Spaces:
Paused
Paused
| """Utility script for merging NexaSci LoRA adapters with the Falcon base model.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from peft import PeftModel | |
| from safetensors import safe_open | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| DEFAULT_BASE_MODEL = "Allanatrix/Nexa_Sci_distilled_Falcon-10B" | |
| DEFAULT_ADAPTER_PATH = Path(__file__).resolve().parents[1] / "models" / "adapter_model.safetensors" | |
| DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parents[1] / "models" / "merged" | |
| def _resolve_torch_dtype(dtype_str: str | None) -> torch.dtype | None: | |
| """Map a string representation of a torch dtype to the actual type.""" | |
| if dtype_str is None: | |
| return None | |
| normalised = dtype_str.strip().lower() | |
| mapping: Dict[str, torch.dtype] = { | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "half": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| } | |
| if normalised not in mapping: | |
| raise ValueError(f"Unsupported torch dtype: {dtype_str}") | |
| return mapping[normalised] | |
| def merge_lora_adapter( | |
| base_model_id: str, | |
| adapter_path: Path, | |
| output_dir: Path, | |
| torch_dtype: torch.dtype | None = None, | |
| trust_remote_code: bool = True, | |
| lora_rank: Optional[int] = None, | |
| ) -> None: | |
| """Merge a local LoRA adapter with the Falcon base model. | |
| Parameters | |
| ---------- | |
| base_model_id: | |
| Hugging Face model ID or local path to the base Falcon model. | |
| adapter_path: | |
| Path to the LoRA adapter weights (safetensors directory or file). | |
| output_dir: | |
| Destination directory for the merged model. | |
| torch_dtype: | |
| Optional torch dtype override (float16, bfloat16, float32). | |
| trust_remote_code: | |
| Whether to allow custom model code when loading Falcon. | |
| """ | |
| if not adapter_path.exists(): | |
| raise FileNotFoundError(f"Adapter path does not exist: {adapter_path}") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| print(f"Loading base model: {base_model_id}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| device_map="auto", | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| adapter_dir = adapter_path | |
| detected_rank = None | |
| if adapter_path.is_file(): | |
| adapter_dir = adapter_path.parent | |
| config_path = adapter_dir / "adapter_config.json" | |
| # Always detect rank from adapter weights first | |
| print(f"Detecting LoRA rank from adapter weights...") | |
| try: | |
| with safe_open(adapter_path, framework="pt") as f: | |
| keys = list(f.keys()) | |
| # Look for lora_A or lora_B weights to infer rank | |
| for key in keys: | |
| if "lora_A" in key or "lora_B" in key: | |
| tensor = f.get_tensor(key) | |
| # lora_A shape is typically [rank, in_dim] or [in_dim, rank] | |
| # lora_B shape is typically [out_dim, rank] or [rank, out_dim] | |
| if len(tensor.shape) == 2: | |
| # Rank is the smaller dimension | |
| detected_rank = min(tensor.shape) | |
| print(f"✓ Detected rank {detected_rank} from {key} (shape: {tensor.shape})") | |
| break | |
| except Exception as e: | |
| print(f"Warning: Could not detect rank from adapter weights: {e}") | |
| # Use provided rank if given, otherwise use detected or default | |
| if lora_rank is not None: | |
| detected_rank = lora_rank | |
| print(f"Using provided LoRA rank: {detected_rank}") | |
| elif detected_rank is None: | |
| detected_rank = 32 | |
| print(f"Using default rank: {detected_rank}") | |
| # Check if existing config has wrong rank | |
| recreate_config = True | |
| if config_path.exists(): | |
| try: | |
| with config_path.open() as f: | |
| existing_config = json.load(f) | |
| existing_rank = existing_config.get("r") | |
| if existing_rank == detected_rank: | |
| print(f"✓ Existing adapter_config.json has correct rank {detected_rank}") | |
| recreate_config = False | |
| else: | |
| print(f"⚠ Existing config has rank {existing_rank}, but adapter needs rank {detected_rank}") | |
| print(f" Recreating adapter_config.json...") | |
| except Exception as e: | |
| print(f"Warning: Could not read existing config: {e}") | |
| if recreate_config: | |
| print(f"Creating adapter_config.json for {adapter_path.name} with rank {detected_rank}") | |
| adapter_config = { | |
| "base_model_name_or_path": base_model_id, | |
| "bias": "none", | |
| "fan_in_fan_out": False, | |
| "inference_mode": True, | |
| "init_lora_weights": True, | |
| "lora_alpha": detected_rank * 2, # Common practice: alpha = 2 * rank | |
| "lora_dropout": 0.0, | |
| "modules_to_save": None, | |
| "peft_type": "LORA", | |
| "r": detected_rank, | |
| "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], | |
| "task_type": "CAUSAL_LM", | |
| } | |
| config_path.write_text(json.dumps(adapter_config, indent=2)) | |
| print(f"✓ Created {config_path}") | |
| print(f"Loading adapter from: {adapter_dir}") | |
| peft_model = PeftModel.from_pretrained(model, str(adapter_dir), is_trainable=False) | |
| print("Merging adapter weights into base model...") | |
| merged_model = peft_model.merge_and_unload() | |
| print(f"Saving merged model to: {output_dir}") | |
| merged_model.save_pretrained(output_dir, safe_serialization=True) | |
| print("Saving tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code) | |
| tokenizer.save_pretrained(output_dir) | |
| metadata: Dict[str, Any] = { | |
| "base_model": base_model_id, | |
| "adapter_path": str(adapter_path), | |
| "torch_dtype": str(torch_dtype) if torch_dtype is not None else "auto", | |
| } | |
| metadata_path = output_dir / "merge_metadata.json" | |
| metadata_path.write_text(json.dumps(metadata, indent=2)) | |
| print(f"Wrote merge metadata to {metadata_path}") | |
| def parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| """Parse command line arguments for the merge script.""" | |
| parser = argparse.ArgumentParser(description="Merge NexaSci LoRA weights with Falcon base model.") | |
| parser.add_argument( | |
| "--base-model", | |
| default=DEFAULT_BASE_MODEL, | |
| help="Base model repository or path (default: %(default)s).", | |
| ) | |
| parser.add_argument( | |
| "--adapter-path", | |
| default=str(DEFAULT_ADAPTER_PATH), | |
| help="Path to the LoRA adapter weights (default: %(default)s).", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=str(DEFAULT_OUTPUT_DIR), | |
| help="Directory to save the merged model (default: %(default)s).", | |
| ) | |
| parser.add_argument( | |
| "--torch-dtype", | |
| default=None, | |
| choices=["float16", "fp16", "half", "bfloat16", "bf16", "float32", "fp32"], | |
| help="Optional torch dtype override for loading the base model.", | |
| ) | |
| parser.add_argument( | |
| "--no-trust-remote-code", | |
| action="store_false", | |
| dest="trust_remote_code", | |
| help="Disable trust_remote_code when loading the base model.", | |
| ) | |
| parser.add_argument( | |
| "--lora-rank", | |
| type=int, | |
| default=None, | |
| help="LoRA rank (r). If not specified, will be auto-detected from adapter weights.", | |
| ) | |
| return parser.parse_args(argv) | |
| def main(argv: list[str] | None = None) -> int: | |
| """Entry point for merging LoRA adapters with the base model.""" | |
| args = parse_args(argv) | |
| adapter_path = Path(args.adapter_path) | |
| output_dir = Path(args.output_dir) | |
| torch_dtype = _resolve_torch_dtype(args.torch_dtype) | |
| try: | |
| merge_lora_adapter( | |
| base_model_id=args.base_model, | |
| adapter_path=adapter_path, | |
| output_dir=output_dir, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=args.trust_remote_code, | |
| lora_rank=args.lora_rank, | |
| ) | |
| except Exception as exc: # pragma: no cover - CLI helper | |
| print(f"Error merging model: {exc}", file=sys.stderr) | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": # pragma: no cover - CLI helper | |
| sys.exit(main()) | |