"""Utility script for merging NexaSci LoRA adapters with the Falcon base model.""" from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict, Optional import torch from peft import PeftModel from safetensors import safe_open from transformers import AutoModelForCausalLM, AutoTokenizer DEFAULT_BASE_MODEL = "Allanatrix/Nexa_Sci_distilled_Falcon-10B" DEFAULT_ADAPTER_PATH = Path(__file__).resolve().parents[1] / "models" / "adapter_model.safetensors" DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parents[1] / "models" / "merged" def _resolve_torch_dtype(dtype_str: str | None) -> torch.dtype | None: """Map a string representation of a torch dtype to the actual type.""" if dtype_str is None: return None normalised = dtype_str.strip().lower() mapping: Dict[str, torch.dtype] = { "float16": torch.float16, "fp16": torch.float16, "half": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float32": torch.float32, "fp32": torch.float32, } if normalised not in mapping: raise ValueError(f"Unsupported torch dtype: {dtype_str}") return mapping[normalised] def merge_lora_adapter( base_model_id: str, adapter_path: Path, output_dir: Path, torch_dtype: torch.dtype | None = None, trust_remote_code: bool = True, lora_rank: Optional[int] = None, ) -> None: """Merge a local LoRA adapter with the Falcon base model. Parameters ---------- base_model_id: Hugging Face model ID or local path to the base Falcon model. adapter_path: Path to the LoRA adapter weights (safetensors directory or file). output_dir: Destination directory for the merged model. torch_dtype: Optional torch dtype override (float16, bfloat16, float32). trust_remote_code: Whether to allow custom model code when loading Falcon. """ if not adapter_path.exists(): raise FileNotFoundError(f"Adapter path does not exist: {adapter_path}") output_dir.mkdir(parents=True, exist_ok=True) print(f"Loading base model: {base_model_id}") model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", torch_dtype=torch_dtype, low_cpu_mem_usage=True, trust_remote_code=trust_remote_code, ) adapter_dir = adapter_path detected_rank = None if adapter_path.is_file(): adapter_dir = adapter_path.parent config_path = adapter_dir / "adapter_config.json" # Always detect rank from adapter weights first print(f"Detecting LoRA rank from adapter weights...") try: with safe_open(adapter_path, framework="pt") as f: keys = list(f.keys()) # Look for lora_A or lora_B weights to infer rank for key in keys: if "lora_A" in key or "lora_B" in key: tensor = f.get_tensor(key) # lora_A shape is typically [rank, in_dim] or [in_dim, rank] # lora_B shape is typically [out_dim, rank] or [rank, out_dim] if len(tensor.shape) == 2: # Rank is the smaller dimension detected_rank = min(tensor.shape) print(f"✓ Detected rank {detected_rank} from {key} (shape: {tensor.shape})") break except Exception as e: print(f"Warning: Could not detect rank from adapter weights: {e}") # Use provided rank if given, otherwise use detected or default if lora_rank is not None: detected_rank = lora_rank print(f"Using provided LoRA rank: {detected_rank}") elif detected_rank is None: detected_rank = 32 print(f"Using default rank: {detected_rank}") # Check if existing config has wrong rank recreate_config = True if config_path.exists(): try: with config_path.open() as f: existing_config = json.load(f) existing_rank = existing_config.get("r") if existing_rank == detected_rank: print(f"✓ Existing adapter_config.json has correct rank {detected_rank}") recreate_config = False else: print(f"⚠ Existing config has rank {existing_rank}, but adapter needs rank {detected_rank}") print(f" Recreating adapter_config.json...") except Exception as e: print(f"Warning: Could not read existing config: {e}") if recreate_config: print(f"Creating adapter_config.json for {adapter_path.name} with rank {detected_rank}") adapter_config = { "base_model_name_or_path": base_model_id, "bias": "none", "fan_in_fan_out": False, "inference_mode": True, "init_lora_weights": True, "lora_alpha": detected_rank * 2, # Common practice: alpha = 2 * rank "lora_dropout": 0.0, "modules_to_save": None, "peft_type": "LORA", "r": detected_rank, "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], "task_type": "CAUSAL_LM", } config_path.write_text(json.dumps(adapter_config, indent=2)) print(f"✓ Created {config_path}") print(f"Loading adapter from: {adapter_dir}") peft_model = PeftModel.from_pretrained(model, str(adapter_dir), is_trainable=False) print("Merging adapter weights into base model...") merged_model = peft_model.merge_and_unload() print(f"Saving merged model to: {output_dir}") merged_model.save_pretrained(output_dir, safe_serialization=True) print("Saving tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code) tokenizer.save_pretrained(output_dir) metadata: Dict[str, Any] = { "base_model": base_model_id, "adapter_path": str(adapter_path), "torch_dtype": str(torch_dtype) if torch_dtype is not None else "auto", } metadata_path = output_dir / "merge_metadata.json" metadata_path.write_text(json.dumps(metadata, indent=2)) print(f"Wrote merge metadata to {metadata_path}") def parse_args(argv: list[str] | None = None) -> argparse.Namespace: """Parse command line arguments for the merge script.""" parser = argparse.ArgumentParser(description="Merge NexaSci LoRA weights with Falcon base model.") parser.add_argument( "--base-model", default=DEFAULT_BASE_MODEL, help="Base model repository or path (default: %(default)s).", ) parser.add_argument( "--adapter-path", default=str(DEFAULT_ADAPTER_PATH), help="Path to the LoRA adapter weights (default: %(default)s).", ) parser.add_argument( "--output-dir", default=str(DEFAULT_OUTPUT_DIR), help="Directory to save the merged model (default: %(default)s).", ) parser.add_argument( "--torch-dtype", default=None, choices=["float16", "fp16", "half", "bfloat16", "bf16", "float32", "fp32"], help="Optional torch dtype override for loading the base model.", ) parser.add_argument( "--no-trust-remote-code", action="store_false", dest="trust_remote_code", help="Disable trust_remote_code when loading the base model.", ) parser.add_argument( "--lora-rank", type=int, default=None, help="LoRA rank (r). If not specified, will be auto-detected from adapter weights.", ) return parser.parse_args(argv) def main(argv: list[str] | None = None) -> int: """Entry point for merging LoRA adapters with the base model.""" args = parse_args(argv) adapter_path = Path(args.adapter_path) output_dir = Path(args.output_dir) torch_dtype = _resolve_torch_dtype(args.torch_dtype) try: merge_lora_adapter( base_model_id=args.base_model, adapter_path=adapter_path, output_dir=output_dir, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code, lora_rank=args.lora_rank, ) except Exception as exc: # pragma: no cover - CLI helper print(f"Error merging model: {exc}", file=sys.stderr) return 1 return 0 if __name__ == "__main__": # pragma: no cover - CLI helper sys.exit(main())