harlley commited on
Commit
0e477bf
Β·
1 Parent(s): f0bc0bb

play around with finetunning

Browse files
finetuning-functiongemma/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning FunctionGemma for Square Color Control
2
+
3
+ This directory contains everything needed to fine-tune FunctionGemma to recognize square color control commands.
4
+
5
+ ## πŸ“‹ Overview
6
+
7
+ FunctionGemma is a base model that requires fine-tuning to work well with custom functions. This project demonstrates:
8
+
9
+ 1. **Dataset creation** for function calling
10
+ 2. **Fine-tuning with LoRA** using HuggingFace TRL
11
+ 3. **Export to ONNX** for browser use
12
+ 4. **Deploy to Hugging Face Hub**
13
+
14
+ ## πŸš€ Quick Start
15
+
16
+ ### Option 1: Google Colab (Recommended)
17
+ 1. Upload the entire `finetuning-functiongemma/` folder to [Google Colab](https://colab.research.google.com)
18
+ 2. Open the notebook `finetune_functiongemma.ipynb`
19
+ 3. Select GPU runtime (T4 is sufficient)
20
+ 4. Run all cells
21
+
22
+ > **Note:** The notebook loads the dataset from `dataset/square_color_dataset.json`, so make sure to keep the folder structure intact.
23
+
24
+ ### Option 2: Hugging Face Spaces
25
+ 1. Create a new Space with the Gradio template
26
+ 2. Configure a GPU Space (if needed)
27
+ 3. Use the notebook inside the Space
28
+
29
+ ## πŸ“ Structure
30
+
31
+ ```
32
+ finetuning-functiongemma/
33
+ β”œβ”€β”€ README.md # This file
34
+ β”œβ”€β”€ finetune_functiongemma.ipynb # Main notebook
35
+ β”œβ”€β”€ dataset/
36
+ β”‚ └── square_color_dataset.json # Training dataset
37
+ └── export_to_onnx.py # Script to convert to ONNX
38
+ ```
39
+
40
+ ## 🎯 Target Functions
41
+
42
+ The model will be trained to recognize two functions:
43
+
44
+ ### `set_square_color`
45
+ Changes the square color to a new color.
46
+
47
+ **Example inputs:**
48
+ - "Change the color to blue"
49
+ - "Make it red"
50
+ - "Set the square to green"
51
+
52
+ ### `get_square_color`
53
+ Returns the current color of the square.
54
+
55
+ **Example inputs:**
56
+ - "What color is the square?"
57
+ - "Tell me the current color"
58
+ - "Which color is it?"
59
+
60
+ ## πŸ“Š Dataset
61
+
62
+ The dataset contains varied examples in English, including:
63
+ - Direct commands ("set to red")
64
+ - Indirect commands ("I want it blue")
65
+ - Questions ("what color?")
66
+ - Natural language variations
67
+
68
+ ## πŸ”§ Requirements
69
+
70
+ ```bash
71
+ pip install torch transformers datasets trl accelerate
72
+ pip install optimum[onnxruntime] # For ONNX export
73
+ ```
74
+
75
+ ## πŸ“ Important Notes
76
+
77
+ 1. **GPU Required**: Fine-tuning requires GPU (minimum T4)
78
+ 2. **Time**: ~10-15 minutes with 60 examples and 8 epochs
79
+ 3. **Format**: The model uses special `<escape>` tokens for strings
80
+
81
+ ## πŸ”— Useful Links
82
+
83
+ - [FunctionGemma Docs](https://ai.google.dev/gemma/docs/functiongemma)
84
+ - [Official Fine-tuning Tutorial](https://ai.google.dev/gemma/docs/functiongemma/finetuning-with-functiongemma)
85
+ - [HuggingFace TRL](https://huggingface.co/docs/trl)
86
+
87
+ ## Author
88
+
89
+ Created as an AI Engineering portfolio project.
finetuning-functiongemma/dataset/square_color_dataset.json ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "user_content": "Change the color to blue",
4
+ "tool_name": "set_square_color",
5
+ "tool_arguments": "{\"color\": \"blue\"}"
6
+ },
7
+ {
8
+ "user_content": "What color is the square?",
9
+ "tool_name": "get_square_color",
10
+ "tool_arguments": "{}"
11
+ },
12
+ {
13
+ "user_content": "Make it red",
14
+ "tool_name": "set_square_color",
15
+ "tool_arguments": "{\"color\": \"red\"}"
16
+ },
17
+ {
18
+ "user_content": "Tell me the current color",
19
+ "tool_name": "get_square_color",
20
+ "tool_arguments": "{}"
21
+ },
22
+ {
23
+ "user_content": "Set the square to green",
24
+ "tool_name": "set_square_color",
25
+ "tool_arguments": "{\"color\": \"green\"}"
26
+ },
27
+ {
28
+ "user_content": "Which color is it?",
29
+ "tool_name": "get_square_color",
30
+ "tool_arguments": "{}"
31
+ },
32
+ {
33
+ "user_content": "I want the square to be purple",
34
+ "tool_name": "set_square_color",
35
+ "tool_arguments": "{\"color\": \"purple\"}"
36
+ },
37
+ {
38
+ "user_content": "What's the color right now?",
39
+ "tool_name": "get_square_color",
40
+ "tool_arguments": "{}"
41
+ },
42
+ {
43
+ "user_content": "Turn it yellow",
44
+ "tool_name": "set_square_color",
45
+ "tool_arguments": "{\"color\": \"yellow\"}"
46
+ },
47
+ {
48
+ "user_content": "Can you tell me what color the square is?",
49
+ "tool_name": "get_square_color",
50
+ "tool_arguments": "{}"
51
+ },
52
+ {
53
+ "user_content": "Paint it orange",
54
+ "tool_name": "set_square_color",
55
+ "tool_arguments": "{\"color\": \"orange\"}"
56
+ },
57
+ {
58
+ "user_content": "I'd like to know the current color",
59
+ "tool_name": "get_square_color",
60
+ "tool_arguments": "{}"
61
+ },
62
+ {
63
+ "user_content": "Switch to pink",
64
+ "tool_name": "set_square_color",
65
+ "tool_arguments": "{\"color\": \"pink\"}"
66
+ },
67
+ {
68
+ "user_content": "What is the square's color?",
69
+ "tool_name": "get_square_color",
70
+ "tool_arguments": "{}"
71
+ },
72
+ {
73
+ "user_content": "Make the square cyan",
74
+ "tool_name": "set_square_color",
75
+ "tool_arguments": "{\"color\": \"cyan\"}"
76
+ },
77
+ {
78
+ "user_content": "Show me the color",
79
+ "tool_name": "get_square_color",
80
+ "tool_arguments": "{}"
81
+ },
82
+ {
83
+ "user_content": "I want it to be white",
84
+ "tool_name": "set_square_color",
85
+ "tool_arguments": "{\"color\": \"white\"}"
86
+ },
87
+ {
88
+ "user_content": "What color is it set to?",
89
+ "tool_name": "get_square_color",
90
+ "tool_arguments": "{}"
91
+ },
92
+ {
93
+ "user_content": "Change to black",
94
+ "tool_name": "set_square_color",
95
+ "tool_arguments": "{\"color\": \"black\"}"
96
+ },
97
+ {
98
+ "user_content": "Tell me the color of the square",
99
+ "tool_name": "get_square_color",
100
+ "tool_arguments": "{}"
101
+ },
102
+ {
103
+ "user_content": "Set it to teal",
104
+ "tool_name": "set_square_color",
105
+ "tool_arguments": "{\"color\": \"teal\"}"
106
+ },
107
+ {
108
+ "user_content": "Query the current color",
109
+ "tool_name": "get_square_color",
110
+ "tool_arguments": "{}"
111
+ },
112
+ {
113
+ "user_content": "Make it magenta",
114
+ "tool_name": "set_square_color",
115
+ "tool_arguments": "{\"color\": \"magenta\"}"
116
+ },
117
+ {
118
+ "user_content": "Get the color",
119
+ "tool_name": "get_square_color",
120
+ "tool_arguments": "{}"
121
+ },
122
+ {
123
+ "user_content": "I'd like the square to be lime",
124
+ "tool_name": "set_square_color",
125
+ "tool_arguments": "{\"color\": \"lime\"}"
126
+ },
127
+ {
128
+ "user_content": "Read the current color",
129
+ "tool_name": "get_square_color",
130
+ "tool_arguments": "{}"
131
+ },
132
+ {
133
+ "user_content": "Update the color to navy",
134
+ "tool_name": "set_square_color",
135
+ "tool_arguments": "{\"color\": \"navy\"}"
136
+ },
137
+ {
138
+ "user_content": "Check the square color",
139
+ "tool_name": "get_square_color",
140
+ "tool_arguments": "{}"
141
+ },
142
+ {
143
+ "user_content": "Set color to coral",
144
+ "tool_name": "set_square_color",
145
+ "tool_arguments": "{\"color\": \"coral\"}"
146
+ },
147
+ {
148
+ "user_content": "What color do we have?",
149
+ "tool_name": "get_square_color",
150
+ "tool_arguments": "{}"
151
+ },
152
+ {
153
+ "user_content": "Put it in violet",
154
+ "tool_name": "set_square_color",
155
+ "tool_arguments": "{\"color\": \"violet\"}"
156
+ },
157
+ {
158
+ "user_content": "Display the color",
159
+ "tool_name": "get_square_color",
160
+ "tool_arguments": "{}"
161
+ },
162
+ {
163
+ "user_content": "Color it gold",
164
+ "tool_name": "set_square_color",
165
+ "tool_arguments": "{\"color\": \"gold\"}"
166
+ },
167
+ {
168
+ "user_content": "Fetch the current color",
169
+ "tool_name": "get_square_color",
170
+ "tool_arguments": "{}"
171
+ },
172
+ {
173
+ "user_content": "Apply salmon color",
174
+ "tool_name": "set_square_color",
175
+ "tool_arguments": "{\"color\": \"salmon\"}"
176
+ },
177
+ {
178
+ "user_content": "Return the color value",
179
+ "tool_name": "get_square_color",
180
+ "tool_arguments": "{}"
181
+ },
182
+ {
183
+ "user_content": "Use turquoise",
184
+ "tool_name": "set_square_color",
185
+ "tool_arguments": "{\"color\": \"turquoise\"}"
186
+ },
187
+ {
188
+ "user_content": "What's the current state of the color?",
189
+ "tool_name": "get_square_color",
190
+ "tool_arguments": "{}"
191
+ },
192
+ {
193
+ "user_content": "Modify the square to crimson",
194
+ "tool_name": "set_square_color",
195
+ "tool_arguments": "{\"color\": \"crimson\"}"
196
+ },
197
+ {
198
+ "user_content": "Retrieve the square's color",
199
+ "tool_name": "get_square_color",
200
+ "tool_arguments": "{}"
201
+ },
202
+ {
203
+ "user_content": "Please change to blue",
204
+ "tool_name": "set_square_color",
205
+ "tool_arguments": "{\"color\": \"blue\"}"
206
+ },
207
+ {
208
+ "user_content": "Could you tell me the color?",
209
+ "tool_name": "get_square_color",
210
+ "tool_arguments": "{}"
211
+ },
212
+ {
213
+ "user_content": "I need it red",
214
+ "tool_name": "set_square_color",
215
+ "tool_arguments": "{\"color\": \"red\"}"
216
+ },
217
+ {
218
+ "user_content": "What is the current color?",
219
+ "tool_name": "get_square_color",
220
+ "tool_arguments": "{}"
221
+ },
222
+ {
223
+ "user_content": "Let's make it green",
224
+ "tool_name": "set_square_color",
225
+ "tool_arguments": "{\"color\": \"green\"}"
226
+ },
227
+ {
228
+ "user_content": "Can you get the color?",
229
+ "tool_name": "get_square_color",
230
+ "tool_arguments": "{}"
231
+ },
232
+ {
233
+ "user_content": "Go with purple",
234
+ "tool_name": "set_square_color",
235
+ "tool_arguments": "{\"color\": \"purple\"}"
236
+ },
237
+ {
238
+ "user_content": "I want to know the color",
239
+ "tool_name": "get_square_color",
240
+ "tool_arguments": "{}"
241
+ },
242
+ {
243
+ "user_content": "How about yellow?",
244
+ "tool_name": "set_square_color",
245
+ "tool_arguments": "{\"color\": \"yellow\"}"
246
+ },
247
+ {
248
+ "user_content": "Give me the color info",
249
+ "tool_name": "get_square_color",
250
+ "tool_arguments": "{}"
251
+ },
252
+ {
253
+ "user_content": "Try orange",
254
+ "tool_name": "set_square_color",
255
+ "tool_arguments": "{\"color\": \"orange\"}"
256
+ },
257
+ {
258
+ "user_content": "Report the current color",
259
+ "tool_name": "get_square_color",
260
+ "tool_arguments": "{}"
261
+ },
262
+ {
263
+ "user_content": "Let's go with pink",
264
+ "tool_name": "set_square_color",
265
+ "tool_arguments": "{\"color\": \"pink\"}"
266
+ },
267
+ {
268
+ "user_content": "What's the square showing?",
269
+ "tool_name": "get_square_color",
270
+ "tool_arguments": "{}"
271
+ },
272
+ {
273
+ "user_content": "Change it to brown",
274
+ "tool_name": "set_square_color",
275
+ "tool_arguments": "{\"color\": \"brown\"}"
276
+ },
277
+ {
278
+ "user_content": "Tell me what color it is",
279
+ "tool_name": "get_square_color",
280
+ "tool_arguments": "{}"
281
+ },
282
+ {
283
+ "user_content": "Set to silver",
284
+ "tool_name": "set_square_color",
285
+ "tool_arguments": "{\"color\": \"silver\"}"
286
+ },
287
+ {
288
+ "user_content": "Check what color the square is",
289
+ "tool_name": "get_square_color",
290
+ "tool_arguments": "{}"
291
+ },
292
+ {
293
+ "user_content": "Make the color maroon",
294
+ "tool_name": "set_square_color",
295
+ "tool_arguments": "{\"color\": \"maroon\"}"
296
+ },
297
+ {
298
+ "user_content": "Show current color",
299
+ "tool_name": "get_square_color",
300
+ "tool_arguments": "{}"
301
+ },
302
+ {
303
+ "user_content": "blue",
304
+ "tool_name": "set_square_color",
305
+ "tool_arguments": "{\"color\": \"blue\"}"
306
+ },
307
+ {
308
+ "user_content": "red please",
309
+ "tool_name": "set_square_color",
310
+ "tool_arguments": "{\"color\": \"red\"}"
311
+ },
312
+ {
313
+ "user_content": "green!",
314
+ "tool_name": "set_square_color",
315
+ "tool_arguments": "{\"color\": \"green\"}"
316
+ },
317
+ {
318
+ "user_content": "color?",
319
+ "tool_name": "get_square_color",
320
+ "tool_arguments": "{}"
321
+ },
322
+ {
323
+ "user_content": "what color",
324
+ "tool_name": "get_square_color",
325
+ "tool_arguments": "{}",
326
+ "tool_arguments": "{}"
327
+ },
328
+ {
329
+ "user_content": "current color?",
330
+ "tool_name": "get_square_color",
331
+ "tool_arguments": "{}"
332
+ },
333
+ {
334
+ "user_content": "Set the square color to indigo",
335
+ "tool_name": "set_square_color",
336
+ "tool_arguments": "{\"color\": \"indigo\"}"
337
+ },
338
+ {
339
+ "user_content": "I want indigo",
340
+ "tool_name": "set_square_color",
341
+ "tool_arguments": "{\"color\": \"indigo\"}"
342
+ },
343
+ {
344
+ "user_content": "Make it olive",
345
+ "tool_name": "set_square_color",
346
+ "tool_arguments": "{\"color\": \"olive\"}"
347
+ },
348
+ {
349
+ "user_content": "Switch the color to beige",
350
+ "tool_name": "set_square_color",
351
+ "tool_arguments": "{\"color\": \"beige\"}"
352
+ },
353
+ {
354
+ "user_content": "Can you change it to lavender?",
355
+ "tool_name": "set_square_color",
356
+ "tool_arguments": "{\"color\": \"lavender\"}"
357
+ },
358
+ {
359
+ "user_content": "What's the color of the square right now?",
360
+ "tool_name": "get_square_color",
361
+ "tool_arguments": "{}"
362
+ },
363
+ {
364
+ "user_content": "I'm curious about the current color",
365
+ "tool_name": "get_square_color",
366
+ "tool_arguments": "{}"
367
+ },
368
+ {
369
+ "user_content": "Tell me what the square looks like",
370
+ "tool_name": "get_square_color",
371
+ "tool_arguments": "{}"
372
+ },
373
+ {
374
+ "user_content": "Please set it to aqua",
375
+ "tool_name": "set_square_color",
376
+ "tool_arguments": "{\"color\": \"aqua\"}"
377
+ },
378
+ {
379
+ "user_content": "Could you make it peach?",
380
+ "tool_name": "set_square_color",
381
+ "tool_arguments": "{\"color\": \"peach\"}"
382
+ },
383
+ {
384
+ "user_content": "Would you change the color to mint?",
385
+ "tool_name": "set_square_color",
386
+ "tool_arguments": "{\"color\": \"mint\"}"
387
+ },
388
+ {
389
+ "user_content": "I'd appreciate it if you set it to ruby",
390
+ "tool_name": "set_square_color",
391
+ "tool_arguments": "{\"color\": \"ruby\"}"
392
+ },
393
+ {
394
+ "user_content": "Can I get the color please?",
395
+ "tool_name": "get_square_color",
396
+ "tool_arguments": "{}"
397
+ },
398
+ {
399
+ "user_content": "Would you mind telling me the color?",
400
+ "tool_name": "get_square_color",
401
+ "tool_arguments": "{}"
402
+ },
403
+ {
404
+ "user_content": "I need to know the current color",
405
+ "tool_name": "get_square_color",
406
+ "tool_arguments": "{}"
407
+ },
408
+ {
409
+ "user_content": "Let me know the square's color",
410
+ "tool_name": "get_square_color",
411
+ "tool_arguments": "{}"
412
+ },
413
+ {
414
+ "user_content": "sky blue",
415
+ "tool_name": "set_square_color",
416
+ "tool_arguments": "{\"color\": \"sky blue\"}"
417
+ },
418
+ {
419
+ "user_content": "dark green",
420
+ "tool_name": "set_square_color",
421
+ "tool_arguments": "{\"color\": \"dark green\"}"
422
+ },
423
+ {
424
+ "user_content": "light blue",
425
+ "tool_name": "set_square_color",
426
+ "tool_arguments": "{\"color\": \"light blue\"}"
427
+ },
428
+ {
429
+ "user_content": "dark red",
430
+ "tool_name": "set_square_color",
431
+ "tool_arguments": "{\"color\": \"dark red\"}"
432
+ },
433
+ {
434
+ "user_content": "bright yellow",
435
+ "tool_name": "set_square_color",
436
+ "tool_arguments": "{\"color\": \"bright yellow\"}"
437
+ },
438
+ {
439
+ "user_content": "pale pink",
440
+ "tool_name": "set_square_color",
441
+ "tool_arguments": "{\"color\": \"pale pink\"}"
442
+ },
443
+ {
444
+ "user_content": "forest green",
445
+ "tool_name": "set_square_color",
446
+ "tool_arguments": "{\"color\": \"forest green\"}"
447
+ },
448
+ {
449
+ "user_content": "ocean blue",
450
+ "tool_name": "set_square_color",
451
+ "tool_arguments": "{\"color\": \"ocean blue\"}"
452
+ },
453
+ {
454
+ "user_content": "Set square to red",
455
+ "tool_name": "set_square_color",
456
+ "tool_arguments": "{\"color\": \"red\"}"
457
+ },
458
+ {
459
+ "user_content": "Square color = blue",
460
+ "tool_name": "set_square_color",
461
+ "tool_arguments": "{\"color\": \"blue\"}"
462
+ },
463
+ {
464
+ "user_content": "color: green",
465
+ "tool_name": "set_square_color",
466
+ "tool_arguments": "{\"color\": \"green\"}"
467
+ },
468
+ {
469
+ "user_content": "get color",
470
+ "tool_name": "get_square_color",
471
+ "tool_arguments": "{}"
472
+ },
473
+ {
474
+ "user_content": "show color",
475
+ "tool_name": "get_square_color",
476
+ "tool_arguments": "{}"
477
+ },
478
+ {
479
+ "user_content": "read color",
480
+ "tool_name": "get_square_color",
481
+ "tool_arguments": "{}"
482
+ },
483
+ {
484
+ "user_content": "Yo make it blue",
485
+ "tool_name": "set_square_color",
486
+ "tool_arguments": "{\"color\": \"blue\"}"
487
+ },
488
+ {
489
+ "user_content": "Hey change to red",
490
+ "tool_name": "set_square_color",
491
+ "tool_arguments": "{\"color\": \"red\"}"
492
+ },
493
+ {
494
+ "user_content": "Sup whats the color",
495
+ "tool_name": "get_square_color",
496
+ "tool_arguments": "{}"
497
+ },
498
+ {
499
+ "user_content": "yo color?",
500
+ "tool_name": "get_square_color",
501
+ "tool_arguments": "{}"
502
+ },
503
+ {
504
+ "user_content": "gimme yellow",
505
+ "tool_name": "set_square_color",
506
+ "tool_arguments": "{\"color\": \"yellow\"}"
507
+ },
508
+ {
509
+ "user_content": "hit me with that purple",
510
+ "tool_name": "set_square_color",
511
+ "tool_arguments": "{\"color\": \"purple\"}"
512
+ },
513
+ {
514
+ "user_content": "gonna need orange on that",
515
+ "tool_name": "set_square_color",
516
+ "tool_arguments": "{\"color\": \"orange\"}"
517
+ },
518
+ {
519
+ "user_content": "Just tell me the color already",
520
+ "tool_name": "get_square_color",
521
+ "tool_arguments": "{}"
522
+ },
523
+ {
524
+ "user_content": "Give me green now",
525
+ "tool_name": "set_square_color",
526
+ "tool_arguments": "{\"color\": \"green\"}"
527
+ },
528
+ {
529
+ "user_content": "What color are we looking at?",
530
+ "tool_name": "get_square_color",
531
+ "tool_arguments": "{}"
532
+ },
533
+ {
534
+ "user_content": "I would like to request the square be changed to azure",
535
+ "tool_name": "set_square_color",
536
+ "tool_arguments": "{\"color\": \"azure\"}"
537
+ },
538
+ {
539
+ "user_content": "Please kindly update the color to burgundy",
540
+ "tool_name": "set_square_color",
541
+ "tool_arguments": "{\"color\": \"burgundy\"}"
542
+ },
543
+ {
544
+ "user_content": "If you could, please inform me of the current color",
545
+ "tool_name": "get_square_color",
546
+ "tool_arguments": "{}"
547
+ },
548
+ {
549
+ "user_content": "I would appreciate knowing what color the square is",
550
+ "tool_name": "get_square_color",
551
+ "tool_arguments": "{}"
552
+ },
553
+ {
554
+ "user_content": "May I request that you change it to periwinkle?",
555
+ "tool_name": "set_square_color",
556
+ "tool_arguments": "{\"color\": \"periwinkle\"}"
557
+ },
558
+ {
559
+ "user_content": "Could you kindly set the color to chartreuse?",
560
+ "tool_name": "set_square_color",
561
+ "tool_arguments": "{\"color\": \"chartreuse\"}"
562
+ },
563
+ {
564
+ "user_content": "plz blue",
565
+ "tool_name": "set_square_color",
566
+ "tool_arguments": "{\"color\": \"blue\"}"
567
+ },
568
+ {
569
+ "user_content": "pls red",
570
+ "tool_name": "set_square_color",
571
+ "tool_arguments": "{\"color\": \"red\"}"
572
+ },
573
+ {
574
+ "user_content": "thx color?",
575
+ "tool_name": "get_square_color",
576
+ "tool_arguments": "{}"
577
+ },
578
+ {
579
+ "user_content": "ty what color",
580
+ "tool_name": "get_square_color",
581
+ "tool_arguments": "{}"
582
+ },
583
+ {
584
+ "user_content": "k make it green",
585
+ "tool_name": "set_square_color",
586
+ "tool_arguments": "{\"color\": \"green\"}"
587
+ },
588
+ {
589
+ "user_content": "Set the display color to amber",
590
+ "tool_name": "set_square_color",
591
+ "tool_arguments": "{\"color\": \"amber\"}"
592
+ },
593
+ {
594
+ "user_content": "Update display to scarlet",
595
+ "tool_name": "set_square_color",
596
+ "tool_arguments": "{\"color\": \"scarlet\"}"
597
+ },
598
+ {
599
+ "user_content": "Change display color to emerald",
600
+ "tool_name": "set_square_color",
601
+ "tool_arguments": "{\"color\": \"emerald\"}"
602
+ },
603
+ {
604
+ "user_content": "What is the display showing?",
605
+ "tool_name": "get_square_color",
606
+ "tool_arguments": "{}"
607
+ },
608
+ {
609
+ "user_content": "Get display color",
610
+ "tool_name": "get_square_color",
611
+ "tool_arguments": "{}"
612
+ },
613
+ {
614
+ "user_content": "Show display color",
615
+ "tool_name": "get_square_color",
616
+ "tool_arguments": "{}",
617
+ "tool_arguments": "{}"
618
+ },
619
+ {
620
+ "user_content": "Alright, set it to cerulean",
621
+ "tool_name": "set_square_color",
622
+ "tool_arguments": "{\"color\": \"cerulean\"}"
623
+ },
624
+ {
625
+ "user_content": "OK so make it tangerine",
626
+ "tool_name": "set_square_color",
627
+ "tool_arguments": "{\"color\": \"tangerine\"}"
628
+ },
629
+ {
630
+ "user_content": "Fine, change to mauve",
631
+ "tool_name": "set_square_color",
632
+ "tool_arguments": "{\"color\": \"mauve\"}"
633
+ },
634
+ {
635
+ "user_content": "Sure, what's the color?",
636
+ "tool_name": "get_square_color",
637
+ "tool_arguments": "{}"
638
+ },
639
+ {
640
+ "user_content": "Yeah tell me the color",
641
+ "tool_name": "get_square_color",
642
+ "tool_arguments": "{}"
643
+ },
644
+ {
645
+ "user_content": "So what color is it?",
646
+ "tool_name": "get_square_color",
647
+ "tool_arguments": "{}"
648
+ },
649
+ {
650
+ "user_content": "And the color is?",
651
+ "tool_name": "get_square_color",
652
+ "tool_arguments": "{}"
653
+ },
654
+ {
655
+ "user_content": "Make it #FF0000",
656
+ "tool_name": "set_square_color",
657
+ "tool_arguments": "{\"color\": \"#FF0000\"}"
658
+ },
659
+ {
660
+ "user_content": "Set to #00FF00",
661
+ "tool_name": "set_square_color",
662
+ "tool_arguments": "{\"color\": \"#00FF00\"}"
663
+ },
664
+ {
665
+ "user_content": "Change to #0000FF",
666
+ "tool_name": "set_square_color",
667
+ "tool_arguments": "{\"color\": \"#0000FF\"}"
668
+ },
669
+ {
670
+ "user_content": "Use hex #FFFF00",
671
+ "tool_name": "set_square_color",
672
+ "tool_arguments": "{\"color\": \"#FFFF00\"}"
673
+ },
674
+ {
675
+ "user_content": "Apply #FF00FF",
676
+ "tool_name": "set_square_color",
677
+ "tool_arguments": "{\"color\": \"#FF00FF\"}"
678
+ },
679
+ {
680
+ "user_content": "Set to rgb red",
681
+ "tool_name": "set_square_color",
682
+ "tool_arguments": "{\"color\": \"red\"}"
683
+ },
684
+ {
685
+ "user_content": "First, tell me what color it is",
686
+ "tool_name": "get_square_color",
687
+ "tool_arguments": "{}"
688
+ },
689
+ {
690
+ "user_content": "Before anything, what's the color?",
691
+ "tool_name": "get_square_color",
692
+ "tool_arguments": "{}"
693
+ },
694
+ {
695
+ "user_content": "To start, show me the current color",
696
+ "tool_name": "get_square_color",
697
+ "tool_arguments": "{}"
698
+ },
699
+ {
700
+ "user_content": "Now change it to slate",
701
+ "tool_name": "set_square_color",
702
+ "tool_arguments": "{\"color\": \"slate\"}"
703
+ },
704
+ {
705
+ "user_content": "Then make it ivory",
706
+ "tool_name": "set_square_color",
707
+ "tool_arguments": "{\"color\": \"ivory\"}"
708
+ },
709
+ {
710
+ "user_content": "After that set it to khaki",
711
+ "tool_name": "set_square_color",
712
+ "tool_arguments": "{\"color\": \"khaki\"}"
713
+ },
714
+ {
715
+ "user_content": "Can you check the color for me?",
716
+ "tool_name": "get_square_color",
717
+ "tool_arguments": "{}"
718
+ },
719
+ {
720
+ "user_content": "Just checking - what color is it?",
721
+ "tool_name": "get_square_color",
722
+ "tool_arguments": "{}"
723
+ },
724
+ {
725
+ "user_content": "Quick question - the color?",
726
+ "tool_name": "get_square_color",
727
+ "tool_arguments": "{}"
728
+ },
729
+ {
730
+ "user_content": "One thing - change to plum",
731
+ "tool_name": "set_square_color",
732
+ "tool_arguments": "{\"color\": \"plum\"}"
733
+ },
734
+ {
735
+ "user_content": "Real quick - make it rust",
736
+ "tool_name": "set_square_color",
737
+ "tool_arguments": "{\"color\": \"rust\"}"
738
+ },
739
+ {
740
+ "user_content": "BTW set it to jade",
741
+ "tool_name": "set_square_color",
742
+ "tool_arguments": "{\"color\": \"jade\"}"
743
+ },
744
+ {
745
+ "user_content": "FYI the color should be sapphire",
746
+ "tool_name": "set_square_color",
747
+ "tool_arguments": "{\"color\": \"sapphire\"}"
748
+ },
749
+ {
750
+ "user_content": "lmk the color",
751
+ "tool_name": "get_square_color",
752
+ "tool_arguments": "{}"
753
+ },
754
+ {
755
+ "user_content": "hmu with the color info",
756
+ "tool_name": "get_square_color",
757
+ "tool_arguments": "{}"
758
+ },
759
+ {
760
+ "user_content": "need the color asap",
761
+ "tool_name": "get_square_color",
762
+ "tool_arguments": "{}"
763
+ },
764
+ {
765
+ "user_content": "color pls",
766
+ "tool_name": "get_square_color",
767
+ "tool_arguments": "{}"
768
+ },
769
+ {
770
+ "user_content": "Hmm make it rose",
771
+ "tool_name": "set_square_color",
772
+ "tool_arguments": "{\"color\": \"rose\"}"
773
+ },
774
+ {
775
+ "user_content": "Ugh just set it to tan",
776
+ "tool_name": "set_square_color",
777
+ "tool_arguments": "{\"color\": \"tan\"}"
778
+ },
779
+ {
780
+ "user_content": "Wow change to electric blue",
781
+ "tool_name": "set_square_color",
782
+ "tool_arguments": "{\"color\": \"electric blue\"}"
783
+ },
784
+ {
785
+ "user_content": "Ooh make it neon green",
786
+ "tool_name": "set_square_color",
787
+ "tool_arguments": "{\"color\": \"neon green\"}"
788
+ },
789
+ {
790
+ "user_content": "Nice! What color?",
791
+ "tool_name": "get_square_color",
792
+ "tool_arguments": "{}"
793
+ },
794
+ {
795
+ "user_content": "Cool, show me the color",
796
+ "tool_name": "get_square_color",
797
+ "tool_arguments": "{}"
798
+ },
799
+ {
800
+ "user_content": "Awesome, what's the color now?",
801
+ "tool_name": "get_square_color",
802
+ "tool_arguments": "{}"
803
+ },
804
+ {
805
+ "user_content": "Great, tell me the color",
806
+ "tool_name": "get_square_color",
807
+ "tool_arguments": "{}"
808
+ },
809
+ {
810
+ "user_content": "I command you to set it to fuchsia",
811
+ "tool_name": "set_square_color",
812
+ "tool_arguments": "{\"color\": \"fuchsia\"}"
813
+ },
814
+ {
815
+ "user_content": "You must change it to cobalt",
816
+ "tool_name": "set_square_color",
817
+ "tool_arguments": "{\"color\": \"cobalt\"}"
818
+ },
819
+ {
820
+ "user_content": "I order you to make it bronze",
821
+ "tool_name": "set_square_color",
822
+ "tool_arguments": "{\"color\": \"bronze\"}"
823
+ },
824
+ {
825
+ "user_content": "You are required to tell me the color",
826
+ "tool_name": "get_square_color",
827
+ "tool_arguments": "{}"
828
+ },
829
+ {
830
+ "user_content": "You shall inform me of the current color",
831
+ "tool_name": "get_square_color",
832
+ "tool_arguments": "{}"
833
+ },
834
+ {
835
+ "user_content": "I hereby request the color information",
836
+ "tool_name": "get_square_color",
837
+ "tool_arguments": "{}"
838
+ },
839
+ {
840
+ "user_content": "change blue",
841
+ "tool_name": "set_square_color",
842
+ "tool_arguments": "{\"color\": \"blue\"}"
843
+ },
844
+ {
845
+ "user_content": "set red",
846
+ "tool_name": "set_square_color",
847
+ "tool_arguments": "{\"color\": \"red\"}"
848
+ },
849
+ {
850
+ "user_content": "make green",
851
+ "tool_name": "set_square_color",
852
+ "tool_arguments": "{\"color\": \"green\"}"
853
+ },
854
+ {
855
+ "user_content": "do yellow",
856
+ "tool_name": "set_square_color",
857
+ "tool_arguments": "{\"color\": \"yellow\"}"
858
+ },
859
+ {
860
+ "user_content": "try purple",
861
+ "tool_name": "set_square_color",
862
+ "tool_arguments": "{\"color\": \"purple\"}"
863
+ },
864
+ {
865
+ "user_content": "use orange",
866
+ "tool_name": "set_square_color",
867
+ "tool_arguments": "{\"color\": \"orange\"}"
868
+ },
869
+ {
870
+ "user_content": "whats the color",
871
+ "tool_name": "get_square_color",
872
+ "tool_arguments": "{}"
873
+ },
874
+ {
875
+ "user_content": "tell color",
876
+ "tool_name": "get_square_color",
877
+ "tool_arguments": "{}"
878
+ },
879
+ {
880
+ "user_content": "get current color",
881
+ "tool_name": "get_square_color",
882
+ "tool_arguments": "{}"
883
+ },
884
+ {
885
+ "user_content": "show what color",
886
+ "tool_name": "get_square_color",
887
+ "tool_arguments": "{}"
888
+ },
889
+ {
890
+ "user_content": "Want blue on the square",
891
+ "tool_name": "set_square_color",
892
+ "tool_arguments": "{\"color\": \"blue\"}"
893
+ },
894
+ {
895
+ "user_content": "Need the square red",
896
+ "tool_name": "set_square_color",
897
+ "tool_arguments": "{\"color\": \"red\"}"
898
+ },
899
+ {
900
+ "user_content": "Gotta have green",
901
+ "tool_name": "set_square_color",
902
+ "tool_arguments": "{\"color\": \"green\"}"
903
+ },
904
+ {
905
+ "user_content": "Wanna see yellow",
906
+ "tool_name": "set_square_color",
907
+ "tool_arguments": "{\"color\": \"yellow\"}"
908
+ },
909
+ {
910
+ "user_content": "Curious what color is showing",
911
+ "tool_name": "get_square_color",
912
+ "tool_arguments": "{}"
913
+ },
914
+ {
915
+ "user_content": "Wondering about the color",
916
+ "tool_name": "get_square_color",
917
+ "tool_arguments": "{}"
918
+ },
919
+ {
920
+ "user_content": "Interested in the current color",
921
+ "tool_name": "get_square_color",
922
+ "tool_arguments": "{}"
923
+ },
924
+ {
925
+ "user_content": "Looking to know the color",
926
+ "tool_name": "get_square_color",
927
+ "tool_arguments": "{}"
928
+ },
929
+ {
930
+ "user_content": "The square needs to be steel blue",
931
+ "tool_name": "set_square_color",
932
+ "tool_arguments": "{\"color\": \"steel blue\"}"
933
+ },
934
+ {
935
+ "user_content": "I think hot pink would be nice",
936
+ "tool_name": "set_square_color",
937
+ "tool_arguments": "{\"color\": \"hot pink\"}"
938
+ },
939
+ {
940
+ "user_content": "How about sea green?",
941
+ "tool_name": "set_square_color",
942
+ "tool_arguments": "{\"color\": \"sea green\"}"
943
+ },
944
+ {
945
+ "user_content": "Maybe midnight blue?",
946
+ "tool_name": "set_square_color",
947
+ "tool_arguments": "{\"color\": \"midnight blue\"}"
948
+ },
949
+ {
950
+ "user_content": "Thinking about the color, what is it?",
951
+ "tool_name": "get_square_color",
952
+ "tool_arguments": "{}"
953
+ },
954
+ {
955
+ "user_content": "Speaking of colors, which one is active?",
956
+ "tool_name": "get_square_color",
957
+ "tool_arguments": "{}"
958
+ },
959
+ {
960
+ "user_content": "On the topic of the square, what color?",
961
+ "tool_name": "get_square_color",
962
+ "tool_arguments": "{}"
963
+ },
964
+ {
965
+ "user_content": "Regarding the display, what's the color?",
966
+ "tool_name": "get_square_color",
967
+ "tool_arguments": "{}"
968
+ }
969
+ ]
finetuning-functiongemma/finetune_functiongemma.ipynb ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🎨 Fine-tuning FunctionGemma for Square Color Control\n",
8
+ "\n",
9
+ "This notebook demonstrates how to fine-tune FunctionGemma to recognize color control commands.\n",
10
+ "\n",
11
+ "**Author:** [Your Name]\n",
12
+ "**Portfolio:** AI Engineering\n",
13
+ "\n",
14
+ "## Objectives\n",
15
+ "1. Train the model to call `set_square_color` when the user wants to change the color\n",
16
+ "2. Train the model to call `get_square_color` when the user asks about the current color\n",
17
+ "3. Support various natural language command styles"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## πŸ“¦ 1. Setup and Installation"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# Install dependencies\n",
34
+ "%pip install -q torch tensorboard\n",
35
+ "%pip install -q transformers datasets accelerate evaluate trl protobuf sentencepiece\n",
36
+ "\n",
37
+ "# If running on Ampere+ GPU (A100, L4), uncomment:\n",
38
+ "# %pip install -q flash-attn"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# Login to Hugging Face Hub\n",
48
+ "from huggingface_hub import login\n",
49
+ "\n",
50
+ "# If using Colab secrets:\n",
51
+ "# from google.colab import userdata\n",
52
+ "# login(token=userdata.get('HF_TOKEN'))\n",
53
+ "\n",
54
+ "# Or interactive login:\n",
55
+ "login()"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# Configuration\n",
65
+ "BASE_MODEL = \"google/functiongemma-270m-it\"\n",
66
+ "OUTPUT_DIR = \"functiongemma-square-color\" # Model name on your HF Hub\n",
67
+ "LEARNING_RATE = 5e-5\n",
68
+ "NUM_EPOCHS = 8\n",
69
+ "BATCH_SIZE = 4"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "metadata": {},
75
+ "source": [
76
+ "## πŸ“Š 2. Prepare Dataset"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "import json\n",
86
+ "from datasets import Dataset\n",
87
+ "from transformers.utils import get_json_schema\n",
88
+ "\n",
89
+ "# Tool definitions\n",
90
+ "def set_square_color(color: str) -> str:\n",
91
+ " \"\"\"\n",
92
+ " Sets the color of the square displayed on the screen.\n",
93
+ " \n",
94
+ " Args:\n",
95
+ " color: The color to set, e.g. red, blue, green\n",
96
+ " \"\"\"\n",
97
+ " return f\"Color set to {color}\"\n",
98
+ "\n",
99
+ "def get_square_color() -> str:\n",
100
+ " \"\"\"\n",
101
+ " Returns the current color of the square.\n",
102
+ " Use this when the user asks about the current color.\n",
103
+ " \"\"\"\n",
104
+ " return \"Current color\"\n",
105
+ "\n",
106
+ "# Generate schemas automatically\n",
107
+ "TOOLS = [\n",
108
+ " get_json_schema(set_square_color),\n",
109
+ " get_json_schema(get_square_color)\n",
110
+ "]\n",
111
+ "\n",
112
+ "print(\"Tool schemas:\")\n",
113
+ "print(json.dumps(TOOLS, indent=2))"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "# Load training dataset from file\n",
123
+ "with open(\"dataset/square_color_dataset.json\", \"r\") as f:\n",
124
+ " square_color_dataset = json.load(f)\n",
125
+ "\n",
126
+ "print(f\"Total examples: {len(square_color_dataset)}\")\n",
127
+ "print(f\" - SET: {len([x for x in square_color_dataset if x['tool_name'] == 'set_square_color'])}\")\n",
128
+ "print(f\" - GET: {len([x for x in square_color_dataset if x['tool_name'] == 'get_square_color'])}\")\n",
129
+ "\n",
130
+ "# Preview first few examples\n",
131
+ "print(\"\\nFirst 3 examples:\")\n",
132
+ "for i, sample in enumerate(square_color_dataset[:3]):\n",
133
+ " print(f\" {i+1}. \\\"{sample['user_content']}\\\" β†’ {sample['tool_name']}\")"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# Convert to conversation format\n",
143
+ "SYSTEM_PROMPT = \"You are a model that can do function calling with the following functions\"\n",
144
+ "\n",
145
+ "def create_conversation(sample):\n",
146
+ " return {\n",
147
+ " \"messages\": [\n",
148
+ " {\"role\": \"developer\", \"content\": SYSTEM_PROMPT},\n",
149
+ " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n",
150
+ " {\n",
151
+ " \"role\": \"assistant\",\n",
152
+ " \"tool_calls\": [{\n",
153
+ " \"type\": \"function\",\n",
154
+ " \"function\": {\n",
155
+ " \"name\": sample[\"tool_name\"],\n",
156
+ " \"arguments\": json.loads(sample[\"tool_arguments\"])\n",
157
+ " }\n",
158
+ " }]\n",
159
+ " },\n",
160
+ " ],\n",
161
+ " \"tools\": TOOLS\n",
162
+ " }\n",
163
+ "\n",
164
+ "# Create dataset\n",
165
+ "dataset = Dataset.from_list(square_color_dataset)\n",
166
+ "dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)\n",
167
+ "\n",
168
+ "# Split 80/20\n",
169
+ "dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
170
+ "\n",
171
+ "print(f\"Train: {len(dataset['train'])} examples\")\n",
172
+ "print(f\"Test: {len(dataset['test'])} examples\")"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "# Visualize an example\n",
182
+ "print(\"Formatted conversation example:\")\n",
183
+ "print(json.dumps(dataset[\"train\"][0], indent=2))"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {},
189
+ "source": [
190
+ "## πŸ€– 3. Load Model"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "import torch\n",
200
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
201
+ "\n",
202
+ "quantization_config = BitsAndBytesConfig(load_in_4bit=True)\n",
203
+ "\n",
204
+ "# Load model and tokenizer\n",
205
+ "model = AutoModelForCausalLM.from_pretrained(\n",
206
+ " BASE_MODEL,\n",
207
+ " torch_dtype=\"auto\",\n",
208
+ " device_map=\"auto\",\n",
209
+ " quantization_config=quantization_config, \n",
210
+ " attn_implementation=\"eager\"\n",
211
+ " \n",
212
+ ")\n",
213
+ "\n",
214
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
215
+ "\n",
216
+ "print(f\"Device: {model.device}\")\n",
217
+ "print(f\"DType: {model.dtype}\")\n",
218
+ "print(f\"Parameters: {model.num_parameters():,}\")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "# Visualize how the tokenizer formats the prompt\n",
228
+ "debug_msg = tokenizer.apply_chat_template(\n",
229
+ " dataset[\"train\"][0][\"messages\"],\n",
230
+ " tools=dataset[\"train\"][0][\"tools\"],\n",
231
+ " add_generation_prompt=False,\n",
232
+ " tokenize=False\n",
233
+ ")\n",
234
+ "\n",
235
+ "print(\"=== Formatted prompt ===\")\n",
236
+ "print(debug_msg)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {},
242
+ "source": [
243
+ "## πŸ§ͺ 3.5. Pre-Training Evaluation (Baseline)\n",
244
+ "\n",
245
+ "Before fine-tuning, let's evaluate the base model to establish a baseline. This helps us measure the actual improvement from fine-tuning."
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "def evaluate_model(model, tokenizer, test_samples, tools, system_prompt, verbose=True):\n",
255
+ " \"\"\"\n",
256
+ " Evaluate model on a set of test samples.\n",
257
+ " Returns accuracy metrics and detailed results.\n",
258
+ " \"\"\"\n",
259
+ " results = {\n",
260
+ " \"total\": len(test_samples),\n",
261
+ " \"correct\": 0,\n",
262
+ " \"correct_tool\": 0,\n",
263
+ " \"correct_args\": 0,\n",
264
+ " \"details\": []\n",
265
+ " }\n",
266
+ " \n",
267
+ " for sample in test_samples:\n",
268
+ " messages = [\n",
269
+ " {\"role\": \"developer\", \"content\": system_prompt},\n",
270
+ " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n",
271
+ " ]\n",
272
+ " \n",
273
+ " inputs = tokenizer.apply_chat_template(\n",
274
+ " messages,\n",
275
+ " tools=tools,\n",
276
+ " tokenize=True,\n",
277
+ " add_generation_prompt=True,\n",
278
+ " return_dict=True,\n",
279
+ " return_tensors=\"pt\"\n",
280
+ " ).to(model.device)\n",
281
+ " \n",
282
+ " with torch.no_grad():\n",
283
+ " output = model.generate(\n",
284
+ " **inputs,\n",
285
+ " max_new_tokens=128,\n",
286
+ " do_sample=False,\n",
287
+ " )\n",
288
+ " \n",
289
+ " input_length = inputs['input_ids'].shape[1]\n",
290
+ " response = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)\n",
291
+ " \n",
292
+ " # Check if correct tool was called\n",
293
+ " tool_correct = sample[\"tool_name\"] in response\n",
294
+ " \n",
295
+ " # Check if arguments are correct (for set_square_color)\n",
296
+ " args_correct = False\n",
297
+ " if tool_correct and sample[\"tool_name\"] == \"set_square_color\":\n",
298
+ " expected_args = json.loads(sample[\"tool_arguments\"])\n",
299
+ " args_correct = expected_args.get(\"color\", \"\") in response\n",
300
+ " elif tool_correct and sample[\"tool_name\"] == \"get_square_color\":\n",
301
+ " args_correct = True # No args needed\n",
302
+ " \n",
303
+ " if tool_correct:\n",
304
+ " results[\"correct_tool\"] += 1\n",
305
+ " if tool_correct and args_correct:\n",
306
+ " results[\"correct\"] += 1\n",
307
+ " results[\"correct_args\"] += 1\n",
308
+ " \n",
309
+ " results[\"details\"].append({\n",
310
+ " \"input\": sample[\"user_content\"],\n",
311
+ " \"expected_tool\": sample[\"tool_name\"],\n",
312
+ " \"expected_args\": sample[\"tool_arguments\"],\n",
313
+ " \"response\": response,\n",
314
+ " \"tool_correct\": tool_correct,\n",
315
+ " \"args_correct\": args_correct\n",
316
+ " })\n",
317
+ " \n",
318
+ " results[\"tool_accuracy\"] = results[\"correct_tool\"] / results[\"total\"] * 100\n",
319
+ " results[\"full_accuracy\"] = results[\"correct\"] / results[\"total\"] * 100\n",
320
+ " \n",
321
+ " if verbose:\n",
322
+ " print(f\"Tool Accuracy: {results['correct_tool']}/{results['total']} ({results['tool_accuracy']:.1f}%)\")\n",
323
+ " print(f\"Full Accuracy (tool + args): {results['correct']}/{results['total']} ({results['full_accuracy']:.1f}%)\")\n",
324
+ " \n",
325
+ " return results"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "# Create evaluation test set from the dataset (sample 5 SET + 5 GET)\n",
335
+ "import random\n",
336
+ "\n",
337
+ "random.seed(42) # For reproducibility\n",
338
+ "\n",
339
+ "set_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"set_square_color\"]\n",
340
+ "get_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"get_square_color\"]\n",
341
+ "\n",
342
+ "eval_test_cases = random.sample(set_samples, min(5, len(set_samples))) + \\\n",
343
+ " random.sample(get_samples, min(5, len(get_samples)))\n",
344
+ "\n",
345
+ "print(\"=\" * 50)\n",
346
+ "print(\"PRE-TRAINING EVALUATION (Baseline)\")\n",
347
+ "print(\"=\" * 50)\n",
348
+ "print(f\"\\nEvaluating base model on {len(eval_test_cases)} test cases...\\n\")\n",
349
+ "\n",
350
+ "baseline_results = evaluate_model(\n",
351
+ " model=model,\n",
352
+ " tokenizer=tokenizer,\n",
353
+ " test_samples=eval_test_cases,\n",
354
+ " tools=TOOLS,\n",
355
+ " system_prompt=SYSTEM_PROMPT\n",
356
+ ")\n",
357
+ "\n",
358
+ "# Show some example outputs\n",
359
+ "print(\"\\n--- Sample Outputs (Base Model) ---\")\n",
360
+ "for i, detail in enumerate(baseline_results[\"details\"][:4]):\n",
361
+ " status = \"βœ…\" if detail[\"tool_correct\"] else \"❌\"\n",
362
+ " print(f\"\\n{status} Input: {detail['input']}\")\n",
363
+ " print(f\" Expected: {detail['expected_tool']}\")\n",
364
+ " print(f\" Output: {detail['response'][:200]}...\")"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "## πŸ”₯ 4. Fine-tuning"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "import torch\n",
381
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
382
+ "\n",
383
+ "print(\"Reloading model for fine-tuning (without quantization)...\")\n",
384
+ "\n",
385
+ "del model\n",
386
+ "torch.cuda.empty_cache()\n",
387
+ "\n",
388
+ "model = AutoModelForCausalLM.from_pretrained(\n",
389
+ " BASE_MODEL,\n",
390
+ " torch_dtype=torch.bfloat16,\n",
391
+ " device_map=\"auto\",\n",
392
+ " attn_implementation=\"eager\"\n",
393
+ ")\n",
394
+ "\n",
395
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
396
+ "\n",
397
+ "print(f\"Device: {model.device}\")\n",
398
+ "print(f\"DType: {model.dtype}\")\n",
399
+ "print(f\"Parameters: {model.num_parameters():,}\")\n",
400
+ "print(\"Ready for fine-tuning!\")"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "from trl import SFTConfig, SFTTrainer\n",
410
+ "\n",
411
+ "torch_dtype = model.dtype\n",
412
+ "\n",
413
+ "# Training configuration\n",
414
+ "args = SFTConfig(\n",
415
+ " output_dir=OUTPUT_DIR,\n",
416
+ " max_length=512,\n",
417
+ " packing=False,\n",
418
+ " num_train_epochs=NUM_EPOCHS,\n",
419
+ " per_device_train_batch_size=BATCH_SIZE,\n",
420
+ " gradient_checkpointing=False,\n",
421
+ " optim=\"adamw_torch_fused\",\n",
422
+ " logging_steps=1,\n",
423
+ " eval_strategy=\"epoch\",\n",
424
+ " save_strategy=\"epoch\",\n",
425
+ " learning_rate=LEARNING_RATE,\n",
426
+ " fp16=True if torch_dtype == torch.float16 else False,\n",
427
+ " bf16=True if torch_dtype == torch.bfloat16 else False,\n",
428
+ " lr_scheduler_type=\"constant\",\n",
429
+ " push_to_hub=True,\n",
430
+ " report_to=\"tensorboard\",\n",
431
+ " load_best_model_at_end=True,\n",
432
+ " metric_for_best_model=\"eval_loss\",\n",
433
+ ")\n",
434
+ "\n",
435
+ "# Create trainer\n",
436
+ "trainer = SFTTrainer(\n",
437
+ " model=model,\n",
438
+ " args=args,\n",
439
+ " train_dataset=dataset['train'],\n",
440
+ " eval_dataset=dataset['test'],\n",
441
+ " processing_class=tokenizer,\n",
442
+ ")\n",
443
+ "\n",
444
+ "print(\"Trainer created successfully!\")"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "# πŸš€ Start training!\n",
454
+ "print(\"Starting fine-tuning...\")\n",
455
+ "trainer.train()\n",
456
+ "\n",
457
+ "print(\"\\nβœ… Training complete!\")"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "# Save final model\n",
467
+ "trainer.save_model()\n",
468
+ "print(f\"Model saved to: {OUTPUT_DIR}\")"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "metadata": {},
474
+ "source": [
475
+ "## πŸ“ˆ 5. Visualize Results"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": null,
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": [
484
+ "import matplotlib.pyplot as plt\n",
485
+ "\n",
486
+ "# Extract loss history\n",
487
+ "log_history = trainer.state.log_history\n",
488
+ "\n",
489
+ "train_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n",
490
+ "epoch_train = [log[\"epoch\"] for log in log_history if \"loss\" in log]\n",
491
+ "eval_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n",
492
+ "epoch_eval = [log[\"epoch\"] for log in log_history if \"eval_loss\" in log]\n",
493
+ "\n",
494
+ "# Plot\n",
495
+ "plt.figure(figsize=(10, 6))\n",
496
+ "plt.plot(epoch_train, train_losses, label=\"Training Loss\", alpha=0.7)\n",
497
+ "plt.plot(epoch_eval, eval_losses, label=\"Validation Loss\", marker='o')\n",
498
+ "plt.xlabel(\"Epoch\")\n",
499
+ "plt.ylabel(\"Loss\")\n",
500
+ "plt.title(\"Training and Validation Loss\")\n",
501
+ "plt.legend()\n",
502
+ "plt.grid(True)\n",
503
+ "plt.show()"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "metadata": {},
509
+ "source": [
510
+ "## πŸ§ͺ 6. Post-Training Evaluation\n",
511
+ "\n",
512
+ "Now let's evaluate the fine-tuned model and compare it with the baseline to measure the improvement."
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "print(\"=\" * 50)\n",
522
+ "print(\"POST-TRAINING EVALUATION (Fine-tuned)\")\n",
523
+ "print(\"=\" * 50)\n",
524
+ "print(f\"\\nEvaluating fine-tuned model on {len(eval_test_cases)} test cases...\\n\")\n",
525
+ "\n",
526
+ "finetuned_results = evaluate_model(\n",
527
+ " model=model,\n",
528
+ " tokenizer=tokenizer,\n",
529
+ " test_samples=eval_test_cases,\n",
530
+ " tools=TOOLS,\n",
531
+ " system_prompt=SYSTEM_PROMPT\n",
532
+ ")\n",
533
+ "\n",
534
+ "# Show some example outputs\n",
535
+ "print(\"\\n--- Sample Outputs (Fine-tuned Model) ---\")\n",
536
+ "for i, detail in enumerate(finetuned_results[\"details\"][:4]):\n",
537
+ " status = \"βœ…\" if detail[\"tool_correct\"] else \"❌\"\n",
538
+ " print(f\"\\n{status} Input: {detail['input']}\")\n",
539
+ " print(f\" Expected: {detail['expected_tool']}\")\n",
540
+ " print(f\" Output: {detail['response'][:200]}...\")"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": null,
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": [
549
+ "# Compare baseline vs fine-tuned results\n",
550
+ "print(\"=\" * 60)\n",
551
+ "print(\"πŸ“Š COMPARISON: Baseline vs Fine-tuned\")\n",
552
+ "print(\"=\" * 60)\n",
553
+ "\n",
554
+ "print(f\"\\n{'Metric':<30} {'Baseline':>12} {'Fine-tuned':>12} {'Improvement':>12}\")\n",
555
+ "print(\"-\" * 66)\n",
556
+ "\n",
557
+ "# Tool accuracy comparison\n",
558
+ "tool_improvement = finetuned_results[\"tool_accuracy\"] - baseline_results[\"tool_accuracy\"]\n",
559
+ "print(f\"{'Tool Accuracy':<30} {baseline_results['tool_accuracy']:>11.1f}% {finetuned_results['tool_accuracy']:>11.1f}% {tool_improvement:>+11.1f}%\")\n",
560
+ "\n",
561
+ "# Full accuracy comparison\n",
562
+ "full_improvement = finetuned_results[\"full_accuracy\"] - baseline_results[\"full_accuracy\"]\n",
563
+ "print(f\"{'Full Accuracy (tool + args)':<30} {baseline_results['full_accuracy']:>11.1f}% {finetuned_results['full_accuracy']:>11.1f}% {full_improvement:>+11.1f}%\")\n",
564
+ "\n",
565
+ "print(\"-\" * 66)\n",
566
+ "\n",
567
+ "# Summary\n",
568
+ "if full_improvement > 0:\n",
569
+ " print(f\"\\nβœ… Fine-tuning improved accuracy by {full_improvement:.1f} percentage points!\")\n",
570
+ "elif full_improvement == 0:\n",
571
+ " print(f\"\\n⚠️ No change in accuracy. Consider adjusting training parameters.\")\n",
572
+ "else:\n",
573
+ " print(f\"\\n❌ Accuracy decreased. Check for overfitting or data issues.\")"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": null,
579
+ "metadata": {},
580
+ "outputs": [],
581
+ "source": [
582
+ "# Visualization: Baseline vs Fine-tuned comparison\n",
583
+ "import matplotlib.pyplot as plt\n",
584
+ "import numpy as np\n",
585
+ "\n",
586
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
587
+ "\n",
588
+ "# Chart 1: Bar chart comparison\n",
589
+ "metrics = ['Tool\\nAccuracy', 'Full\\nAccuracy']\n",
590
+ "baseline_vals = [baseline_results[\"tool_accuracy\"], baseline_results[\"full_accuracy\"]]\n",
591
+ "finetuned_vals = [finetuned_results[\"tool_accuracy\"], finetuned_results[\"full_accuracy\"]]\n",
592
+ "\n",
593
+ "x = np.arange(len(metrics))\n",
594
+ "width = 0.35\n",
595
+ "\n",
596
+ "bars1 = axes[0].bar(x - width/2, baseline_vals, width, label='Baseline', color='#ff6b6b', alpha=0.8)\n",
597
+ "bars2 = axes[0].bar(x + width/2, finetuned_vals, width, label='Fine-tuned', color='#4ecdc4', alpha=0.8)\n",
598
+ "\n",
599
+ "axes[0].set_ylabel('Accuracy (%)')\n",
600
+ "axes[0].set_title('Model Performance: Baseline vs Fine-tuned')\n",
601
+ "axes[0].set_xticks(x)\n",
602
+ "axes[0].set_xticklabels(metrics)\n",
603
+ "axes[0].legend()\n",
604
+ "axes[0].set_ylim(0, 110)\n",
605
+ "axes[0].axhline(y=100, color='gray', linestyle='--', alpha=0.3)\n",
606
+ "\n",
607
+ "# Add value labels on bars\n",
608
+ "for bar in bars1:\n",
609
+ " height = bar.get_height()\n",
610
+ " axes[0].annotate(f'{height:.1f}%', xy=(bar.get_x() + bar.get_width() / 2, height),\n",
611
+ " xytext=(0, 3), textcoords=\"offset points\", ha='center', va='bottom', fontsize=10)\n",
612
+ "for bar in bars2:\n",
613
+ " height = bar.get_height()\n",
614
+ " axes[0].annotate(f'{height:.1f}%', xy=(bar.get_x() + bar.get_width() / 2, height),\n",
615
+ " xytext=(0, 3), textcoords=\"offset points\", ha='center', va='bottom', fontsize=10)\n",
616
+ "\n",
617
+ "# Chart 2: Per-sample comparison\n",
618
+ "sample_labels = [d[\"input\"][:20] + \"...\" for d in baseline_results[\"details\"]]\n",
619
+ "baseline_correct = [1 if d[\"tool_correct\"] else 0 for d in baseline_results[\"details\"]]\n",
620
+ "finetuned_correct = [1 if d[\"tool_correct\"] else 0 for d in finetuned_results[\"details\"]]\n",
621
+ "\n",
622
+ "x2 = np.arange(len(sample_labels))\n",
623
+ "width2 = 0.35\n",
624
+ "\n",
625
+ "axes[1].barh(x2 - width2/2, baseline_correct, width2, label='Baseline', color='#ff6b6b', alpha=0.8)\n",
626
+ "axes[1].barh(x2 + width2/2, finetuned_correct, width2, label='Fine-tuned', color='#4ecdc4', alpha=0.8)\n",
627
+ "\n",
628
+ "axes[1].set_xlabel('Correct (1) / Incorrect (0)')\n",
629
+ "axes[1].set_title('Per-Sample Results')\n",
630
+ "axes[1].set_yticks(x2)\n",
631
+ "axes[1].set_yticklabels(sample_labels, fontsize=8)\n",
632
+ "axes[1].legend(loc='lower right')\n",
633
+ "axes[1].set_xlim(-0.1, 1.5)\n",
634
+ "\n",
635
+ "plt.tight_layout()\n",
636
+ "plt.show()\n",
637
+ "\n",
638
+ "# Print detailed per-sample comparison\n",
639
+ "print(\"\\nπŸ“‹ Detailed Per-Sample Comparison:\")\n",
640
+ "print(\"-\" * 80)\n",
641
+ "for i, (b, f) in enumerate(zip(baseline_results[\"details\"], finetuned_results[\"details\"])):\n",
642
+ " b_status = \"βœ…\" if b[\"tool_correct\"] else \"❌\"\n",
643
+ " f_status = \"βœ…\" if f[\"tool_correct\"] else \"❌\"\n",
644
+ " change = \"\"\n",
645
+ " if not b[\"tool_correct\"] and f[\"tool_correct\"]:\n",
646
+ " change = \" πŸŽ‰ FIXED!\"\n",
647
+ " elif b[\"tool_correct\"] and not f[\"tool_correct\"]:\n",
648
+ " change = \" ⚠️ REGRESSED\"\n",
649
+ " print(f\"{b['input'][:40]:<42} Base: {b_status} Fine-tuned: {f_status}{change}\")"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "metadata": {},
655
+ "source": [
656
+ "## πŸ“€ 7. Push to Hugging Face Hub"
657
+ ]
658
+ },
659
+ {
660
+ "cell_type": "code",
661
+ "execution_count": null,
662
+ "metadata": {},
663
+ "outputs": [],
664
+ "source": [
665
+ "# Push to Hub\n",
666
+ "trainer.push_to_hub()\n",
667
+ "\n",
668
+ "print(f\"\\nβœ… Model pushed to: https://huggingface.co/{trainer.hub_model_id}\")"
669
+ ]
670
+ }
671
+ ],
672
+ "metadata": {
673
+ "accelerator": "GPU",
674
+ "colab": {
675
+ "gpuType": "T4",
676
+ "provenance": []
677
+ },
678
+ "kernelspec": {
679
+ "display_name": "Python 3",
680
+ "language": "python",
681
+ "name": "python3"
682
+ },
683
+ "language_info": {
684
+ "name": "python",
685
+ "version": "3.10.0"
686
+ }
687
+ },
688
+ "nbformat": 4,
689
+ "nbformat_minor": 4
690
+ }
finetuning-functiongemma/finetune_functiongemma_v2.ipynb ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🎨 Fine-tuning FunctionGemma for Square Color Control\n",
8
+ "\n",
9
+ "This notebook demonstrates how to fine-tune FunctionGemma to recognize color control commands.\n",
10
+ "\n",
11
+ "**Author:** Harlley Oliveira\n",
12
+ "**Portfolio:** AI Engineering\n",
13
+ "\n",
14
+ "## Objectives\n",
15
+ "1. Train the model to call `set_square_color` when the user wants to change the color\n",
16
+ "2. Train the model to call `get_square_color` when the user asks about the current color\n",
17
+ "3. Support various natural language command styles"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## πŸ“¦ 1. Setup and Installation"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# Install dependencies\n",
34
+ "%pip install -q torch tensorboard\n",
35
+ "%pip install -q transformers datasets accelerate evaluate trl protobuf sentencepiece"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "# Login to Hugging Face Hub\n",
45
+ "from huggingface_hub import login\n",
46
+ "login()"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# Configuration\n",
56
+ "BASE_MODEL = \"google/functiongemma-270m-it\"\n",
57
+ "OUTPUT_DIR = \"functiongemma-square-color\"\n",
58
+ "LEARNING_RATE = 5e-5\n",
59
+ "NUM_EPOCHS = 8\n",
60
+ "BATCH_SIZE = 4"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "metadata": {},
66
+ "source": [
67
+ "## πŸ“Š 2. Prepare Dataset with Correct Format"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "import json\n",
77
+ "from datasets import Dataset\n",
78
+ "\n",
79
+ "# Tool definitions (same as before)\n",
80
+ "def set_square_color(color: str) -> str:\n",
81
+ " \"\"\"\n",
82
+ " Sets the color of the square displayed on the screen.\n",
83
+ " \n",
84
+ " Args:\n",
85
+ " color: The color to set, e.g. red, blue, green\n",
86
+ " \"\"\"\n",
87
+ " return f\"Color set to {color}\"\n",
88
+ "\n",
89
+ "def get_square_color() -> str:\n",
90
+ " \"\"\"\n",
91
+ " Returns the current color of the square.\n",
92
+ " Use this when the user asks about the current color.\n",
93
+ " \"\"\"\n",
94
+ " return \"Current color\"\n",
95
+ "\n",
96
+ "# Get JSON schemas\n",
97
+ "from transformers.utils import get_json_schema\n",
98
+ "TOOLS = [\n",
99
+ " get_json_schema(set_square_color),\n",
100
+ " get_json_schema(get_square_color)\n",
101
+ "]\n",
102
+ "\n",
103
+ "print(\"Tool schemas:\")\n",
104
+ "print(json.dumps(TOOLS, indent=2))"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "# Load training dataset\n",
114
+ "with open(\"dataset/square_color_dataset.json\", \"r\") as f:\n",
115
+ " square_color_dataset = json.load(f)\n",
116
+ "\n",
117
+ "print(f\"Total examples: {len(square_color_dataset)}\")\n",
118
+ "print(f\" - SET: {len([x for x in square_color_dataset if x['tool_name'] == 'set_square_color'])}\")\n",
119
+ "print(f\" - GET: {len([x for x in square_color_dataset if x['tool_name'] == 'get_square_color'])}\")\n",
120
+ "\n",
121
+ "# Preview first few examples\n",
122
+ "print(\"\\nFirst 3 examples:\")\n",
123
+ "for i, sample in enumerate(square_color_dataset[:3]):\n",
124
+ " print(f\" {i+1}. \\\"{sample['user_content']}\\\" β†’ {sample['tool_name']}\")"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "# CRITICAL: FunctionGemma's expected output format\n",
134
+ "# The model should output: <start_function_call>call:func{args}<end_function_call>\n",
135
+ "\n",
136
+ "SYSTEM_PROMPT = \"You are a model that can do function calling with the following functions\"\n",
137
+ "\n",
138
+ "def format_function_call_output(tool_name: str, tool_arguments: dict) -> str:\n",
139
+ " \"\"\"\n",
140
+ " Format the expected output in FunctionGemma's native format.\n",
141
+ " \n",
142
+ " FunctionGemma outputs: <start_function_call>call:func_name{arg:<escape>value<escape>}<end_function_call>\n",
143
+ " \"\"\"\n",
144
+ " if not tool_arguments:\n",
145
+ " # For functions with no arguments\n",
146
+ " return f\"<start_function_call>call:{tool_name}{{}}<end_function_call>\"\n",
147
+ " \n",
148
+ " # Format arguments with <escape> tokens for string values\n",
149
+ " args_parts = []\n",
150
+ " for key, value in tool_arguments.items():\n",
151
+ " if isinstance(value, str):\n",
152
+ " args_parts.append(f\"{key}:<escape>{value}<escape>\")\n",
153
+ " else:\n",
154
+ " args_parts.append(f\"{key}:{value}\")\n",
155
+ " \n",
156
+ " args_str = \",\".join(args_parts)\n",
157
+ " return f\"<start_function_call>call:{tool_name}{{{args_str}}}<end_function_call>\"\n",
158
+ "\n",
159
+ "# Test the format\n",
160
+ "print(\"Example outputs:\")\n",
161
+ "print(format_function_call_output(\"set_square_color\", {\"color\": \"blue\"}))\n",
162
+ "print(format_function_call_output(\"get_square_color\", {}))"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": null,
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "from transformers import AutoTokenizer\n",
172
+ "\n",
173
+ "# Load tokenizer first to use apply_chat_template\n",
174
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
175
+ "\n",
176
+ "def create_training_text(sample):\n",
177
+ " \"\"\"\n",
178
+ " Create the full training text using FunctionGemma's chat template.\n",
179
+ " \n",
180
+ " The key is that we format the assistant's response in FunctionGemma's\n",
181
+ " native function call format.\n",
182
+ " \"\"\"\n",
183
+ " tool_args = json.loads(sample[\"tool_arguments\"])\n",
184
+ " expected_output = format_function_call_output(sample[\"tool_name\"], tool_args)\n",
185
+ " \n",
186
+ " # Create messages - note: assistant content is the raw function call format\n",
187
+ " messages = [\n",
188
+ " {\"role\": \"developer\", \"content\": SYSTEM_PROMPT},\n",
189
+ " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n",
190
+ " {\"role\": \"assistant\", \"content\": expected_output},\n",
191
+ " ]\n",
192
+ " \n",
193
+ " # Apply chat template WITH tools to get proper function declarations\n",
194
+ " text = tokenizer.apply_chat_template(\n",
195
+ " messages,\n",
196
+ " tools=TOOLS,\n",
197
+ " tokenize=False,\n",
198
+ " add_generation_prompt=False\n",
199
+ " )\n",
200
+ " \n",
201
+ " return {\"text\": text}\n",
202
+ "\n",
203
+ "# Create dataset\n",
204
+ "dataset = Dataset.from_list(square_color_dataset)\n",
205
+ "dataset = dataset.map(create_training_text, remove_columns=dataset.features, batched=False)\n",
206
+ "\n",
207
+ "# Split 80/20\n",
208
+ "dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
209
+ "\n",
210
+ "print(f\"Train: {len(dataset['train'])} examples\")\n",
211
+ "print(f\"Test: {len(dataset['test'])} examples\")"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "# Visualize a formatted example\n",
221
+ "print(\"=\" * 60)\n",
222
+ "print(\"FORMATTED TRAINING EXAMPLE\")\n",
223
+ "print(\"=\" * 60)\n",
224
+ "print(dataset[\"train\"][0][\"text\"])\n",
225
+ "print(\"=\" * 60)"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": [
232
+ "## πŸ€– 3. Load Model"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "import torch\n",
242
+ "from transformers import AutoModelForCausalLM\n",
243
+ "\n",
244
+ "print(\"Loading model for fine-tuning...\")\n",
245
+ "\n",
246
+ "model = AutoModelForCausalLM.from_pretrained(\n",
247
+ " BASE_MODEL,\n",
248
+ " dtype=torch.bfloat16,\n",
249
+ " device_map=\"auto\",\n",
250
+ " attn_implementation=\"eager\"\n",
251
+ ")\n",
252
+ "\n",
253
+ "print(f\"Device: {model.device}\")\n",
254
+ "print(f\"DType: {model.dtype}\")\n",
255
+ "print(f\"Parameters: {model.num_parameters():,}\")"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {},
261
+ "source": [
262
+ "## πŸ§ͺ 3.5. Pre-Training Evaluation (Baseline)"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "import re\n",
272
+ "\n",
273
+ "def extract_function_call(text):\n",
274
+ " \"\"\"\n",
275
+ " Extract function call from FunctionGemma's output format.\n",
276
+ " Returns (function_name, arguments_dict) or (None, None) if not found.\n",
277
+ " \"\"\"\n",
278
+ " pattern = r\"<start_function_call>call:(\\w+)\\{(.*)\\}<end_function_call>\"\n",
279
+ " match = re.search(pattern, text, re.DOTALL)\n",
280
+ " \n",
281
+ " if not match:\n",
282
+ " return None, None\n",
283
+ " \n",
284
+ " func_name = match.group(1)\n",
285
+ " args_str = match.group(2)\n",
286
+ " \n",
287
+ " # Parse arguments\n",
288
+ " args = {}\n",
289
+ " if args_str.strip():\n",
290
+ " # Match key:<escape>value<escape> or key:value patterns\n",
291
+ " arg_pattern = r\"(\\w+):(?:<escape>(.*?)<escape>|([^,}]*))\"\n",
292
+ " for m in re.finditer(arg_pattern, args_str):\n",
293
+ " key = m.group(1)\n",
294
+ " value = m.group(2) if m.group(2) else m.group(3)\n",
295
+ " args[key] = value.strip() if value else \"\"\n",
296
+ " \n",
297
+ " return func_name, args\n",
298
+ "\n",
299
+ "def evaluate_model(model, tokenizer, test_samples, tools, system_prompt, verbose=True):\n",
300
+ " \"\"\"\n",
301
+ " Evaluate model on test samples using FunctionGemma's format.\n",
302
+ " \"\"\"\n",
303
+ " results = {\n",
304
+ " \"total\": len(test_samples),\n",
305
+ " \"correct\": 0,\n",
306
+ " \"correct_tool\": 0,\n",
307
+ " \"correct_args\": 0,\n",
308
+ " \"details\": []\n",
309
+ " }\n",
310
+ " \n",
311
+ " for sample in test_samples:\n",
312
+ " messages = [\n",
313
+ " {\"role\": \"developer\", \"content\": system_prompt},\n",
314
+ " {\"role\": \"user\", \"content\": sample[\"user_content\"]},\n",
315
+ " ]\n",
316
+ " \n",
317
+ " inputs = tokenizer.apply_chat_template(\n",
318
+ " messages,\n",
319
+ " tools=tools,\n",
320
+ " tokenize=True,\n",
321
+ " add_generation_prompt=True,\n",
322
+ " return_dict=True,\n",
323
+ " return_tensors=\"pt\"\n",
324
+ " ).to(model.device)\n",
325
+ " \n",
326
+ " with torch.no_grad():\n",
327
+ " output = model.generate(\n",
328
+ " **inputs,\n",
329
+ " max_new_tokens=128,\n",
330
+ " do_sample=False,\n",
331
+ " )\n",
332
+ " \n",
333
+ " input_length = inputs['input_ids'].shape[1]\n",
334
+ " response = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)\n",
335
+ " \n",
336
+ " # Parse the function call from response\n",
337
+ " called_func, called_args = extract_function_call(response)\n",
338
+ " \n",
339
+ " # Check if correct tool was called\n",
340
+ " tool_correct = called_func == sample[\"tool_name\"]\n",
341
+ " \n",
342
+ " # Check arguments\n",
343
+ " args_correct = False\n",
344
+ " expected_args = json.loads(sample[\"tool_arguments\"])\n",
345
+ " \n",
346
+ " if tool_correct:\n",
347
+ " if sample[\"tool_name\"] == \"get_square_color\":\n",
348
+ " args_correct = True # No args needed\n",
349
+ " elif called_args and \"color\" in called_args:\n",
350
+ " args_correct = called_args.get(\"color\", \"\").lower() == expected_args.get(\"color\", \"\").lower()\n",
351
+ " \n",
352
+ " if tool_correct:\n",
353
+ " results[\"correct_tool\"] += 1\n",
354
+ " if tool_correct and args_correct:\n",
355
+ " results[\"correct\"] += 1\n",
356
+ " results[\"correct_args\"] += 1\n",
357
+ " \n",
358
+ " results[\"details\"].append({\n",
359
+ " \"input\": sample[\"user_content\"],\n",
360
+ " \"expected_tool\": sample[\"tool_name\"],\n",
361
+ " \"expected_args\": sample[\"tool_arguments\"],\n",
362
+ " \"called_func\": called_func,\n",
363
+ " \"called_args\": called_args,\n",
364
+ " \"response\": response,\n",
365
+ " \"tool_correct\": tool_correct,\n",
366
+ " \"args_correct\": args_correct\n",
367
+ " })\n",
368
+ " \n",
369
+ " results[\"tool_accuracy\"] = results[\"correct_tool\"] / results[\"total\"] * 100\n",
370
+ " results[\"full_accuracy\"] = results[\"correct\"] / results[\"total\"] * 100\n",
371
+ " \n",
372
+ " if verbose:\n",
373
+ " print(f\"Tool Accuracy: {results['correct_tool']}/{results['total']} ({results['tool_accuracy']:.1f}%)\")\n",
374
+ " print(f\"Full Accuracy (tool + args): {results['correct']}/{results['total']} ({results['full_accuracy']:.1f}%)\")\n",
375
+ " \n",
376
+ " return results"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "# Create evaluation test set\n",
386
+ "import random\n",
387
+ "\n",
388
+ "random.seed(42)\n",
389
+ "\n",
390
+ "set_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"set_square_color\"]\n",
391
+ "get_samples = [s for s in square_color_dataset if s[\"tool_name\"] == \"get_square_color\"]\n",
392
+ "\n",
393
+ "test_cases = 25\n",
394
+ "eval_test_cases = random.sample(set_samples, min(test_cases, len(set_samples))) + \\\n",
395
+ " random.sample(get_samples, min(test_cases, len(get_samples)))\n",
396
+ "\n",
397
+ "print(\"=\" * 50)\n",
398
+ "print(\"PRE-TRAINING EVALUATION (Baseline)\")\n",
399
+ "print(\"=\" * 50)\n",
400
+ "print(f\"\\nEvaluating base model on {len(eval_test_cases)} test cases...\\n\")\n",
401
+ "\n",
402
+ "baseline_results = evaluate_model(\n",
403
+ " model=model,\n",
404
+ " tokenizer=tokenizer,\n",
405
+ " test_samples=eval_test_cases,\n",
406
+ " tools=TOOLS,\n",
407
+ " system_prompt=SYSTEM_PROMPT\n",
408
+ ")\n",
409
+ "\n",
410
+ "# Show sample outputs\n",
411
+ "print(\"\\n--- Sample Outputs (Base Model) ---\")\n",
412
+ "for i, detail in enumerate(baseline_results[\"details\"][:4]):\n",
413
+ " status = \"βœ…\" if detail[\"tool_correct\"] else \"❌\"\n",
414
+ " print(f\"\\n{status} Input: {detail['input']}\")\n",
415
+ " print(f\" Expected: {detail['expected_tool']}\")\n",
416
+ " print(f\" Got: {detail['called_func']} with args {detail['called_args']}\")"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "metadata": {},
422
+ "source": [
423
+ "## πŸ”₯ 4. Fine-tuning"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "from trl import SFTConfig, SFTTrainer\n",
433
+ "\n",
434
+ "torch_dtype = model.dtype\n",
435
+ "\n",
436
+ "# Training configuration\n",
437
+ "args = SFTConfig(\n",
438
+ " output_dir=OUTPUT_DIR,\n",
439
+ " max_length=512,\n",
440
+ " packing=False,\n",
441
+ " num_train_epochs=NUM_EPOCHS,\n",
442
+ " per_device_train_batch_size=BATCH_SIZE,\n",
443
+ " gradient_checkpointing=False,\n",
444
+ " optim=\"adamw_torch_fused\",\n",
445
+ " logging_steps=1,\n",
446
+ " eval_strategy=\"epoch\",\n",
447
+ " save_strategy=\"epoch\",\n",
448
+ " learning_rate=LEARNING_RATE,\n",
449
+ " fp16=True if torch_dtype == torch.float16 else False,\n",
450
+ " bf16=True if torch_dtype == torch.bfloat16 else False,\n",
451
+ " lr_scheduler_type=\"constant\",\n",
452
+ " push_to_hub=True,\n",
453
+ " report_to=\"tensorboard\",\n",
454
+ " load_best_model_at_end=True,\n",
455
+ " metric_for_best_model=\"eval_loss\",\n",
456
+ " dataset_text_field=\"text\", # IMPORTANT: specify the text field\n",
457
+ ")\n",
458
+ "\n",
459
+ "# Create trainer\n",
460
+ "trainer = SFTTrainer(\n",
461
+ " model=model,\n",
462
+ " args=args,\n",
463
+ " train_dataset=dataset['train'],\n",
464
+ " eval_dataset=dataset['test'],\n",
465
+ " processing_class=tokenizer,\n",
466
+ ")\n",
467
+ "\n",
468
+ "print(\"Trainer created successfully!\")"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": [
477
+ "# πŸš€ Start training!\n",
478
+ "print(\"Starting fine-tuning...\")\n",
479
+ "trainer.train()\n",
480
+ "\n",
481
+ "print(\"\\nβœ… Training complete!\")"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": [
490
+ "# Save final model in the original dtype (BF16)\n",
491
+ "# This prevents the model from being saved as FP32 (which doubles the size)\n",
492
+ "model.save_pretrained(OUTPUT_DIR, safe_serialization=True)\n",
493
+ "tokenizer.save_pretrained(OUTPUT_DIR)\n",
494
+ "print(f\"Model saved to: {OUTPUT_DIR}\")"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "markdown",
499
+ "metadata": {},
500
+ "source": [
501
+ "## πŸ“ˆ 5. Visualize Results"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "import matplotlib.pyplot as plt\n",
511
+ "\n",
512
+ "# Extract loss history\n",
513
+ "log_history = trainer.state.log_history\n",
514
+ "\n",
515
+ "train_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n",
516
+ "epoch_train = [log[\"epoch\"] for log in log_history if \"loss\" in log]\n",
517
+ "eval_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n",
518
+ "epoch_eval = [log[\"epoch\"] for log in log_history if \"eval_loss\" in log]\n",
519
+ "\n",
520
+ "# Plot\n",
521
+ "plt.figure(figsize=(10, 6))\n",
522
+ "plt.plot(epoch_train, train_losses, label=\"Training Loss\", alpha=0.7)\n",
523
+ "plt.plot(epoch_eval, eval_losses, label=\"Validation Loss\", marker='o')\n",
524
+ "plt.xlabel(\"Epoch\")\n",
525
+ "plt.ylabel(\"Loss\")\n",
526
+ "plt.title(\"Training and Validation Loss\")\n",
527
+ "plt.legend()\n",
528
+ "plt.grid(True)\n",
529
+ "plt.show()"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "markdown",
534
+ "metadata": {},
535
+ "source": [
536
+ "## πŸ§ͺ 6. Post-Training Evaluation"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "code",
541
+ "execution_count": null,
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "print(\"=\" * 50)\n",
546
+ "print(\"POST-TRAINING EVALUATION (Fine-tuned)\")\n",
547
+ "print(\"=\" * 50)\n",
548
+ "print(f\"\\nEvaluating fine-tuned model on {len(eval_test_cases)} test cases...\\n\")\n",
549
+ "\n",
550
+ "finetuned_results = evaluate_model(\n",
551
+ " model=model,\n",
552
+ " tokenizer=tokenizer,\n",
553
+ " test_samples=eval_test_cases,\n",
554
+ " tools=TOOLS,\n",
555
+ " system_prompt=SYSTEM_PROMPT\n",
556
+ ")\n",
557
+ "\n",
558
+ "# Show sample outputs\n",
559
+ "print(\"\\n--- Sample Outputs (Fine-tuned Model) ---\")\n",
560
+ "for i, detail in enumerate(finetuned_results[\"details\"][:4]):\n",
561
+ " status = \"βœ…\" if detail[\"tool_correct\"] else \"❌\"\n",
562
+ " print(f\"\\n{status} Input: {detail['input']}\")\n",
563
+ " print(f\" Expected: {detail['expected_tool']}\")\n",
564
+ " print(f\" Got: {detail['called_func']} with args {detail['called_args']}\")"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "metadata": {},
571
+ "outputs": [],
572
+ "source": [
573
+ "# Compare baseline vs fine-tuned\n",
574
+ "print(\"=\" * 60)\n",
575
+ "print(\"πŸ“Š COMPARISON: Baseline vs Fine-tuned\")\n",
576
+ "print(\"=\" * 60)\n",
577
+ "\n",
578
+ "print(f\"\\n{'Metric':<30} {'Baseline':>12} {'Fine-tuned':>12} {'Improvement':>12}\")\n",
579
+ "print(\"-\" * 66)\n",
580
+ "\n",
581
+ "tool_improvement = finetuned_results[\"tool_accuracy\"] - baseline_results[\"tool_accuracy\"]\n",
582
+ "print(f\"{'Tool Accuracy':<30} {baseline_results['tool_accuracy']:>11.1f}% {finetuned_results['tool_accuracy']:>11.1f}% {tool_improvement:>+11.1f}%\")\n",
583
+ "\n",
584
+ "full_improvement = finetuned_results[\"full_accuracy\"] - baseline_results[\"full_accuracy\"]\n",
585
+ "print(f\"{'Full Accuracy (tool + args)':<30} {baseline_results['full_accuracy']:>11.1f}% {finetuned_results['full_accuracy']:>11.1f}% {full_improvement:>+11.1f}%\")\n",
586
+ "\n",
587
+ "print(\"-\" * 66)\n",
588
+ "\n",
589
+ "if full_improvement > 0:\n",
590
+ " print(f\"\\nβœ… Fine-tuning improved accuracy by {full_improvement:.1f} percentage points!\")\n",
591
+ "elif full_improvement == 0:\n",
592
+ " print(f\"\\n⚠️ No change in accuracy.\")\n",
593
+ "else:\n",
594
+ " print(f\"\\n❌ Accuracy decreased. Check for overfitting or data issues.\")"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "markdown",
599
+ "metadata": {},
600
+ "source": [
601
+ "## πŸ“€ 7. Push to Hugging Face Hub"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": null,
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": [
610
+ "# Push to Hub\n",
611
+ "trainer.push_to_hub()\n",
612
+ "\n",
613
+ "print(f\"\\nβœ… Model pushed to: https://huggingface.co/{trainer.hub_model_id}\")"
614
+ ]
615
+ }
616
+ ],
617
+ "metadata": {
618
+ "accelerator": "GPU",
619
+ "colab": {
620
+ "gpuType": "T4",
621
+ "provenance": []
622
+ },
623
+ "kernelspec": {
624
+ "display_name": "Python 3",
625
+ "language": "python",
626
+ "name": "python3"
627
+ },
628
+ "language_info": {
629
+ "name": "python",
630
+ "version": "3.10.0"
631
+ }
632
+ },
633
+ "nbformat": 4,
634
+ "nbformat_minor": 4
635
+ }
src/worker.ts CHANGED
@@ -31,7 +31,7 @@ function getModel(onProgress?: ProgressCallback) {
31
  progress_callback: onProgress,
32
  }),
33
  AutoModelForCausalLM.from_pretrained(MODEL_ID, {
34
- dtype: "q4",
35
  device: "webgpu",
36
  progress_callback: onProgress,
37
  }),
 
31
  progress_callback: onProgress,
32
  }),
33
  AutoModelForCausalLM.from_pretrained(MODEL_ID, {
34
+ dtype: "fp16",
35
  device: "webgpu",
36
  progress_callback: onProgress,
37
  }),