Nexa_Labs / scripts /merge_model.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Utility script for merging NexaSci LoRA adapters with the Falcon base model."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from peft import PeftModel
from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer
DEFAULT_BASE_MODEL = "Allanatrix/Nexa_Sci_distilled_Falcon-10B"
DEFAULT_ADAPTER_PATH = Path(__file__).resolve().parents[1] / "models" / "adapter_model.safetensors"
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parents[1] / "models" / "merged"
def _resolve_torch_dtype(dtype_str: str | None) -> torch.dtype | None:
"""Map a string representation of a torch dtype to the actual type."""
if dtype_str is None:
return None
normalised = dtype_str.strip().lower()
mapping: Dict[str, torch.dtype] = {
"float16": torch.float16,
"fp16": torch.float16,
"half": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
if normalised not in mapping:
raise ValueError(f"Unsupported torch dtype: {dtype_str}")
return mapping[normalised]
def merge_lora_adapter(
base_model_id: str,
adapter_path: Path,
output_dir: Path,
torch_dtype: torch.dtype | None = None,
trust_remote_code: bool = True,
lora_rank: Optional[int] = None,
) -> None:
"""Merge a local LoRA adapter with the Falcon base model.
Parameters
----------
base_model_id:
Hugging Face model ID or local path to the base Falcon model.
adapter_path:
Path to the LoRA adapter weights (safetensors directory or file).
output_dir:
Destination directory for the merged model.
torch_dtype:
Optional torch dtype override (float16, bfloat16, float32).
trust_remote_code:
Whether to allow custom model code when loading Falcon.
"""
if not adapter_path.exists():
raise FileNotFoundError(f"Adapter path does not exist: {adapter_path}")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading base model: {base_model_id}")
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
adapter_dir = adapter_path
detected_rank = None
if adapter_path.is_file():
adapter_dir = adapter_path.parent
config_path = adapter_dir / "adapter_config.json"
# Always detect rank from adapter weights first
print(f"Detecting LoRA rank from adapter weights...")
try:
with safe_open(adapter_path, framework="pt") as f:
keys = list(f.keys())
# Look for lora_A or lora_B weights to infer rank
for key in keys:
if "lora_A" in key or "lora_B" in key:
tensor = f.get_tensor(key)
# lora_A shape is typically [rank, in_dim] or [in_dim, rank]
# lora_B shape is typically [out_dim, rank] or [rank, out_dim]
if len(tensor.shape) == 2:
# Rank is the smaller dimension
detected_rank = min(tensor.shape)
print(f"✓ Detected rank {detected_rank} from {key} (shape: {tensor.shape})")
break
except Exception as e:
print(f"Warning: Could not detect rank from adapter weights: {e}")
# Use provided rank if given, otherwise use detected or default
if lora_rank is not None:
detected_rank = lora_rank
print(f"Using provided LoRA rank: {detected_rank}")
elif detected_rank is None:
detected_rank = 32
print(f"Using default rank: {detected_rank}")
# Check if existing config has wrong rank
recreate_config = True
if config_path.exists():
try:
with config_path.open() as f:
existing_config = json.load(f)
existing_rank = existing_config.get("r")
if existing_rank == detected_rank:
print(f"✓ Existing adapter_config.json has correct rank {detected_rank}")
recreate_config = False
else:
print(f"⚠ Existing config has rank {existing_rank}, but adapter needs rank {detected_rank}")
print(f" Recreating adapter_config.json...")
except Exception as e:
print(f"Warning: Could not read existing config: {e}")
if recreate_config:
print(f"Creating adapter_config.json for {adapter_path.name} with rank {detected_rank}")
adapter_config = {
"base_model_name_or_path": base_model_id,
"bias": "none",
"fan_in_fan_out": False,
"inference_mode": True,
"init_lora_weights": True,
"lora_alpha": detected_rank * 2, # Common practice: alpha = 2 * rank
"lora_dropout": 0.0,
"modules_to_save": None,
"peft_type": "LORA",
"r": detected_rank,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
"task_type": "CAUSAL_LM",
}
config_path.write_text(json.dumps(adapter_config, indent=2))
print(f"✓ Created {config_path}")
print(f"Loading adapter from: {adapter_dir}")
peft_model = PeftModel.from_pretrained(model, str(adapter_dir), is_trainable=False)
print("Merging adapter weights into base model...")
merged_model = peft_model.merge_and_unload()
print(f"Saving merged model to: {output_dir}")
merged_model.save_pretrained(output_dir, safe_serialization=True)
print("Saving tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code)
tokenizer.save_pretrained(output_dir)
metadata: Dict[str, Any] = {
"base_model": base_model_id,
"adapter_path": str(adapter_path),
"torch_dtype": str(torch_dtype) if torch_dtype is not None else "auto",
}
metadata_path = output_dir / "merge_metadata.json"
metadata_path.write_text(json.dumps(metadata, indent=2))
print(f"Wrote merge metadata to {metadata_path}")
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""Parse command line arguments for the merge script."""
parser = argparse.ArgumentParser(description="Merge NexaSci LoRA weights with Falcon base model.")
parser.add_argument(
"--base-model",
default=DEFAULT_BASE_MODEL,
help="Base model repository or path (default: %(default)s).",
)
parser.add_argument(
"--adapter-path",
default=str(DEFAULT_ADAPTER_PATH),
help="Path to the LoRA adapter weights (default: %(default)s).",
)
parser.add_argument(
"--output-dir",
default=str(DEFAULT_OUTPUT_DIR),
help="Directory to save the merged model (default: %(default)s).",
)
parser.add_argument(
"--torch-dtype",
default=None,
choices=["float16", "fp16", "half", "bfloat16", "bf16", "float32", "fp32"],
help="Optional torch dtype override for loading the base model.",
)
parser.add_argument(
"--no-trust-remote-code",
action="store_false",
dest="trust_remote_code",
help="Disable trust_remote_code when loading the base model.",
)
parser.add_argument(
"--lora-rank",
type=int,
default=None,
help="LoRA rank (r). If not specified, will be auto-detected from adapter weights.",
)
return parser.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
"""Entry point for merging LoRA adapters with the base model."""
args = parse_args(argv)
adapter_path = Path(args.adapter_path)
output_dir = Path(args.output_dir)
torch_dtype = _resolve_torch_dtype(args.torch_dtype)
try:
merge_lora_adapter(
base_model_id=args.base_model,
adapter_path=adapter_path,
output_dir=output_dir,
torch_dtype=torch_dtype,
trust_remote_code=args.trust_remote_code,
lora_rank=args.lora_rank,
)
except Exception as exc: # pragma: no cover - CLI helper
print(f"Error merging model: {exc}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__": # pragma: no cover - CLI helper
sys.exit(main())