{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ๐ŸŽจ Fine-tuning FunctionGemma for Square Color Control\n", "\n", "This notebook demonstrates how to fine-tune FunctionGemma to recognize color control commands.\n", "\n", "**Author:** Harlley Oliveira\n", "**Portfolio:** AI Engineering\n", "\n", "## Objectives\n", "1. Train the model to call `set_square_color` when the user wants to change the color\n", "2. Train the model to call `get_square_color` when the user asks about the current color\n", "3. Support various natural language command styles" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ“ฆ 1. Setup and Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies\n", "%pip install -q torch tensorboard\n", "%pip install -q transformers datasets accelerate evaluate trl protobuf sentencepiece" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Login to Hugging Face Hub\n", "from huggingface_hub import login\n", "login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Configuration\n", "BASE_MODEL = \"google/functiongemma-270m-it\"\n", "OUTPUT_DIR = \"functiongemma-square-color\"\n", "LEARNING_RATE = 5e-5\n", "NUM_EPOCHS = 8\n", "BATCH_SIZE = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ“Š 2. Prepare Dataset with Correct Format" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from datasets import Dataset\n", "\n", "# Tool definitions (same as before)\n", "def set_square_color(color: str) -> str:\n", " \"\"\"\n", " Sets the color of the square displayed on the screen.\n", " \n", " Args:\n", " color: The color to set, e.g. red, blue, green\n", " \"\"\"\n", " return f\"Color set to {color}\"\n", "\n", "def get_square_color() -> str:\n", " \"\"\"\n", " Returns the current color of the square.\n", " Use this when the user asks about the current color.\n", " \"\"\"\n", " return \"Current color\"\n", "\n", "# Get JSON schemas\n", "from transformers.utils import get_json_schema\n", "TOOLS = [\n", " get_json_schema(set_square_color),\n", " get_json_schema(get_square_color)\n", "]\n", "\n", "print(\"Tool schemas:\")\n", "print(json.dumps(TOOLS, indent=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load training dataset\n", "with open(\"dataset/square_color_dataset.json\", \"r\") as f:\n", " square_color_dataset = json.load(f)\n", "\n", "print(f\"Total examples: {len(square_color_dataset)}\")\n", "print(f\" - SET: {len([x for x in square_color_dataset if x['tool_name'] == 'set_square_color'])}\")\n", "print(f\" - GET: {len([x for x in square_color_dataset if x['tool_name'] == 'get_square_color'])}\")\n", "\n", "# Preview first few examples\n", "print(\"\\nFirst 3 examples:\")\n", "for i, sample in enumerate(square_color_dataset[:3]):\n", " print(f\" {i+1}. \\\"{sample['user_content']}\\\" โ†’ {sample['tool_name']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# CRITICAL: FunctionGemma's expected output format\n", "# The model should output: call:func{args}\n", "\n", "SYSTEM_PROMPT = \"You are a model that can do function calling with the following functions\"\n", "\n", "def format_function_call_output(tool_name: str, tool_arguments: dict) -> str:\n", " \"\"\"\n", " Format the expected output in FunctionGemma's native format.\n", " \n", " FunctionGemma outputs: call:func_name{arg:value}\n", " \"\"\"\n", " if not tool_arguments:\n", " # For functions with no arguments\n", " return f\"call:{tool_name}{{}}\"\n", " \n", " # Format arguments with tokens for string values\n", " args_parts = []\n", " for key, value in tool_arguments.items():\n", " if isinstance(value, str):\n", " args_parts.append(f\"{key}:{value}\")\n", " else:\n", " args_parts.append(f\"{key}:{value}\")\n", " \n", " args_str = \",\".join(args_parts)\n", " return f\"call:{tool_name}{{{args_str}}}\"\n", "\n", "# Test the format\n", "print(\"Example outputs:\")\n", "print(format_function_call_output(\"set_square_color\", {\"color\": \"blue\"}))\n", "print(format_function_call_output(\"get_square_color\", {}))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "# Load tokenizer first to use apply_chat_template\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n", "\n", "def create_training_text(sample):\n", " \"\"\"\n", " Create the full training text using FunctionGemma's chat template.\n", " \n", " The key is that we format the assistant's response in FunctionGemma's\n", " native function call format.\n", " \"\"\"\n", " tool_args = json.loads(sample[\"tool_arguments\"])\n", " expected_output = format_function_call_output(sample[\"tool_name\"], tool_args)\n", " \n", " # Create messages - note: assistant content is the raw function call format\n", " messages = [\n", " {\"role\": \"developer\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n", " {\"role\": \"assistant\", \"content\": expected_output},\n", " ]\n", " \n", " # Apply chat template WITH tools to get proper function declarations\n", " text = tokenizer.apply_chat_template(\n", " messages,\n", " tools=TOOLS,\n", " tokenize=False,\n", " add_generation_prompt=False\n", " )\n", " \n", " return {\"text\": text}\n", "\n", "# Create dataset\n", "dataset = Dataset.from_list(square_color_dataset)\n", "dataset = dataset.map(create_training_text, remove_columns=dataset.features, batched=False)\n", "\n", "# Split 80/20\n", "dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n", "\n", "print(f\"Train: {len(dataset['train'])} examples\")\n", "print(f\"Test: {len(dataset['test'])} examples\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Visualize a formatted example\n", "print(\"=\" * 60)\n", "print(\"FORMATTED TRAINING EXAMPLE\")\n", "print(\"=\" * 60)\n", "print(dataset[\"train\"][0][\"text\"])\n", "print(\"=\" * 60)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿค– 3. Load Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM\n", "\n", "print(\"Loading model for fine-tuning...\")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " BASE_MODEL,\n", " dtype=torch.bfloat16,\n", " device_map=\"auto\",\n", " attn_implementation=\"eager\"\n", ")\n", "\n", "print(f\"Device: {model.device}\")\n", "print(f\"DType: {model.dtype}\")\n", "print(f\"Parameters: {model.num_parameters():,}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿงช 3.5. Pre-Training Evaluation (Baseline)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "def extract_function_call(text):\n", " \"\"\"\n", " Extract function call from FunctionGemma's output format.\n", " Returns (function_name, arguments_dict) or (None, None) if not found.\n", " \"\"\"\n", " pattern = r\"call:(\\w+)\\{(.*)\\}\"\n", " match = re.search(pattern, text, re.DOTALL)\n", " \n", " if not match:\n", " return None, None\n", " \n", " func_name = match.group(1)\n", " args_str = match.group(2)\n", " \n", " # Parse arguments\n", " args = {}\n", " if args_str.strip():\n", " # Match key:value or key:value patterns\n", " arg_pattern = r\"(\\w+):(?:(.*?)|([^,}]*))\"\n", " for m in re.finditer(arg_pattern, args_str):\n", " key = m.group(1)\n", " value = m.group(2) if m.group(2) else m.group(3)\n", " args[key] = value.strip() if value else \"\"\n", " \n", " return func_name, args\n", "\n", "def evaluate_model(model, tokenizer, test_samples, tools, system_prompt, verbose=True):\n", " \"\"\"\n", " Evaluate model on test samples using FunctionGemma's format.\n", " \"\"\"\n", " results = {\n", " \"total\": len(test_samples),\n", " \"correct\": 0,\n", " \"correct_tool\": 0,\n", " \"correct_args\": 0,\n", " \"details\": []\n", " }\n", " \n", " for sample in test_samples:\n", " messages = [\n", " {\"role\": \"developer\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n", " ]\n", " \n", " inputs = tokenizer.apply_chat_template(\n", " messages,\n", " tools=tools,\n", " tokenize=True,\n", " add_generation_prompt=True,\n", " return_dict=True,\n", " return_tensors=\"pt\"\n", " ).to(model.device)\n", " \n", " with torch.no_grad():\n", " output = model.generate(\n", " **inputs,\n", " max_new_tokens=128,\n", " do_sample=False,\n", " )\n", " \n", " input_length = inputs['input_ids'].shape[1]\n", " response = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)\n", " \n", " # Parse the function call from response\n", " called_func, called_args = extract_function_call(response)\n", " \n", " # Check if correct tool was called\n", " tool_correct = called_func == sample[\"tool_name\"]\n", " \n", " # Check arguments\n", " args_correct = False\n", " expected_args = json.loads(sample[\"tool_arguments\"])\n", " \n", " if tool_correct:\n", " if sample[\"tool_name\"] == \"get_square_color\":\n", " args_correct = True # No args needed\n", " elif called_args and \"color\" in called_args:\n", " args_correct = called_args.get(\"color\", \"\").lower() == expected_args.get(\"color\", \"\").lower()\n", " \n", " if tool_correct:\n", " results[\"correct_tool\"] += 1\n", " if tool_correct and args_correct:\n", " results[\"correct\"] += 1\n", " results[\"correct_args\"] += 1\n", " \n", " results[\"details\"].append({\n", " \"input\": sample[\"user_content\"],\n", " \"expected_tool\": sample[\"tool_name\"],\n", " \"expected_args\": sample[\"tool_arguments\"],\n", " \"called_func\": called_func,\n", " \"called_args\": called_args,\n", " \"response\": response,\n", " \"tool_correct\": tool_correct,\n", " \"args_correct\": args_correct\n", " })\n", " \n", " results[\"tool_accuracy\"] = results[\"correct_tool\"] / results[\"total\"] * 100\n", " results[\"full_accuracy\"] = results[\"correct\"] / results[\"total\"] * 100\n", " \n", " if verbose:\n", " print(f\"Tool Accuracy: {results['correct_tool']}/{results['total']} ({results['tool_accuracy']:.1f}%)\")\n", " print(f\"Full Accuracy (tool + args): {results['correct']}/{results['total']} ({results['full_accuracy']:.1f}%)\")\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create evaluation test set\n", "import random\n", "\n", "random.seed(42)\n", "\n", "set_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"set_square_color\"]\n", "get_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"get_square_color\"]\n", "\n", "test_cases = 25\n", "eval_test_cases = random.sample(set_samples, min(test_cases, len(set_samples))) + \\\n", " random.sample(get_samples, min(test_cases, len(get_samples)))\n", "\n", "print(\"=\" * 50)\n", "print(\"PRE-TRAINING EVALUATION (Baseline)\")\n", "print(\"=\" * 50)\n", "print(f\"\\nEvaluating base model on {len(eval_test_cases)} test cases...\\n\")\n", "\n", "baseline_results = evaluate_model(\n", " model=model,\n", " tokenizer=tokenizer,\n", " test_samples=eval_test_cases,\n", " tools=TOOLS,\n", " system_prompt=SYSTEM_PROMPT\n", ")\n", "\n", "# Show sample outputs\n", "print(\"\\n--- Sample Outputs (Base Model) ---\")\n", "for i, detail in enumerate(baseline_results[\"details\"][:4]):\n", " status = \"โœ…\" if detail[\"tool_correct\"] else \"โŒ\"\n", " print(f\"\\n{status} Input: {detail['input']}\")\n", " print(f\" Expected: {detail['expected_tool']}\")\n", " print(f\" Got: {detail['called_func']} with args {detail['called_args']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ”ฅ 4. Fine-tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from trl import SFTConfig, SFTTrainer\n", "\n", "torch_dtype = model.dtype\n", "\n", "# Training configuration\n", "args = SFTConfig(\n", " output_dir=OUTPUT_DIR,\n", " max_length=512,\n", " packing=False,\n", " num_train_epochs=NUM_EPOCHS,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " gradient_checkpointing=False,\n", " optim=\"adamw_torch_fused\",\n", " logging_steps=1,\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=LEARNING_RATE,\n", " fp16=True if torch_dtype == torch.float16 else False,\n", " bf16=True if torch_dtype == torch.bfloat16 else False,\n", " lr_scheduler_type=\"constant\",\n", " push_to_hub=True,\n", " report_to=\"tensorboard\",\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"eval_loss\",\n", " dataset_text_field=\"text\", # IMPORTANT: specify the text field\n", ")\n", "\n", "# Create trainer\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=args,\n", " train_dataset=dataset['train'],\n", " eval_dataset=dataset['test'],\n", " processing_class=tokenizer,\n", ")\n", "\n", "print(\"Trainer created successfully!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ๐Ÿš€ Start training!\n", "print(\"Starting fine-tuning...\")\n", "trainer.train()\n", "\n", "print(\"\\nโœ… Training complete!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save final model in the original dtype (BF16)\n", "# This prevents the model from being saved as FP32 (which doubles the size)\n", "model.save_pretrained(OUTPUT_DIR, safe_serialization=True)\n", "tokenizer.save_pretrained(OUTPUT_DIR)\n", "print(f\"Model saved to: {OUTPUT_DIR}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ“ˆ 5. Visualize Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Extract loss history\n", "log_history = trainer.state.log_history\n", "\n", "train_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n", "epoch_train = [log[\"epoch\"] for log in log_history if \"loss\" in log]\n", "eval_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n", "epoch_eval = [log[\"epoch\"] for log in log_history if \"eval_loss\" in log]\n", "\n", "# Plot\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(epoch_train, train_losses, label=\"Training Loss\", alpha=0.7)\n", "plt.plot(epoch_eval, eval_losses, label=\"Validation Loss\", marker='o')\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.title(\"Training and Validation Loss\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿงช 6. Post-Training Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"=\" * 50)\n", "print(\"POST-TRAINING EVALUATION (Fine-tuned)\")\n", "print(\"=\" * 50)\n", "print(f\"\\nEvaluating fine-tuned model on {len(eval_test_cases)} test cases...\\n\")\n", "\n", "finetuned_results = evaluate_model(\n", " model=model,\n", " tokenizer=tokenizer,\n", " test_samples=eval_test_cases,\n", " tools=TOOLS,\n", " system_prompt=SYSTEM_PROMPT\n", ")\n", "\n", "# Show sample outputs\n", "print(\"\\n--- Sample Outputs (Fine-tuned Model) ---\")\n", "for i, detail in enumerate(finetuned_results[\"details\"][:4]):\n", " status = \"โœ…\" if detail[\"tool_correct\"] else \"โŒ\"\n", " print(f\"\\n{status} Input: {detail['input']}\")\n", " print(f\" Expected: {detail['expected_tool']}\")\n", " print(f\" Got: {detail['called_func']} with args {detail['called_args']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compare baseline vs fine-tuned\n", "print(\"=\" * 60)\n", "print(\"๐Ÿ“Š COMPARISON: Baseline vs Fine-tuned\")\n", "print(\"=\" * 60)\n", "\n", "print(f\"\\n{'Metric':<30} {'Baseline':>12} {'Fine-tuned':>12} {'Improvement':>12}\")\n", "print(\"-\" * 66)\n", "\n", "tool_improvement = finetuned_results[\"tool_accuracy\"] - baseline_results[\"tool_accuracy\"]\n", "print(f\"{'Tool Accuracy':<30} {baseline_results['tool_accuracy']:>11.1f}% {finetuned_results['tool_accuracy']:>11.1f}% {tool_improvement:>+11.1f}%\")\n", "\n", "full_improvement = finetuned_results[\"full_accuracy\"] - baseline_results[\"full_accuracy\"]\n", "print(f\"{'Full Accuracy (tool + args)':<30} {baseline_results['full_accuracy']:>11.1f}% {finetuned_results['full_accuracy']:>11.1f}% {full_improvement:>+11.1f}%\")\n", "\n", "print(\"-\" * 66)\n", "\n", "if full_improvement > 0:\n", " print(f\"\\nโœ… Fine-tuning improved accuracy by {full_improvement:.1f} percentage points!\")\n", "elif full_improvement == 0:\n", " print(f\"\\nโš ๏ธ No change in accuracy.\")\n", "else:\n", " print(f\"\\nโŒ Accuracy decreased. Check for overfitting or data issues.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ“ค 7. Push to Hugging Face Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Push to Hub\n", "trainer.push_to_hub()\n", "\n", "print(f\"\\nโœ… Model pushed to: https://huggingface.co/{trainer.hub_model_id}\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }