harlley commited on
Commit
84928fb
·
1 Parent(s): d4f215b

add complete sequence tool calling with no hard-coded responses

Browse files
.claude/settings.local.json CHANGED
@@ -3,7 +3,8 @@
3
  "allow": [
4
  "WebSearch",
5
  "Bash(npm run typecheck:*)",
6
- "Bash(npm run lint:*)"
 
7
  ]
8
  }
9
  }
 
3
  "allow": [
4
  "WebSearch",
5
  "Bash(npm run typecheck:*)",
6
+ "Bash(npm run lint:*)",
7
+ "WebFetch(domain:ai.google.dev)"
8
  ]
9
  }
10
  }
src/components/AgenticInterface.tsx CHANGED
@@ -16,7 +16,8 @@ const INITIAL_MESSAGES: Message[] = [
16
 
17
  export function AgenticInterface() {
18
  const { squareColor, setSquareColor } = useColorStore();
19
- const { processMessage, isProcessing, loadingStatus } = useFunctionCalling();
 
20
  const [messages, setMessages] = useState<Message[]>(INITIAL_MESSAGES);
21
 
22
  const handleSendMessage = async (text: string) => {
@@ -31,13 +32,20 @@ export function AgenticInterface() {
31
  const color = result.functionCall.args.color as string | undefined;
32
  if (color) {
33
  setSquareColor(color);
34
- botResponse = `Done! I've changed the square color to ${color}.`;
 
 
 
 
35
  } else {
36
  botResponse =
37
  "I understood you want to change the color, but I couldn't determine which color. Please try again with a specific color.";
38
  }
39
  } else if (result.functionCall?.functionName === "getSquareColor") {
40
- botResponse = `The current color of the square is ${squareColor}.`;
 
 
 
41
  } else if (result.textResponse) {
42
  botResponse = result.textResponse;
43
  } else {
 
16
 
17
  export function AgenticInterface() {
18
  const { squareColor, setSquareColor } = useColorStore();
19
+ const { processMessage, continueWithToolResult, isProcessing, loadingStatus } =
20
+ useFunctionCalling();
21
  const [messages, setMessages] = useState<Message[]>(INITIAL_MESSAGES);
22
 
23
  const handleSendMessage = async (text: string) => {
 
32
  const color = result.functionCall.args.color as string | undefined;
33
  if (color) {
34
  setSquareColor(color);
35
+ // Get response from model after executing the function
36
+ botResponse = await continueWithToolResult("set_square_color", {
37
+ success: true,
38
+ color,
39
+ });
40
  } else {
41
  botResponse =
42
  "I understood you want to change the color, but I couldn't determine which color. Please try again with a specific color.";
43
  }
44
  } else if (result.functionCall?.functionName === "getSquareColor") {
45
+ // Get response from model after executing the function
46
+ botResponse = await continueWithToolResult("get_square_color", {
47
+ color: squareColor,
48
+ });
49
  } else if (result.textResponse) {
50
  botResponse = result.textResponse;
51
  } else {
src/hooks/useFunctionCalling.ts CHANGED
@@ -138,5 +138,20 @@ export function useFunctionCalling() {
138
  [initModel],
139
  );
140
 
141
- return { processMessage, isProcessing, loadingStatus };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  }
 
138
  [initModel],
139
  );
140
 
141
+ const continueWithToolResult = useCallback(
142
+ async (functionName: string, functionResult: unknown): Promise<string> => {
143
+ const api = apiRef.current;
144
+ if (!api) return "Error: Worker not available";
145
+
146
+ setIsProcessing(true);
147
+ try {
148
+ return await api.continueWithToolResult(functionName, functionResult);
149
+ } finally {
150
+ setIsProcessing(false);
151
+ }
152
+ },
153
+ [],
154
+ );
155
+
156
+ return { processMessage, continueWithToolResult, isProcessing, loadingStatus };
157
  }
src/worker.ts CHANGED
@@ -5,6 +5,7 @@ import {
5
  type PreTrainedTokenizer,
6
  } from "@huggingface/transformers";
7
  import * as Comlink from "comlink";
 
8
 
9
  const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
10
 
@@ -23,6 +24,42 @@ export type LoadingProgress = {
23
 
24
  type ProgressCallback = (progress: LoadingProgress) => void;
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  function getModel(onProgress?: ProgressCallback) {
27
  modelPromise ??= (async () => {
28
  console.log("[Worker] Loading model...");
@@ -79,11 +116,6 @@ const tools = [
79
  const developerPrompt =
80
  "You are a model that can do function calling with the following functions";
81
 
82
- export type FunctionCallResult = {
83
- functionName: string;
84
- args: Record<string, unknown>;
85
- } | null;
86
-
87
  // Parses functiongemma output: <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
88
  function parseFunctionCall(output: string): FunctionCallResult {
89
  const match = output.match(
@@ -186,12 +218,21 @@ const workerAPI = {
186
  const functionCall = parseFunctionCall(decoded as string);
187
  console.log("[Worker] Parsed function call:", functionCall);
188
 
189
- // If function was called, don't return text response (it's usually garbage)
190
  if (functionCall) {
191
  console.log(
192
  "[Worker] ✅ Function call successful:",
193
  functionCall.functionName,
194
  );
 
 
 
 
 
 
 
 
 
195
  return { functionCall };
196
  }
197
 
@@ -216,6 +257,119 @@ const workerAPI = {
216
 
217
  return { functionCall: null, textResponse: textResponse || undefined };
218
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  };
220
 
221
  export type WorkerAPI = typeof workerAPI;
 
5
  type PreTrainedTokenizer,
6
  } from "@huggingface/transformers";
7
  import * as Comlink from "comlink";
8
+ import { createStore } from "zustand/vanilla";
9
 
10
  const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
11
 
 
24
 
25
  type ProgressCallback = (progress: LoadingProgress) => void;
26
 
27
+ // Types for FunctionGemma multi-turn conversation
28
+ type ToolCall = {
29
+ type: "function";
30
+ function: {
31
+ name: string;
32
+ arguments: Record<string, unknown>;
33
+ };
34
+ };
35
+
36
+ type ChatMessage =
37
+ | { role: "developer" | "user"; content: string }
38
+ | { role: "assistant"; content?: string; tool_calls?: ToolCall[] }
39
+ | { role: "tool"; content: Array<{ name: string; response: unknown }> };
40
+
41
+ export type FunctionCallResult = {
42
+ functionName: string;
43
+ args: Record<string, unknown>;
44
+ } | null;
45
+
46
+ type ConversationState = {
47
+ messages: ChatMessage[];
48
+ lastFunctionCall: FunctionCallResult;
49
+ setMessages: (messages: ChatMessage[]) => void;
50
+ setLastFunctionCall: (fc: FunctionCallResult) => void;
51
+ clear: () => void;
52
+ };
53
+
54
+ // Store to maintain conversation state between processMessage and continueWithToolResult
55
+ const conversationStore = createStore<ConversationState>((set) => ({
56
+ messages: [],
57
+ lastFunctionCall: null,
58
+ setMessages: (messages) => set({ messages }),
59
+ setLastFunctionCall: (fc) => set({ lastFunctionCall: fc }),
60
+ clear: () => set({ messages: [], lastFunctionCall: null }),
61
+ }));
62
+
63
  function getModel(onProgress?: ProgressCallback) {
64
  modelPromise ??= (async () => {
65
  console.log("[Worker] Loading model...");
 
116
  const developerPrompt =
117
  "You are a model that can do function calling with the following functions";
118
 
 
 
 
 
 
119
  // Parses functiongemma output: <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
120
  function parseFunctionCall(output: string): FunctionCallResult {
121
  const match = output.match(
 
218
  const functionCall = parseFunctionCall(decoded as string);
219
  console.log("[Worker] Parsed function call:", functionCall);
220
 
221
+ // If function was called, save conversation state for continueWithToolResult
222
  if (functionCall) {
223
  console.log(
224
  "[Worker] ✅ Function call successful:",
225
  functionCall.functionName,
226
  );
227
+
228
+ // Save conversation state for the next turn
229
+ const { setMessages, setLastFunctionCall } = conversationStore.getState();
230
+ setMessages([
231
+ { role: "developer", content: developerPrompt },
232
+ { role: "user", content: text },
233
+ ]);
234
+ setLastFunctionCall(functionCall);
235
+
236
  return { functionCall };
237
  }
238
 
 
257
 
258
  return { functionCall: null, textResponse: textResponse || undefined };
259
  },
260
+
261
+ async continueWithToolResult(
262
+ functionName: string,
263
+ functionResult: unknown,
264
+ ): Promise<string> {
265
+ console.log("[Worker] ========== CONTINUE WITH TOOL RESULT ==========");
266
+ console.log("[Worker] Function:", functionName);
267
+ console.log("[Worker] Result:", functionResult);
268
+
269
+ const { messages, lastFunctionCall, clear } = conversationStore.getState();
270
+
271
+ if (messages.length === 0 || !lastFunctionCall) {
272
+ console.log("[Worker] ⚠️ No conversation state found");
273
+ return "I apologize, but I lost track of our conversation. Could you please try again?";
274
+ }
275
+
276
+ const { model, tokenizer } = await getModel();
277
+
278
+ // Build the full conversation with tool result
279
+ // Convert camelCase back to snake_case for the model
280
+ const snakeCaseName =
281
+ lastFunctionCall.functionName === "setSquareColor"
282
+ ? "set_square_color"
283
+ : lastFunctionCall.functionName === "getSquareColor"
284
+ ? "get_square_color"
285
+ : functionName;
286
+
287
+ // Build messages in the format expected by the tokenizer
288
+ // Note: Using 'as unknown' to bypass strict typing since the tokenizer
289
+ // accepts a more flexible format than what TypeScript infers
290
+ const fullMessages = [
291
+ ...messages,
292
+ // Assistant turn with tool_calls
293
+ {
294
+ role: "assistant",
295
+ tool_calls: [
296
+ {
297
+ type: "function",
298
+ function: {
299
+ name: snakeCaseName,
300
+ arguments: lastFunctionCall.args,
301
+ },
302
+ },
303
+ ],
304
+ },
305
+ // Tool response
306
+ {
307
+ role: "tool",
308
+ content: [{ name: snakeCaseName, response: functionResult }],
309
+ },
310
+ ] as unknown[];
311
+
312
+ console.log(
313
+ "[Worker] Full messages for final response:",
314
+ JSON.stringify(fullMessages, null, 2),
315
+ );
316
+
317
+ // biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
318
+ const inputs = tokenizer.apply_chat_template(fullMessages as any, {
319
+ tools,
320
+ tokenize: true,
321
+ add_generation_prompt: true,
322
+ return_dict: true,
323
+ }) as { input_ids: { dims: number[]; data: unknown } };
324
+
325
+ console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
326
+
327
+ try {
328
+ const inputIds = Array.from(inputs.input_ids.data as Iterable<number>);
329
+ const decodedInput = tokenizer.decode(inputIds, {
330
+ skip_special_tokens: false,
331
+ });
332
+ console.log(
333
+ "[Worker] Decoded input for final response:",
334
+ decodedInput,
335
+ );
336
+ } catch (e) {
337
+ console.log("[Worker] Could not decode input:", e);
338
+ }
339
+
340
+ // biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
341
+ const output = await (model.generate as Function)({
342
+ ...inputs,
343
+ max_new_tokens: 256,
344
+ });
345
+
346
+ const inputLength = inputs.input_ids.dims[1];
347
+ const decoded = tokenizer.decode(output.slice(0, [inputLength, null]), {
348
+ skip_special_tokens: false,
349
+ });
350
+
351
+ console.log("[Worker] Raw final response:", decoded);
352
+
353
+ // Clear conversation state
354
+ clear();
355
+
356
+ // Clean up the response
357
+ const cleanResponse = (decoded as string)
358
+ .replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
359
+ .replace(/<start_function_response>[\s\S]*?<end_function_response>/g, "")
360
+ .replace(/<\|im_end\|>/g, "")
361
+ .replace(/<\|im_start\|>/g, "")
362
+ .replace(/<\|endoftext\|>/g, "")
363
+ .replace(/<end_of_turn>/g, "")
364
+ .replace(/<start_of_turn>/g, "")
365
+ .replace(/^model\s*/i, "")
366
+ .replace(/^assistant\s*/i, "")
367
+ .trim();
368
+
369
+ console.log("[Worker] ✅ Final response:", cleanResponse);
370
+
371
+ return cleanResponse || "Done!";
372
+ },
373
  };
374
 
375
  export type WorkerAPI = typeof workerAPI;