# ๐ŸŽจ Fine-tuning FunctionGemma for Square Color Control

This notebook demonstrates how to fine-tune FunctionGemma to recognize color control commands.

**Author:** Harlley Oliveira
**Portfolio:** AI Engineering

## Objectives
1. Train the model to call `set_square_color` when the user wants to change the color
2. Train the model to call `get_square_color` when the user asks about the current color
3. Support various natural language command styles

## ๐Ÿ“ฆ 1. Setup and Installation

In [None]:
# Install dependencies
%pip install -q torch tensorboard
%pip install -q transformers datasets accelerate evaluate trl protobuf sentencepiece

In [None]:
# Login to Hugging Face Hub
from huggingface_hub import login
login()

In [None]:
# Configuration
BASE_MODEL = "google/functiongemma-270m-it"
OUTPUT_DIR = "functiongemma-square-color"
LEARNING_RATE = 5e-5
NUM_EPOCHS = 8
BATCH_SIZE = 4

## ๐Ÿ“Š 2. Prepare Dataset with Correct Format

In [None]:
import json
from datasets import Dataset

# Tool definitions (same as before)
def set_square_color(color: str) -> str:
 """
 Sets the color of the square displayed on the screen.
 
 Args:
 color: The color to set, e.g. red, blue, green
 """
 return f"Color set to {color}"

def get_square_color() -> str:
 """
 Returns the current color of the square.
 Use this when the user asks about the current color.
 """
 return "Current color"

# Get JSON schemas
from transformers.utils import get_json_schema
TOOLS = [
 get_json_schema(set_square_color),
 get_json_schema(get_square_color)
]

print("Tool schemas:")
print(json.dumps(TOOLS, indent=2))

In [None]:
# Load training dataset
with open("dataset/square_color_dataset.json", "r") as f:
 square_color_dataset = json.load(f)

print(f"Total examples: {len(square_color_dataset)}")
print(f" - SET: {len([x for x in square_color_dataset if x['tool_name'] == 'set_square_color'])}")
print(f" - GET: {len([x for x in square_color_dataset if x['tool_name'] == 'get_square_color'])}")

# Preview first few examples
print("\nFirst 3 examples:")
for i, sample in enumerate(square_color_dataset[:3]):
 print(f" {i+1}. \"{sample['user_content']}\" โ†’ {sample['tool_name']}")

In [None]:
# CRITICAL: FunctionGemma's expected output format
# The model should output: call:func{args}

SYSTEM_PROMPT = "You are a model that can do function calling with the following functions"

def format_function_call_output(tool_name: str, tool_arguments: dict) -> str:
 """
 Format the expected output in FunctionGemma's native format.
 
 FunctionGemma outputs: call:func_name{arg:value}
 """
 if not tool_arguments:
 # For functions with no arguments
 return f"call:{tool_name}{{}}"
 
 # Format arguments with tokens for string values
 args_parts = []
 for key, value in tool_arguments.items():
 if isinstance(value, str):
 args_parts.append(f"{key}:{value}")
 else:
 args_parts.append(f"{key}:{value}")
 
 args_str = ",".join(args_parts)
 return f"call:{tool_name}{{{args_str}}}"

# Test the format
print("Example outputs:")
print(format_function_call_output("set_square_color", {"color": "blue"}))
print(format_function_call_output("get_square_color", {}))

In [None]:
from transformers import AutoTokenizer

# Load tokenizer first to use apply_chat_template
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

def create_training_text(sample):
 """
 Create the full training text using FunctionGemma's chat template.
 
 The key is that we format the assistant's response in FunctionGemma's
 native function call format.
 """
 tool_args = json.loads(sample["tool_arguments"])
 expected_output = format_function_call_output(sample["tool_name"], tool_args)
 
 # Create messages - note: assistant content is the raw function call format
 messages = [
 {"role": "developer", "content": SYSTEM_PROMPT},
 {"role": "user", "content": sample["user_content"]},
 {"role": "assistant", "content": expected_output},
 ]
 
 # Apply chat template WITH tools to get proper function declarations
 text = tokenizer.apply_chat_template(
 messages,
 tools=TOOLS,
 tokenize=False,
 add_generation_prompt=False
 )
 
 return {"text": text}

# Create dataset
dataset = Dataset.from_list(square_color_dataset)
dataset = dataset.map(create_training_text, remove_columns=dataset.features, batched=False)

# Split 80/20
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)

print(f"Train: {len(dataset['train'])} examples")
print(f"Test: {len(dataset['test'])} examples")

In [None]:
# Visualize a formatted example
print("=" * 60)
print("FORMATTED TRAINING EXAMPLE")
print("=" * 60)
print(dataset["train"][0]["text"])
print("=" * 60)

## ๐Ÿค– 3. Load Model

In [None]:
import torch
from transformers import AutoModelForCausalLM

print("Loading model for fine-tuning...")

model = AutoModelForCausalLM.from_pretrained(
 BASE_MODEL,
 dtype=torch.bfloat16,
 device_map="auto",
 attn_implementation="eager"
)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
print(f"Parameters: {model.num_parameters():,}")

## ๐Ÿงช 3.5. Pre-Training Evaluation (Baseline)

In [None]:
import re

def extract_function_call(text):
 """
 Extract function call from FunctionGemma's output format.
 Returns (function_name, arguments_dict) or (None, None) if not found.
 """
 pattern = r"call:(\w+)\{(.*)\}"
 match = re.search(pattern, text, re.DOTALL)
 
 if not match:
 return None, None
 
 func_name = match.group(1)
 args_str = match.group(2)
 
 # Parse arguments
 args = {}
 if args_str.strip():
 # Match key:value or key:value patterns
 arg_pattern = r"(\w+):(?:(.*?)|([^,}]*))"
 for m in re.finditer(arg_pattern, args_str):
 key = m.group(1)
 value = m.group(2) if m.group(2) else m.group(3)
 args[key] = value.strip() if value else ""
 
 return func_name, args

def evaluate_model(model, tokenizer, test_samples, tools, system_prompt, verbose=True):
 """
 Evaluate model on test samples using FunctionGemma's format.
 """
 results = {
 "total": len(test_samples),
 "correct": 0,
 "correct_tool": 0,
 "correct_args": 0,
 "details": []
 }
 
 for sample in test_samples:
 messages = [
 {"role": "developer", "content": system_prompt},
 {"role": "user", "content": sample["user_content"]},
 ]
 
 inputs = tokenizer.apply_chat_template(
 messages,
 tools=tools,
 tokenize=True,
 add_generation_prompt=True,
 return_dict=True,
 return_tensors="pt"
 ).to(model.device)
 
 with torch.no_grad():
 output = model.generate(
 **inputs,
 max_new_tokens=128,
 do_sample=False,
 )
 
 input_length = inputs['input_ids'].shape[1]
 response = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)
 
 # Parse the function call from response
 called_func, called_args = extract_function_call(response)
 
 # Check if correct tool was called
 tool_correct = called_func == sample["tool_name"]
 
 # Check arguments
 args_correct = False
 expected_args = json.loads(sample["tool_arguments"])
 
 if tool_correct:
 if sample["tool_name"] == "get_square_color":
 args_correct = True # No args needed
 elif called_args and "color" in called_args:
 args_correct = called_args.get("color", "").lower() == expected_args.get("color", "").lower()
 
 if tool_correct:
 results["correct_tool"] += 1
 if tool_correct and args_correct:
 results["correct"] += 1
 results["correct_args"] += 1
 
 results["details"].append({
 "input": sample["user_content"],
 "expected_tool": sample["tool_name"],
 "expected_args": sample["tool_arguments"],
 "called_func": called_func,
 "called_args": called_args,
 "response": response,
 "tool_correct": tool_correct,
 "args_correct": args_correct
 })
 
 results["tool_accuracy"] = results["correct_tool"] / results["total"] * 100
 results["full_accuracy"] = results["correct"] / results["total"] * 100
 
 if verbose:
 print(f"Tool Accuracy: {results['correct_tool']}/{results['total']} ({results['tool_accuracy']:.1f}%)")
 print(f"Full Accuracy (tool + args): {results['correct']}/{results['total']} ({results['full_accuracy']:.1f}%)")
 
 return results

In [None]:
# Create evaluation test set
import random

random.seed(42)

set_samples = [s for s in square_color_dataset if s["tool_name"] == "set_square_color"]
get_samples = [s for s in square_color_dataset if s["tool_name"] == "get_square_color"]

test_cases = 25
eval_test_cases = random.sample(set_samples, min(test_cases, len(set_samples))) + \
 random.sample(get_samples, min(test_cases, len(get_samples)))

print("=" * 50)
print("PRE-TRAINING EVALUATION (Baseline)")
print("=" * 50)
print(f"\nEvaluating base model on {len(eval_test_cases)} test cases...\n")

baseline_results = evaluate_model(
 model=model,
 tokenizer=tokenizer,
 test_samples=eval_test_cases,
 tools=TOOLS,
 system_prompt=SYSTEM_PROMPT
)

# Show sample outputs
print("\n--- Sample Outputs (Base Model) ---")
for i, detail in enumerate(baseline_results["details"][:4]):
 status = "โœ…" if detail["tool_correct"] else "โŒ"
 print(f"\n{status} Input: {detail['input']}")
 print(f" Expected: {detail['expected_tool']}")
 print(f" Got: {detail['called_func']} with args {detail['called_args']}")

## ๐Ÿ”ฅ 4. Fine-tuning

In [None]:
from trl import SFTConfig, SFTTrainer

torch_dtype = model.dtype

# Training configuration
args = SFTConfig(
 output_dir=OUTPUT_DIR,
 max_length=512,
 packing=False,
 num_train_epochs=NUM_EPOCHS,
 per_device_train_batch_size=BATCH_SIZE,
 gradient_checkpointing=False,
 optim="adamw_torch_fused",
 logging_steps=1,
 eval_strategy="epoch",
 save_strategy="epoch",
 learning_rate=LEARNING_RATE,
 fp16=True if torch_dtype == torch.float16 else False,
 bf16=True if torch_dtype == torch.bfloat16 else False,
 lr_scheduler_type="constant",
 push_to_hub=True,
 report_to="tensorboard",
 load_best_model_at_end=True,
 metric_for_best_model="eval_loss",
 dataset_text_field="text", # IMPORTANT: specify the text field
)

# Create trainer
trainer = SFTTrainer(
 model=model,
 args=args,
 train_dataset=dataset['train'],
 eval_dataset=dataset['test'],
 processing_class=tokenizer,
)

print("Trainer created successfully!")

In [None]:
# ๐Ÿš€ Start training!
print("Starting fine-tuning...")
trainer.train()

print("\nโœ… Training complete!")

In [None]:
# Save final model in the original dtype (BF16)
# This prevents the model from being saved as FP32 (which doubles the size)
model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved to: {OUTPUT_DIR}")

## ๐Ÿ“ˆ 5. Visualize Results

In [None]:
import matplotlib.pyplot as plt

# Extract loss history
log_history = trainer.state.log_history

train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot
plt.figure(figsize=(10, 6))
plt.plot(epoch_train, train_losses, label="Training Loss", alpha=0.7)
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

## ๐Ÿงช 6. Post-Training Evaluation

In [None]:
print("=" * 50)
print("POST-TRAINING EVALUATION (Fine-tuned)")
print("=" * 50)
print(f"\nEvaluating fine-tuned model on {len(eval_test_cases)} test cases...\n")

finetuned_results = evaluate_model(
 model=model,
 tokenizer=tokenizer,
 test_samples=eval_test_cases,
 tools=TOOLS,
 system_prompt=SYSTEM_PROMPT
)

# Show sample outputs
print("\n--- Sample Outputs (Fine-tuned Model) ---")
for i, detail in enumerate(finetuned_results["details"][:4]):
 status = "โœ…" if detail["tool_correct"] else "โŒ"
 print(f"\n{status} Input: {detail['input']}")
 print(f" Expected: {detail['expected_tool']}")
 print(f" Got: {detail['called_func']} with args {detail['called_args']}")

In [None]:
# Compare baseline vs fine-tuned
print("=" * 60)
print("๐Ÿ“Š COMPARISON: Baseline vs Fine-tuned")
print("=" * 60)

print(f"\n{'Metric':<30} {'Baseline':>12} {'Fine-tuned':>12} {'Improvement':>12}")
print("-" * 66)

tool_improvement = finetuned_results["tool_accuracy"] - baseline_results["tool_accuracy"]
print(f"{'Tool Accuracy':<30} {baseline_results['tool_accuracy']:>11.1f}% {finetuned_results['tool_accuracy']:>11.1f}% {tool_improvement:>+11.1f}%")

full_improvement = finetuned_results["full_accuracy"] - baseline_results["full_accuracy"]
print(f"{'Full Accuracy (tool + args)':<30} {baseline_results['full_accuracy']:>11.1f}% {finetuned_results['full_accuracy']:>11.1f}% {full_improvement:>+11.1f}%")

print("-" * 66)

if full_improvement > 0:
 print(f"\nโœ… Fine-tuning improved accuracy by {full_improvement:.1f} percentage points!")
elif full_improvement == 0:
 print(f"\nโš ๏ธ No change in accuracy.")
else:
 print(f"\nโŒ Accuracy decreased. Check for overfitting or data issues.")

## ๐Ÿ“ค 7. Push to Hugging Face Hub

In [None]:
# Push to Hub
trainer.push_to_hub()

print(f"\nโœ… Model pushed to: https://huggingface.co/{trainer.hub_model_id}")