harlley commited on
Commit
e214ac8
·
1 Parent(s): c4197e7

refactor: create generic useChat hook and simplify architecture

Browse files

- Add useChat hook that receives tools as parameter and handles UI + ML state
- Create unified Zustand store (useChatStore) replacing 2 stores
- Simplify worker to only handle model loading and inference (~90 lines)
- Extract function calling logic to lib/functionCalling.ts
- Move tool definitions to config/tools.ts
- Add shared types in types/chat.ts

package.json CHANGED
@@ -7,7 +7,7 @@
7
  "dev": "vite",
8
  "build": "tsc -b && vite build",
9
  "preview": "vite preview",
10
- "typecheck": "tsc --noEmit",
11
  "lint": "biome lint ./src",
12
  "format": "biome format --write ./src",
13
  "check": "biome check --write ./src"
 
7
  "dev": "vite",
8
  "build": "tsc -b && vite build",
9
  "preview": "vite preview",
10
+ "typecheck": "tsc -b",
11
  "lint": "biome lint ./src",
12
  "format": "biome format --write ./src",
13
  "check": "biome check --write ./src"
src/components/AgenticInterface.tsx CHANGED
@@ -1,75 +1,40 @@
1
- import { useState } from "react";
2
- import { useFunctionCalling } from "@/hooks/useFunctionCalling";
3
- import { useColorStore } from "@/store/useColorStore";
 
 
4
  import { ChatSidebar } from "./ChatSidebar";
5
  import { ColorControlForm } from "./ColorControlForm";
6
  import { ColorVisualizer } from "./ColorVisualizer";
7
- import type { Message } from "./MessageBubble";
8
-
9
- const INITIAL_MESSAGES: Message[] = [
10
- {
11
- id: 1,
12
- text: "Hi! I can change the square color for you. Ask me to set a color or ask what color it currently is!",
13
- sender: "bot",
14
- },
15
- ];
16
 
17
  export function AgenticInterface() {
18
- const { squareColor, setSquareColor } = useColorStore();
19
- const { processMessage, continueWithToolResult, isProcessing, loadingStatus } =
20
- useFunctionCalling();
21
- const [messages, setMessages] = useState<Message[]>(INITIAL_MESSAGES);
22
-
23
- const handleSendMessage = async (text: string) => {
24
- const userMessage: Message = { id: Date.now(), text, sender: "user" };
25
- setMessages((prev) => [...prev, userMessage]);
26
-
27
- let botResponse: string;
28
- try {
29
- const result = await processMessage(text);
30
-
31
- if (result.functionCall?.functionName === "setSquareColor") {
32
- const color = result.functionCall.args.color as string | undefined;
33
- if (color) {
34
- setSquareColor(color);
35
- // Get response from model after executing the function
36
- botResponse = await continueWithToolResult("set_square_color", {
37
- success: true,
38
- color,
39
- });
40
- } else {
41
- botResponse =
42
- "I understood you want to change the color, but I couldn't determine which color. Please try again with a specific color.";
43
- }
44
- } else if (result.functionCall?.functionName === "getSquareColor") {
45
- // Get response from model after executing the function
46
- botResponse = await continueWithToolResult("get_square_color", {
47
- color: squareColor,
48
- });
49
- } else if (result.textResponse) {
50
- botResponse = result.textResponse;
51
- } else {
52
- botResponse =
53
- "Sorry, I had trouble understanding that. Could you please try again?";
54
  }
55
- } catch {
56
- botResponse =
57
- "Sorry, there was an error processing your request. Please try again.";
58
- }
59
 
60
- const botMessage: Message = {
61
- id: Date.now() + 1,
62
- text: botResponse,
63
- sender: "bot",
64
- };
65
- setMessages((prev) => [...prev, botMessage]);
66
- };
67
 
68
  return (
69
  <div className="flex flex-col md:flex-row h-screen w-screen overflow-hidden bg-background text-foreground font-sans selection:bg-primary/30">
70
  <ChatSidebar
71
  messages={messages}
72
- onSendMessage={handleSendMessage}
73
  isLoading={isProcessing || loadingStatus.isLoading}
74
  modelReady={loadingStatus.isModelReady}
75
  progress={loadingStatus.downloadProgress}
 
1
+ import { useCallback } from "react";
2
+ import { colorTools } from "@/config/tools";
3
+ import { useChat } from "@/hooks/useChat";
4
+ import { useChatStore } from "@/store/useChatStore";
5
+ import type { FunctionCallResult } from "@/types/chat";
6
  import { ChatSidebar } from "./ChatSidebar";
7
  import { ColorControlForm } from "./ColorControlForm";
8
  import { ColorVisualizer } from "./ColorVisualizer";
 
 
 
 
 
 
 
 
 
9
 
10
  export function AgenticInterface() {
11
+ const { squareColor, setSquareColor } = useChatStore();
12
+
13
+ const handleFunctionCall = useCallback(
14
+ async (fc: FunctionCallResult) => {
15
+ if (fc?.functionName === "setSquareColor") {
16
+ const color = fc.args.color as string;
17
+ setSquareColor(color);
18
+ return { success: true, color };
19
+ }
20
+ if (fc?.functionName === "getSquareColor") {
21
+ return { color: squareColor };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
23
+ return { error: "Unknown function" };
24
+ },
25
+ [squareColor, setSquareColor],
26
+ );
27
 
28
+ const { messages, sendMessage, isProcessing, loadingStatus } = useChat({
29
+ tools: colorTools,
30
+ onFunctionCall: handleFunctionCall,
31
+ });
 
 
 
32
 
33
  return (
34
  <div className="flex flex-col md:flex-row h-screen w-screen overflow-hidden bg-background text-foreground font-sans selection:bg-primary/30">
35
  <ChatSidebar
36
  messages={messages}
37
+ onSendMessage={sendMessage}
38
  isLoading={isProcessing || loadingStatus.isLoading}
39
  modelReady={loadingStatus.isModelReady}
40
  progress={loadingStatus.downloadProgress}
src/components/MessageBubble.tsx CHANGED
@@ -1,15 +1,11 @@
1
  import { IconRobot, IconUser } from "@tabler/icons-react";
 
2
 
3
- type MessageSender = "user" | "bot";
4
-
5
- type Message = {
6
- id: number;
7
- text: string;
8
- sender: MessageSender;
9
- };
10
 
11
  type MessageBubbleProps = {
12
- message: Message;
13
  };
14
 
15
  export function MessageBubble({ message }: MessageBubbleProps) {
@@ -41,5 +37,3 @@ export function MessageBubble({ message }: MessageBubbleProps) {
41
  </div>
42
  );
43
  }
44
-
45
- export type { Message, MessageSender };
 
1
  import { IconRobot, IconUser } from "@tabler/icons-react";
2
+ import type { UIMessage } from "@/types/chat";
3
 
4
+ // Re-export for backwards compatibility
5
+ export type Message = UIMessage;
 
 
 
 
 
6
 
7
  type MessageBubbleProps = {
8
+ message: UIMessage;
9
  };
10
 
11
  export function MessageBubble({ message }: MessageBubbleProps) {
 
37
  </div>
38
  );
39
  }
 
 
src/config/tools.ts ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { ToolDefinition } from "@/types/chat";
2
+
3
+ export const colorTools: ToolDefinition[] = [
4
+ {
5
+ type: "function",
6
+ function: {
7
+ name: "set_square_color",
8
+ description: "Sets the color of the square displayed on the screen.",
9
+ parameters: {
10
+ type: "object",
11
+ properties: {
12
+ color: {
13
+ type: "string",
14
+ description: "The color to set, e.g. red, blue, green",
15
+ },
16
+ },
17
+ required: ["color"],
18
+ },
19
+ },
20
+ },
21
+ {
22
+ type: "function",
23
+ function: {
24
+ name: "get_square_color",
25
+ description:
26
+ "Returns the current color of the square. Use this when the user asks 'what color is the square' or 'tell me the color'.",
27
+ parameters: {
28
+ type: "object",
29
+ properties: {},
30
+ required: [],
31
+ },
32
+ },
33
+ },
34
+ ];
src/hooks/{useFunctionCalling.ts → useChat.ts} RENAMED
@@ -1,6 +1,20 @@
1
  import * as Comlink from "comlink";
2
  import { useCallback, useEffect, useRef, useState } from "react";
3
- import type { FunctionCallResult, LoadingProgress, WorkerAPI } from "@/worker";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  export type LoadingStatus = {
6
  isLoading: boolean;
@@ -10,7 +24,12 @@ export type LoadingStatus = {
10
  error?: string;
11
  };
12
 
13
- export function useFunctionCalling() {
 
 
 
 
 
14
  const workerRef = useRef<Worker | null>(null);
15
  const apiRef = useRef<Comlink.Remote<WorkerAPI> | null>(null);
16
  const initializationPromiseRef = useRef<Promise<void> | null>(null);
@@ -22,6 +41,9 @@ export function useFunctionCalling() {
22
  });
23
  const [isProcessing, setIsProcessing] = useState(false);
24
 
 
 
 
25
  useEffect(() => {
26
  if (!("gpu" in navigator)) {
27
  setLoadingStatus((prev) => ({
@@ -45,12 +67,12 @@ export function useFunctionCalling() {
45
  };
46
  }, []);
47
 
 
48
  const initModel = useCallback(async () => {
49
  const api = apiRef.current;
50
  if (!api) return;
51
 
52
  setLoadingStatus((prev) => ({ ...prev, isLoading: true }));
53
- console.log("[useFunctionCalling] Starting model initialization...");
54
 
55
  const fileProgress = new Map<string, { loaded: number; total: number }>();
56
  let lastUpdate = 0;
@@ -65,7 +87,6 @@ export function useFunctionCalling() {
65
  });
66
  }
67
 
68
- // Throttle updates to avoid "Maximum update depth exceeded"
69
  const now = Date.now();
70
  if (now - lastUpdate < 100) return;
71
  lastUpdate = now;
@@ -80,7 +101,6 @@ export function useFunctionCalling() {
80
  const currentProgress =
81
  grandTotal > 0 ? (totalLoaded / grandTotal) * 100 : 0;
82
 
83
- console.log("[useFunctionCalling] Progress:", progress);
84
  setLoadingStatus((prev) => ({
85
  ...prev,
86
  currentFile: progress.file ?? prev.currentFile,
@@ -88,14 +108,13 @@ export function useFunctionCalling() {
88
  }));
89
  }),
90
  );
91
- console.log("[useFunctionCalling] Model loaded successfully!");
92
  setLoadingStatus({
93
  isLoading: false,
94
  isModelReady: true,
95
  downloadProgress: 100,
96
  });
97
  } catch (error) {
98
- console.error("[useFunctionCalling] Failed to load model:", error);
99
  setLoadingStatus((prev) => ({
100
  ...prev,
101
  isLoading: false,
@@ -105,13 +124,56 @@ export function useFunctionCalling() {
105
  }
106
  }, []);
107
 
108
- const processMessage = useCallback(
109
- async (
110
- text: string,
111
- ): Promise<{ functionCall: FunctionCallResult; textResponse?: string }> => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  const api = apiRef.current;
113
- if (!api) return { functionCall: null };
114
 
 
115
  if (!initializationPromiseRef.current) {
116
  initializationPromiseRef.current = initModel().catch((err) => {
117
  initializationPromiseRef.current = null;
@@ -122,36 +184,68 @@ export function useFunctionCalling() {
122
  try {
123
  await initializationPromiseRef.current;
124
  } catch {
125
- return {
126
- functionCall: null,
127
- textResponse: "Failed to initialize model. Please try again.",
128
- };
129
- }
130
-
131
- setIsProcessing(true);
132
- try {
133
- return await api.processMessage(text);
134
- } finally {
135
- setIsProcessing(false);
136
  }
137
- },
138
- [initModel],
139
- );
140
 
141
- const continueWithToolResult = useCallback(
142
- async (functionName: string, functionResult: unknown): Promise<string> => {
143
- const api = apiRef.current;
144
- if (!api) return "Error: Worker not available";
145
 
146
  setIsProcessing(true);
 
147
  try {
148
- return await api.continueWithToolResult(functionName, functionResult);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  } finally {
150
  setIsProcessing(false);
151
  }
152
  },
153
- [],
 
 
 
 
 
 
 
154
  );
155
 
156
- return { processMessage, continueWithToolResult, isProcessing, loadingStatus };
 
 
 
 
 
157
  }
 
1
  import * as Comlink from "comlink";
2
  import { useCallback, useEffect, useRef, useState } from "react";
3
+ import { useChatStore } from "@/store/useChatStore";
4
+ import {
5
+ parseFunctionCall,
6
+ cleanupModelResponse,
7
+ } from "@/lib/functionCalling";
8
+ import { toSnakeCase } from "@/lib/utils";
9
+ import type {
10
+ ToolDefinition,
11
+ ChatMessage,
12
+ FunctionCallResult,
13
+ } from "@/types/chat";
14
+ import type { WorkerAPI, LoadingProgress } from "@/worker";
15
+
16
+ const DEVELOPER_PROMPT =
17
+ "You are a model that can do function calling with the following functions";
18
 
19
  export type LoadingStatus = {
20
  isLoading: boolean;
 
24
  error?: string;
25
  };
26
 
27
+ type UseChatOptions = {
28
+ tools: ToolDefinition[];
29
+ onFunctionCall?: (fc: FunctionCallResult) => Promise<unknown>;
30
+ };
31
+
32
+ export function useChat({ tools, onFunctionCall }: UseChatOptions) {
33
  const workerRef = useRef<Worker | null>(null);
34
  const apiRef = useRef<Comlink.Remote<WorkerAPI> | null>(null);
35
  const initializationPromiseRef = useRef<Promise<void> | null>(null);
 
41
  });
42
  const [isProcessing, setIsProcessing] = useState(false);
43
 
44
+ const { messages, addMessage, setConversation } = useChatStore();
45
+
46
+ // Initialize worker
47
  useEffect(() => {
48
  if (!("gpu" in navigator)) {
49
  setLoadingStatus((prev) => ({
 
67
  };
68
  }, []);
69
 
70
+ // Initialize model
71
  const initModel = useCallback(async () => {
72
  const api = apiRef.current;
73
  if (!api) return;
74
 
75
  setLoadingStatus((prev) => ({ ...prev, isLoading: true }));
 
76
 
77
  const fileProgress = new Map<string, { loaded: number; total: number }>();
78
  let lastUpdate = 0;
 
87
  });
88
  }
89
 
 
90
  const now = Date.now();
91
  if (now - lastUpdate < 100) return;
92
  lastUpdate = now;
 
101
  const currentProgress =
102
  grandTotal > 0 ? (totalLoaded / grandTotal) * 100 : 0;
103
 
 
104
  setLoadingStatus((prev) => ({
105
  ...prev,
106
  currentFile: progress.file ?? prev.currentFile,
 
108
  }));
109
  }),
110
  );
111
+
112
  setLoadingStatus({
113
  isLoading: false,
114
  isModelReady: true,
115
  downloadProgress: 100,
116
  });
117
  } catch (error) {
 
118
  setLoadingStatus((prev) => ({
119
  ...prev,
120
  isLoading: false,
 
124
  }
125
  }, []);
126
 
127
+ // Continue conversation with tool result
128
+ const continueWithResult = useCallback(
129
+ async (result: unknown): Promise<string> => {
130
+ const api = apiRef.current;
131
+
132
+ // Get fresh state directly from store to avoid stale closure
133
+ const { conversationMessages, lastFunctionCall, clearConversation } =
134
+ useChatStore.getState();
135
+
136
+ if (!api || !lastFunctionCall || conversationMessages.length === 0) {
137
+ return "Error: no conversation to continue";
138
+ }
139
+
140
+ const snakeCaseName = toSnakeCase(lastFunctionCall.functionName);
141
+
142
+ const fullMessages: ChatMessage[] = [
143
+ ...conversationMessages,
144
+ {
145
+ role: "assistant",
146
+ tool_calls: [
147
+ {
148
+ type: "function",
149
+ function: {
150
+ name: snakeCaseName,
151
+ arguments: lastFunctionCall.args,
152
+ },
153
+ },
154
+ ],
155
+ },
156
+ {
157
+ role: "tool",
158
+ content: [{ name: snakeCaseName, response: result }],
159
+ },
160
+ ];
161
+
162
+ const rawOutput = await api.generate(fullMessages, tools);
163
+ clearConversation();
164
+
165
+ return cleanupModelResponse(rawOutput || "");
166
+ },
167
+ [tools],
168
+ );
169
+
170
+ // Send message
171
+ const sendMessage = useCallback(
172
+ async (text: string) => {
173
  const api = apiRef.current;
174
+ if (!api) return;
175
 
176
+ // Ensure model is initialized
177
  if (!initializationPromiseRef.current) {
178
  initializationPromiseRef.current = initModel().catch((err) => {
179
  initializationPromiseRef.current = null;
 
184
  try {
185
  await initializationPromiseRef.current;
186
  } catch {
187
+ addMessage({
188
+ id: Date.now(),
189
+ text: "Failed to initialize model. Please try again.",
190
+ sender: "bot",
191
+ });
192
+ return;
 
 
 
 
 
193
  }
 
 
 
194
 
195
+ // Add user message to UI
196
+ addMessage({ id: Date.now(), text, sender: "user" });
 
 
197
 
198
  setIsProcessing(true);
199
+
200
  try {
201
+ const mlMessages: ChatMessage[] = [
202
+ { role: "developer", content: DEVELOPER_PROMPT },
203
+ { role: "user", content: text },
204
+ ];
205
+
206
+ const rawOutput = await api.generate(mlMessages, tools);
207
+ const functionCall = parseFunctionCall(rawOutput || "");
208
+
209
+ if (functionCall && onFunctionCall) {
210
+ // Save conversation state for continuation
211
+ setConversation(mlMessages, functionCall);
212
+
213
+ // Execute function and get result
214
+ const result = await onFunctionCall(functionCall);
215
+
216
+ // Continue conversation with result
217
+ const response = await continueWithResult(result);
218
+ addMessage({ id: Date.now(), text: response, sender: "bot" });
219
+ } else if (functionCall) {
220
+ // Function call but no handler
221
+ addMessage({
222
+ id: Date.now(),
223
+ text: "Function call detected but no handler provided.",
224
+ sender: "bot",
225
+ });
226
+ } else {
227
+ // No function call, just text response
228
+ const response = cleanupModelResponse(rawOutput || "");
229
+ addMessage({ id: Date.now(), text: response, sender: "bot" });
230
+ }
231
  } finally {
232
  setIsProcessing(false);
233
  }
234
  },
235
+ [
236
+ tools,
237
+ onFunctionCall,
238
+ addMessage,
239
+ setConversation,
240
+ continueWithResult,
241
+ initModel,
242
+ ],
243
  );
244
 
245
+ return {
246
+ messages,
247
+ sendMessage,
248
+ isProcessing,
249
+ loadingStatus,
250
+ };
251
  }
src/lib/functionCalling.ts ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { FunctionCallResult } from "@/types/chat";
2
+ import { toCamelCase } from "./utils";
3
+
4
+ /**
5
+ * Parses FunctionGemma output format:
6
+ * <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
7
+ */
8
+ export function parseFunctionCall(output: string): FunctionCallResult {
9
+ const match = output.match(
10
+ /<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>/,
11
+ );
12
+
13
+ if (!match) return null;
14
+
15
+ const funcName = match[1];
16
+ const argsStr = match[2];
17
+
18
+ // Convert function name from snake_case to camelCase
19
+ const functionName = toCamelCase(funcName);
20
+
21
+ // Parse all arguments generically
22
+ // Format: key:<escape>value<escape>
23
+ const args: Record<string, unknown> = {};
24
+ const argRegex = /(\w+):<escape>([^<]*)<escape>/g;
25
+
26
+ for (const argMatch of argsStr.matchAll(argRegex)) {
27
+ const key = toCamelCase(argMatch[1]);
28
+ const value = argMatch[2].trim();
29
+ args[key] = value;
30
+ }
31
+
32
+ return { functionName, args };
33
+ }
34
+
35
+ /**
36
+ * Cleans up model response by removing special tokens
37
+ */
38
+ export function cleanupModelResponse(text: string): string {
39
+ return (
40
+ text
41
+ .replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
42
+ .replace(/<start_function_response>[\s\S]*?<end_function_response>/g, "")
43
+ .replace(/<end_of_turn>/g, "")
44
+ .replace(/<start_of_turn>/g, "")
45
+ .replace(/<\|im_end\|>/g, "")
46
+ .replace(/<\|im_start\|>/g, "")
47
+ .replace(/<\|endoftext\|>/g, "")
48
+ .replace(/<tool_call>[\s\S]*?<\/tool_call>/g, "")
49
+ .replace(/^model\s*/i, "")
50
+ .replace(/^assistant\s*/i, "")
51
+ .trim() || "Done!"
52
+ );
53
+ }
src/store/useChatStore.ts ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { create } from "zustand";
2
+ import type {
3
+ ChatMessage,
4
+ FunctionCallResult,
5
+ UIMessage,
6
+ } from "@/types/chat";
7
+
8
+ const INITIAL_MESSAGE: UIMessage = {
9
+ id: 1,
10
+ text: "Hi! I can change the square color for you. Ask me to set a color or ask what color it currently is!",
11
+ sender: "bot",
12
+ };
13
+
14
+ interface ChatState {
15
+ // UI Messages (displayed in chat)
16
+ messages: UIMessage[];
17
+ addMessage: (message: UIMessage) => void;
18
+
19
+ // ML Conversation (for multi-turn function calling)
20
+ conversationMessages: ChatMessage[];
21
+ lastFunctionCall: FunctionCallResult;
22
+ setConversation: (
23
+ messages: ChatMessage[],
24
+ functionCall: FunctionCallResult,
25
+ ) => void;
26
+ clearConversation: () => void;
27
+
28
+ // App State
29
+ squareColor: string;
30
+ setSquareColor: (color: string) => void;
31
+ }
32
+
33
+ export const useChatStore = create<ChatState>((set) => ({
34
+ // UI Messages
35
+ messages: [INITIAL_MESSAGE],
36
+ addMessage: (message) =>
37
+ set((state) => ({ messages: [...state.messages, message] })),
38
+
39
+ // ML Conversation
40
+ conversationMessages: [],
41
+ lastFunctionCall: null,
42
+ setConversation: (conversationMessages, lastFunctionCall) =>
43
+ set({ conversationMessages, lastFunctionCall }),
44
+ clearConversation: () =>
45
+ set({ conversationMessages: [], lastFunctionCall: null }),
46
+
47
+ // App State
48
+ squareColor: "rebeccapurple",
49
+ setSquareColor: (squareColor) => set({ squareColor }),
50
+ }));
src/store/useColorStore.ts DELETED
@@ -1,11 +0,0 @@
1
- import { create } from "zustand";
2
-
3
- interface ColorState {
4
- squareColor: string;
5
- setSquareColor: (color: string) => void;
6
- }
7
-
8
- export const useColorStore = create<ColorState>((set) => ({
9
- squareColor: "rebeccapurple",
10
- setSquareColor: (color: string) => set({ squareColor: color }),
11
- }));
 
 
 
 
 
 
 
 
 
 
 
 
src/types/chat.ts ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export type ToolCall = {
2
+ type: "function";
3
+ function: {
4
+ name: string;
5
+ arguments: Record<string, unknown>;
6
+ };
7
+ };
8
+
9
+ export type ToolDefinition = {
10
+ type: "function";
11
+ function: {
12
+ name: string;
13
+ description: string;
14
+ parameters: {
15
+ type: "object";
16
+ properties: Record<string, unknown>;
17
+ required: string[];
18
+ };
19
+ };
20
+ };
21
+
22
+ export type ChatMessage =
23
+ | { role: "developer" | "user"; content: string }
24
+ | { role: "assistant"; content?: string; tool_calls?: ToolCall[] }
25
+ | { role: "tool"; content: Array<{ name: string; response: unknown }> };
26
+
27
+ export type FunctionCallResult = {
28
+ functionName: string;
29
+ args: Record<string, unknown>;
30
+ } | null;
31
+
32
+ export type UIMessage = {
33
+ id: number;
34
+ text: string;
35
+ sender: "user" | "bot";
36
+ };
src/worker.ts CHANGED
@@ -5,8 +5,7 @@ import {
5
  type PreTrainedTokenizer,
6
  } from "@huggingface/transformers";
7
  import * as Comlink from "comlink";
8
- import { createStore } from "zustand/vanilla";
9
- import { toCamelCase, toSnakeCase } from "./lib/utils";
10
 
11
  const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
12
 
@@ -25,42 +24,6 @@ export type LoadingProgress = {
25
 
26
  type ProgressCallback = (progress: LoadingProgress) => void;
27
 
28
- // Types for FunctionGemma multi-turn conversation
29
- type ToolCall = {
30
- type: "function";
31
- function: {
32
- name: string;
33
- arguments: Record<string, unknown>;
34
- };
35
- };
36
-
37
- type ChatMessage =
38
- | { role: "developer" | "user"; content: string }
39
- | { role: "assistant"; content?: string; tool_calls?: ToolCall[] }
40
- | { role: "tool"; content: Array<{ name: string; response: unknown }> };
41
-
42
- export type FunctionCallResult = {
43
- functionName: string;
44
- args: Record<string, unknown>;
45
- } | null;
46
-
47
- type ConversationState = {
48
- messages: ChatMessage[];
49
- lastFunctionCall: FunctionCallResult;
50
- setMessages: (messages: ChatMessage[]) => void;
51
- setLastFunctionCall: (fc: FunctionCallResult) => void;
52
- clear: () => void;
53
- };
54
-
55
- // Store to maintain conversation state between processMessage and continueWithToolResult
56
- const conversationStore = createStore<ConversationState>((set) => ({
57
- messages: [],
58
- lastFunctionCall: null,
59
- setMessages: (messages) => set({ messages }),
60
- setLastFunctionCall: (fc) => set({ lastFunctionCall: fc }),
61
- clear: () => set({ messages: [], lastFunctionCall: null }),
62
- }));
63
-
64
  function getModel(onProgress?: ProgressCallback) {
65
  modelPromise ??= (async () => {
66
  console.log("[Worker] Loading model...");
@@ -80,71 +43,6 @@ function getModel(onProgress?: ProgressCallback) {
80
  return modelPromise;
81
  }
82
 
83
- const tools = [
84
- {
85
- type: "function",
86
- function: {
87
- name: "set_square_color",
88
- description: "Sets the color of the square displayed on the screen.",
89
- parameters: {
90
- type: "object",
91
- properties: {
92
- color: {
93
- type: "string",
94
- description: "The color to set, e.g. red, blue, green",
95
- },
96
- },
97
- required: ["color"],
98
- },
99
- },
100
- },
101
- {
102
- type: "function",
103
- function: {
104
- name: "get_square_color",
105
- description:
106
- "Returns the current color of the square. Use this when the user asks 'what color is the square' or 'tell me the color'.",
107
- parameters: {
108
- type: "object",
109
- properties: {},
110
- required: [],
111
- },
112
- },
113
- },
114
- ];
115
-
116
- // Exact prompt from functiongemma documentation
117
- const developerPrompt =
118
- "You are a model that can do function calling with the following functions";
119
-
120
- // Parses functiongemma output: <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
121
- function parseFunctionCall(output: string): FunctionCallResult {
122
- const match = output.match(
123
- /<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>/,
124
- );
125
-
126
- if (!match) return null;
127
-
128
- const funcName = match[1];
129
- const argsStr = match[2];
130
-
131
- // Convert function name from snake_case to camelCase
132
- const functionName = toCamelCase(funcName);
133
-
134
- // Parse all arguments generically
135
- // Format: key:<escape>value<escape>
136
- const args: Record<string, unknown> = {};
137
- const argRegex = /(\w+):<escape>([^<]*)<escape>/g;
138
-
139
- for (const argMatch of argsStr.matchAll(argRegex)) {
140
- const key = toCamelCase(argMatch[1]);
141
- const value = argMatch[2].trim();
142
- args[key] = value;
143
- }
144
-
145
- return { functionName, args };
146
- }
147
-
148
  const workerAPI = {
149
  async initModel(progressCallback: ProgressCallback): Promise<void> {
150
  console.log("[Worker] initModel called");
@@ -152,161 +50,17 @@ const workerAPI = {
152
  console.log("[Worker] initModel completed");
153
  },
154
 
155
- async processMessage(text: string): Promise<{
156
- functionCall: FunctionCallResult;
157
- textResponse?: string;
158
- }> {
159
- console.log("[Worker] ========== NEW REQUEST ==========");
160
- console.log("[Worker] User input:", text);
161
-
162
- if (!text.trim()) return { functionCall: null };
163
-
164
- const { model, tokenizer } = await getModel();
165
-
166
- const messages = [
167
- { role: "developer", content: developerPrompt },
168
- { role: "user", content: text },
169
- ];
170
-
171
- console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));
172
-
173
- const inputs = tokenizer.apply_chat_template(messages, {
174
- tools,
175
- tokenize: true,
176
- add_generation_prompt: true,
177
- return_dict: true,
178
- }) as { input_ids: { dims: number[]; data: unknown } };
179
-
180
- console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
181
- console.log("[Worker] Input structure:", Object.keys(inputs));
182
- console.log(
183
- "[Worker] input_ids type:",
184
- typeof inputs.input_ids,
185
- inputs.input_ids.constructor?.name,
186
- );
187
-
188
- try {
189
- const inputIds = Array.from(inputs.input_ids.data as Iterable<number>);
190
- const decodedInput = tokenizer.decode(inputIds, {
191
- skip_special_tokens: false,
192
- });
193
- console.log("[Worker] Decoded input (what model sees):", decodedInput);
194
- } catch (e) {
195
- console.log("[Worker] Could not decode input:", e);
196
- }
197
-
198
- // biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
199
- const output = await (model.generate as Function)({
200
- ...inputs,
201
- max_new_tokens: 512,
202
- });
203
-
204
- const inputLength = inputs.input_ids.dims[1];
205
- const outputLength = output.dims ? output.dims[1] : output.length;
206
- console.log("[Worker] Output token count:", outputLength);
207
- console.log("[Worker] New tokens generated:", outputLength - inputLength);
208
-
209
- const decoded = tokenizer.decode(output.slice(0, [inputLength, null]), {
210
- skip_special_tokens: false,
211
- });
212
-
213
- console.log("[Worker] Raw model output:", decoded);
214
- const functionCall = parseFunctionCall(decoded as string);
215
- console.log("[Worker] Parsed function call:", functionCall);
216
-
217
- // If function was called, save conversation state for continueWithToolResult
218
- if (functionCall) {
219
- console.log(
220
- "[Worker] ✅ Function call successful:",
221
- functionCall.functionName,
222
- );
223
-
224
- // Save conversation state for the next turn
225
- const { setMessages, setLastFunctionCall } = conversationStore.getState();
226
- setMessages([
227
- { role: "developer", content: developerPrompt },
228
- { role: "user", content: text },
229
- ]);
230
- setLastFunctionCall(functionCall);
231
-
232
- return { functionCall };
233
- }
234
-
235
- // If model generated an error format, return null to trigger fallback
236
- if ((decoded as string).includes("<start_function_call>error:")) {
237
- console.log(
238
- "[Worker] ⚠️ Model generated error format, triggering fallback",
239
- );
240
- return { functionCall: null };
241
- }
242
-
243
- console.log("[Worker] ❌ No function call detected, returning fallback");
244
-
245
- const textResponse = (decoded as string)
246
- .replace(/<tool_call>[\s\S]*?<\/tool_call>/g, "")
247
- .replace(/<\|im_end\|>/g, "")
248
- .replace(/<\|im_start\|>/g, "")
249
- .replace(/<\|endoftext\|>/g, "")
250
- .replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
251
- .replace(/<start_function_response>/g, "")
252
- .trim();
253
-
254
- return { functionCall: null, textResponse: textResponse || undefined };
255
- },
256
-
257
- async continueWithToolResult(
258
- functionName: string,
259
- functionResult: unknown,
260
  ): Promise<string> {
261
- console.log("[Worker] ========== CONTINUE WITH TOOL RESULT ==========");
262
- console.log("[Worker] Function:", functionName);
263
- console.log("[Worker] Result:", functionResult);
264
-
265
- const { messages, lastFunctionCall, clear } = conversationStore.getState();
266
-
267
- if (messages.length === 0 || !lastFunctionCall) {
268
- console.log("[Worker] ⚠️ No conversation state found");
269
- return "I apologize, but I lost track of our conversation. Could you please try again?";
270
- }
271
 
272
  const { model, tokenizer } = await getModel();
273
 
274
- // Build the full conversation with tool result
275
- // Convert camelCase back to snake_case for the model
276
- const snakeCaseName = toSnakeCase(lastFunctionCall.functionName);
277
-
278
- // Build messages in the format expected by the tokenizer
279
- // Note: Using 'as unknown' to bypass strict typing since the tokenizer
280
- // accepts a more flexible format than what TypeScript infers
281
- const fullMessages = [
282
- ...messages,
283
- // Assistant turn with tool_calls
284
- {
285
- role: "assistant",
286
- tool_calls: [
287
- {
288
- type: "function",
289
- function: {
290
- name: snakeCaseName,
291
- arguments: lastFunctionCall.args,
292
- },
293
- },
294
- ],
295
- },
296
- // Tool response
297
- {
298
- role: "tool",
299
- content: [{ name: snakeCaseName, response: functionResult }],
300
- },
301
- ] as unknown[];
302
-
303
- console.log(
304
- "[Worker] Full messages for final response:",
305
- JSON.stringify(fullMessages, null, 2),
306
- );
307
-
308
  // biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
309
- const inputs = tokenizer.apply_chat_template(fullMessages as any, {
310
  tools,
311
  tokenize: true,
312
  add_generation_prompt: true,
@@ -315,23 +69,10 @@ const workerAPI = {
315
 
316
  console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
317
 
318
- try {
319
- const inputIds = Array.from(inputs.input_ids.data as Iterable<number>);
320
- const decodedInput = tokenizer.decode(inputIds, {
321
- skip_special_tokens: false,
322
- });
323
- console.log(
324
- "[Worker] Decoded input for final response:",
325
- decodedInput,
326
- );
327
- } catch (e) {
328
- console.log("[Worker] Could not decode input:", e);
329
- }
330
-
331
  // biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
332
  const output = await (model.generate as Function)({
333
  ...inputs,
334
- max_new_tokens: 256,
335
  });
336
 
337
  const inputLength = inputs.input_ids.dims[1];
@@ -339,27 +80,9 @@ const workerAPI = {
339
  skip_special_tokens: false,
340
  });
341
 
342
- console.log("[Worker] Raw final response:", decoded);
343
-
344
- // Clear conversation state
345
- clear();
346
-
347
- // Clean up the response
348
- const cleanResponse = (decoded as string)
349
- .replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
350
- .replace(/<start_function_response>[\s\S]*?<end_function_response>/g, "")
351
- .replace(/<\|im_end\|>/g, "")
352
- .replace(/<\|im_start\|>/g, "")
353
- .replace(/<\|endoftext\|>/g, "")
354
- .replace(/<end_of_turn>/g, "")
355
- .replace(/<start_of_turn>/g, "")
356
- .replace(/^model\s*/i, "")
357
- .replace(/^assistant\s*/i, "")
358
- .trim();
359
-
360
- console.log("[Worker] ✅ Final response:", cleanResponse);
361
 
362
- return cleanResponse || "Done!";
363
  },
364
  };
365
 
 
5
  type PreTrainedTokenizer,
6
  } from "@huggingface/transformers";
7
  import * as Comlink from "comlink";
8
+ import type { ChatMessage, ToolDefinition } from "./types/chat";
 
9
 
10
  const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
11
 
 
24
 
25
  type ProgressCallback = (progress: LoadingProgress) => void;
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  function getModel(onProgress?: ProgressCallback) {
28
  modelPromise ??= (async () => {
29
  console.log("[Worker] Loading model...");
 
43
  return modelPromise;
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  const workerAPI = {
47
  async initModel(progressCallback: ProgressCallback): Promise<void> {
48
  console.log("[Worker] initModel called");
 
50
  console.log("[Worker] initModel completed");
51
  },
52
 
53
+ async generate(
54
+ messages: ChatMessage[],
55
+ tools: ToolDefinition[],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ): Promise<string> {
57
+ console.log("[Worker] ========== GENERATE ==========");
58
+ console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));
 
 
 
 
 
 
 
 
59
 
60
  const { model, tokenizer } = await getModel();
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  // biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
63
+ const inputs = tokenizer.apply_chat_template(messages as any, {
64
  tools,
65
  tokenize: true,
66
  add_generation_prompt: true,
 
69
 
70
  console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  // biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
73
  const output = await (model.generate as Function)({
74
  ...inputs,
75
+ max_new_tokens: 512,
76
  });
77
 
78
  const inputLength = inputs.input_ids.dims[1];
 
80
  skip_special_tokens: false,
81
  });
82
 
83
+ console.log("[Worker] Raw output:", decoded);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ return decoded as string;
86
  },
87
  };
88