Spaces:
Running
Running
refactor: create generic useChat hook and simplify architecture
Browse files- Add useChat hook that receives tools as parameter and handles UI + ML state
- Create unified Zustand store (useChatStore) replacing 2 stores
- Simplify worker to only handle model loading and inference (~90 lines)
- Extract function calling logic to lib/functionCalling.ts
- Move tool definitions to config/tools.ts
- Add shared types in types/chat.ts
- package.json +1 -1
- src/components/AgenticInterface.tsx +25 -60
- src/components/MessageBubble.tsx +4 -10
- src/config/tools.ts +34 -0
- src/hooks/{useFunctionCalling.ts → useChat.ts} +127 -33
- src/lib/functionCalling.ts +53 -0
- src/store/useChatStore.ts +50 -0
- src/store/useColorStore.ts +0 -11
- src/types/chat.ts +36 -0
- src/worker.ts +10 -287
package.json
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
"dev": "vite",
|
| 8 |
"build": "tsc -b && vite build",
|
| 9 |
"preview": "vite preview",
|
| 10 |
-
"typecheck": "tsc
|
| 11 |
"lint": "biome lint ./src",
|
| 12 |
"format": "biome format --write ./src",
|
| 13 |
"check": "biome check --write ./src"
|
|
|
|
| 7 |
"dev": "vite",
|
| 8 |
"build": "tsc -b && vite build",
|
| 9 |
"preview": "vite preview",
|
| 10 |
+
"typecheck": "tsc -b",
|
| 11 |
"lint": "biome lint ./src",
|
| 12 |
"format": "biome format --write ./src",
|
| 13 |
"check": "biome check --write ./src"
|
src/components/AgenticInterface.tsx
CHANGED
|
@@ -1,75 +1,40 @@
|
|
| 1 |
-
import {
|
| 2 |
-
import {
|
| 3 |
-
import {
|
|
|
|
|
|
|
| 4 |
import { ChatSidebar } from "./ChatSidebar";
|
| 5 |
import { ColorControlForm } from "./ColorControlForm";
|
| 6 |
import { ColorVisualizer } from "./ColorVisualizer";
|
| 7 |
-
import type { Message } from "./MessageBubble";
|
| 8 |
-
|
| 9 |
-
const INITIAL_MESSAGES: Message[] = [
|
| 10 |
-
{
|
| 11 |
-
id: 1,
|
| 12 |
-
text: "Hi! I can change the square color for you. Ask me to set a color or ask what color it currently is!",
|
| 13 |
-
sender: "bot",
|
| 14 |
-
},
|
| 15 |
-
];
|
| 16 |
|
| 17 |
export function AgenticInterface() {
|
| 18 |
-
const { squareColor, setSquareColor } =
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
const result = await processMessage(text);
|
| 30 |
-
|
| 31 |
-
if (result.functionCall?.functionName === "setSquareColor") {
|
| 32 |
-
const color = result.functionCall.args.color as string | undefined;
|
| 33 |
-
if (color) {
|
| 34 |
-
setSquareColor(color);
|
| 35 |
-
// Get response from model after executing the function
|
| 36 |
-
botResponse = await continueWithToolResult("set_square_color", {
|
| 37 |
-
success: true,
|
| 38 |
-
color,
|
| 39 |
-
});
|
| 40 |
-
} else {
|
| 41 |
-
botResponse =
|
| 42 |
-
"I understood you want to change the color, but I couldn't determine which color. Please try again with a specific color.";
|
| 43 |
-
}
|
| 44 |
-
} else if (result.functionCall?.functionName === "getSquareColor") {
|
| 45 |
-
// Get response from model after executing the function
|
| 46 |
-
botResponse = await continueWithToolResult("get_square_color", {
|
| 47 |
-
color: squareColor,
|
| 48 |
-
});
|
| 49 |
-
} else if (result.textResponse) {
|
| 50 |
-
botResponse = result.textResponse;
|
| 51 |
-
} else {
|
| 52 |
-
botResponse =
|
| 53 |
-
"Sorry, I had trouble understanding that. Could you please try again?";
|
| 54 |
}
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
};
|
| 65 |
-
setMessages((prev) => [...prev, botMessage]);
|
| 66 |
-
};
|
| 67 |
|
| 68 |
return (
|
| 69 |
<div className="flex flex-col md:flex-row h-screen w-screen overflow-hidden bg-background text-foreground font-sans selection:bg-primary/30">
|
| 70 |
<ChatSidebar
|
| 71 |
messages={messages}
|
| 72 |
-
onSendMessage={
|
| 73 |
isLoading={isProcessing || loadingStatus.isLoading}
|
| 74 |
modelReady={loadingStatus.isModelReady}
|
| 75 |
progress={loadingStatus.downloadProgress}
|
|
|
|
| 1 |
+
import { useCallback } from "react";
|
| 2 |
+
import { colorTools } from "@/config/tools";
|
| 3 |
+
import { useChat } from "@/hooks/useChat";
|
| 4 |
+
import { useChatStore } from "@/store/useChatStore";
|
| 5 |
+
import type { FunctionCallResult } from "@/types/chat";
|
| 6 |
import { ChatSidebar } from "./ChatSidebar";
|
| 7 |
import { ColorControlForm } from "./ColorControlForm";
|
| 8 |
import { ColorVisualizer } from "./ColorVisualizer";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
export function AgenticInterface() {
|
| 11 |
+
const { squareColor, setSquareColor } = useChatStore();
|
| 12 |
+
|
| 13 |
+
const handleFunctionCall = useCallback(
|
| 14 |
+
async (fc: FunctionCallResult) => {
|
| 15 |
+
if (fc?.functionName === "setSquareColor") {
|
| 16 |
+
const color = fc.args.color as string;
|
| 17 |
+
setSquareColor(color);
|
| 18 |
+
return { success: true, color };
|
| 19 |
+
}
|
| 20 |
+
if (fc?.functionName === "getSquareColor") {
|
| 21 |
+
return { color: squareColor };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
}
|
| 23 |
+
return { error: "Unknown function" };
|
| 24 |
+
},
|
| 25 |
+
[squareColor, setSquareColor],
|
| 26 |
+
);
|
| 27 |
|
| 28 |
+
const { messages, sendMessage, isProcessing, loadingStatus } = useChat({
|
| 29 |
+
tools: colorTools,
|
| 30 |
+
onFunctionCall: handleFunctionCall,
|
| 31 |
+
});
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
return (
|
| 34 |
<div className="flex flex-col md:flex-row h-screen w-screen overflow-hidden bg-background text-foreground font-sans selection:bg-primary/30">
|
| 35 |
<ChatSidebar
|
| 36 |
messages={messages}
|
| 37 |
+
onSendMessage={sendMessage}
|
| 38 |
isLoading={isProcessing || loadingStatus.isLoading}
|
| 39 |
modelReady={loadingStatus.isModelReady}
|
| 40 |
progress={loadingStatus.downloadProgress}
|
src/components/MessageBubble.tsx
CHANGED
|
@@ -1,15 +1,11 @@
|
|
| 1 |
import { IconRobot, IconUser } from "@tabler/icons-react";
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
type Message = {
|
| 6 |
-
id: number;
|
| 7 |
-
text: string;
|
| 8 |
-
sender: MessageSender;
|
| 9 |
-
};
|
| 10 |
|
| 11 |
type MessageBubbleProps = {
|
| 12 |
-
message:
|
| 13 |
};
|
| 14 |
|
| 15 |
export function MessageBubble({ message }: MessageBubbleProps) {
|
|
@@ -41,5 +37,3 @@ export function MessageBubble({ message }: MessageBubbleProps) {
|
|
| 41 |
</div>
|
| 42 |
);
|
| 43 |
}
|
| 44 |
-
|
| 45 |
-
export type { Message, MessageSender };
|
|
|
|
| 1 |
import { IconRobot, IconUser } from "@tabler/icons-react";
|
| 2 |
+
import type { UIMessage } from "@/types/chat";
|
| 3 |
|
| 4 |
+
// Re-export for backwards compatibility
|
| 5 |
+
export type Message = UIMessage;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
type MessageBubbleProps = {
|
| 8 |
+
message: UIMessage;
|
| 9 |
};
|
| 10 |
|
| 11 |
export function MessageBubble({ message }: MessageBubbleProps) {
|
|
|
|
| 37 |
</div>
|
| 38 |
);
|
| 39 |
}
|
|
|
|
|
|
src/config/tools.ts
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { ToolDefinition } from "@/types/chat";
|
| 2 |
+
|
| 3 |
+
export const colorTools: ToolDefinition[] = [
|
| 4 |
+
{
|
| 5 |
+
type: "function",
|
| 6 |
+
function: {
|
| 7 |
+
name: "set_square_color",
|
| 8 |
+
description: "Sets the color of the square displayed on the screen.",
|
| 9 |
+
parameters: {
|
| 10 |
+
type: "object",
|
| 11 |
+
properties: {
|
| 12 |
+
color: {
|
| 13 |
+
type: "string",
|
| 14 |
+
description: "The color to set, e.g. red, blue, green",
|
| 15 |
+
},
|
| 16 |
+
},
|
| 17 |
+
required: ["color"],
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
type: "function",
|
| 23 |
+
function: {
|
| 24 |
+
name: "get_square_color",
|
| 25 |
+
description:
|
| 26 |
+
"Returns the current color of the square. Use this when the user asks 'what color is the square' or 'tell me the color'.",
|
| 27 |
+
parameters: {
|
| 28 |
+
type: "object",
|
| 29 |
+
properties: {},
|
| 30 |
+
required: [],
|
| 31 |
+
},
|
| 32 |
+
},
|
| 33 |
+
},
|
| 34 |
+
];
|
src/hooks/{useFunctionCalling.ts → useChat.ts}
RENAMED
|
@@ -1,6 +1,20 @@
|
|
| 1 |
import * as Comlink from "comlink";
|
| 2 |
import { useCallback, useEffect, useRef, useState } from "react";
|
| 3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
export type LoadingStatus = {
|
| 6 |
isLoading: boolean;
|
|
@@ -10,7 +24,12 @@ export type LoadingStatus = {
|
|
| 10 |
error?: string;
|
| 11 |
};
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
const workerRef = useRef<Worker | null>(null);
|
| 15 |
const apiRef = useRef<Comlink.Remote<WorkerAPI> | null>(null);
|
| 16 |
const initializationPromiseRef = useRef<Promise<void> | null>(null);
|
|
@@ -22,6 +41,9 @@ export function useFunctionCalling() {
|
|
| 22 |
});
|
| 23 |
const [isProcessing, setIsProcessing] = useState(false);
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
useEffect(() => {
|
| 26 |
if (!("gpu" in navigator)) {
|
| 27 |
setLoadingStatus((prev) => ({
|
|
@@ -45,12 +67,12 @@ export function useFunctionCalling() {
|
|
| 45 |
};
|
| 46 |
}, []);
|
| 47 |
|
|
|
|
| 48 |
const initModel = useCallback(async () => {
|
| 49 |
const api = apiRef.current;
|
| 50 |
if (!api) return;
|
| 51 |
|
| 52 |
setLoadingStatus((prev) => ({ ...prev, isLoading: true }));
|
| 53 |
-
console.log("[useFunctionCalling] Starting model initialization...");
|
| 54 |
|
| 55 |
const fileProgress = new Map<string, { loaded: number; total: number }>();
|
| 56 |
let lastUpdate = 0;
|
|
@@ -65,7 +87,6 @@ export function useFunctionCalling() {
|
|
| 65 |
});
|
| 66 |
}
|
| 67 |
|
| 68 |
-
// Throttle updates to avoid "Maximum update depth exceeded"
|
| 69 |
const now = Date.now();
|
| 70 |
if (now - lastUpdate < 100) return;
|
| 71 |
lastUpdate = now;
|
|
@@ -80,7 +101,6 @@ export function useFunctionCalling() {
|
|
| 80 |
const currentProgress =
|
| 81 |
grandTotal > 0 ? (totalLoaded / grandTotal) * 100 : 0;
|
| 82 |
|
| 83 |
-
console.log("[useFunctionCalling] Progress:", progress);
|
| 84 |
setLoadingStatus((prev) => ({
|
| 85 |
...prev,
|
| 86 |
currentFile: progress.file ?? prev.currentFile,
|
|
@@ -88,14 +108,13 @@ export function useFunctionCalling() {
|
|
| 88 |
}));
|
| 89 |
}),
|
| 90 |
);
|
| 91 |
-
|
| 92 |
setLoadingStatus({
|
| 93 |
isLoading: false,
|
| 94 |
isModelReady: true,
|
| 95 |
downloadProgress: 100,
|
| 96 |
});
|
| 97 |
} catch (error) {
|
| 98 |
-
console.error("[useFunctionCalling] Failed to load model:", error);
|
| 99 |
setLoadingStatus((prev) => ({
|
| 100 |
...prev,
|
| 101 |
isLoading: false,
|
|
@@ -105,13 +124,56 @@ export function useFunctionCalling() {
|
|
| 105 |
}
|
| 106 |
}, []);
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
const api = apiRef.current;
|
| 113 |
-
if (!api) return
|
| 114 |
|
|
|
|
| 115 |
if (!initializationPromiseRef.current) {
|
| 116 |
initializationPromiseRef.current = initModel().catch((err) => {
|
| 117 |
initializationPromiseRef.current = null;
|
|
@@ -122,36 +184,68 @@ export function useFunctionCalling() {
|
|
| 122 |
try {
|
| 123 |
await initializationPromiseRef.current;
|
| 124 |
} catch {
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
setIsProcessing(true);
|
| 132 |
-
try {
|
| 133 |
-
return await api.processMessage(text);
|
| 134 |
-
} finally {
|
| 135 |
-
setIsProcessing(false);
|
| 136 |
}
|
| 137 |
-
},
|
| 138 |
-
[initModel],
|
| 139 |
-
);
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
const api = apiRef.current;
|
| 144 |
-
if (!api) return "Error: Worker not available";
|
| 145 |
|
| 146 |
setIsProcessing(true);
|
|
|
|
| 147 |
try {
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
} finally {
|
| 150 |
setIsProcessing(false);
|
| 151 |
}
|
| 152 |
},
|
| 153 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
);
|
| 155 |
|
| 156 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
}
|
|
|
|
| 1 |
import * as Comlink from "comlink";
|
| 2 |
import { useCallback, useEffect, useRef, useState } from "react";
|
| 3 |
+
import { useChatStore } from "@/store/useChatStore";
|
| 4 |
+
import {
|
| 5 |
+
parseFunctionCall,
|
| 6 |
+
cleanupModelResponse,
|
| 7 |
+
} from "@/lib/functionCalling";
|
| 8 |
+
import { toSnakeCase } from "@/lib/utils";
|
| 9 |
+
import type {
|
| 10 |
+
ToolDefinition,
|
| 11 |
+
ChatMessage,
|
| 12 |
+
FunctionCallResult,
|
| 13 |
+
} from "@/types/chat";
|
| 14 |
+
import type { WorkerAPI, LoadingProgress } from "@/worker";
|
| 15 |
+
|
| 16 |
+
const DEVELOPER_PROMPT =
|
| 17 |
+
"You are a model that can do function calling with the following functions";
|
| 18 |
|
| 19 |
export type LoadingStatus = {
|
| 20 |
isLoading: boolean;
|
|
|
|
| 24 |
error?: string;
|
| 25 |
};
|
| 26 |
|
| 27 |
+
type UseChatOptions = {
|
| 28 |
+
tools: ToolDefinition[];
|
| 29 |
+
onFunctionCall?: (fc: FunctionCallResult) => Promise<unknown>;
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
export function useChat({ tools, onFunctionCall }: UseChatOptions) {
|
| 33 |
const workerRef = useRef<Worker | null>(null);
|
| 34 |
const apiRef = useRef<Comlink.Remote<WorkerAPI> | null>(null);
|
| 35 |
const initializationPromiseRef = useRef<Promise<void> | null>(null);
|
|
|
|
| 41 |
});
|
| 42 |
const [isProcessing, setIsProcessing] = useState(false);
|
| 43 |
|
| 44 |
+
const { messages, addMessage, setConversation } = useChatStore();
|
| 45 |
+
|
| 46 |
+
// Initialize worker
|
| 47 |
useEffect(() => {
|
| 48 |
if (!("gpu" in navigator)) {
|
| 49 |
setLoadingStatus((prev) => ({
|
|
|
|
| 67 |
};
|
| 68 |
}, []);
|
| 69 |
|
| 70 |
+
// Initialize model
|
| 71 |
const initModel = useCallback(async () => {
|
| 72 |
const api = apiRef.current;
|
| 73 |
if (!api) return;
|
| 74 |
|
| 75 |
setLoadingStatus((prev) => ({ ...prev, isLoading: true }));
|
|
|
|
| 76 |
|
| 77 |
const fileProgress = new Map<string, { loaded: number; total: number }>();
|
| 78 |
let lastUpdate = 0;
|
|
|
|
| 87 |
});
|
| 88 |
}
|
| 89 |
|
|
|
|
| 90 |
const now = Date.now();
|
| 91 |
if (now - lastUpdate < 100) return;
|
| 92 |
lastUpdate = now;
|
|
|
|
| 101 |
const currentProgress =
|
| 102 |
grandTotal > 0 ? (totalLoaded / grandTotal) * 100 : 0;
|
| 103 |
|
|
|
|
| 104 |
setLoadingStatus((prev) => ({
|
| 105 |
...prev,
|
| 106 |
currentFile: progress.file ?? prev.currentFile,
|
|
|
|
| 108 |
}));
|
| 109 |
}),
|
| 110 |
);
|
| 111 |
+
|
| 112 |
setLoadingStatus({
|
| 113 |
isLoading: false,
|
| 114 |
isModelReady: true,
|
| 115 |
downloadProgress: 100,
|
| 116 |
});
|
| 117 |
} catch (error) {
|
|
|
|
| 118 |
setLoadingStatus((prev) => ({
|
| 119 |
...prev,
|
| 120 |
isLoading: false,
|
|
|
|
| 124 |
}
|
| 125 |
}, []);
|
| 126 |
|
| 127 |
+
// Continue conversation with tool result
|
| 128 |
+
const continueWithResult = useCallback(
|
| 129 |
+
async (result: unknown): Promise<string> => {
|
| 130 |
+
const api = apiRef.current;
|
| 131 |
+
|
| 132 |
+
// Get fresh state directly from store to avoid stale closure
|
| 133 |
+
const { conversationMessages, lastFunctionCall, clearConversation } =
|
| 134 |
+
useChatStore.getState();
|
| 135 |
+
|
| 136 |
+
if (!api || !lastFunctionCall || conversationMessages.length === 0) {
|
| 137 |
+
return "Error: no conversation to continue";
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
const snakeCaseName = toSnakeCase(lastFunctionCall.functionName);
|
| 141 |
+
|
| 142 |
+
const fullMessages: ChatMessage[] = [
|
| 143 |
+
...conversationMessages,
|
| 144 |
+
{
|
| 145 |
+
role: "assistant",
|
| 146 |
+
tool_calls: [
|
| 147 |
+
{
|
| 148 |
+
type: "function",
|
| 149 |
+
function: {
|
| 150 |
+
name: snakeCaseName,
|
| 151 |
+
arguments: lastFunctionCall.args,
|
| 152 |
+
},
|
| 153 |
+
},
|
| 154 |
+
],
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
role: "tool",
|
| 158 |
+
content: [{ name: snakeCaseName, response: result }],
|
| 159 |
+
},
|
| 160 |
+
];
|
| 161 |
+
|
| 162 |
+
const rawOutput = await api.generate(fullMessages, tools);
|
| 163 |
+
clearConversation();
|
| 164 |
+
|
| 165 |
+
return cleanupModelResponse(rawOutput || "");
|
| 166 |
+
},
|
| 167 |
+
[tools],
|
| 168 |
+
);
|
| 169 |
+
|
| 170 |
+
// Send message
|
| 171 |
+
const sendMessage = useCallback(
|
| 172 |
+
async (text: string) => {
|
| 173 |
const api = apiRef.current;
|
| 174 |
+
if (!api) return;
|
| 175 |
|
| 176 |
+
// Ensure model is initialized
|
| 177 |
if (!initializationPromiseRef.current) {
|
| 178 |
initializationPromiseRef.current = initModel().catch((err) => {
|
| 179 |
initializationPromiseRef.current = null;
|
|
|
|
| 184 |
try {
|
| 185 |
await initializationPromiseRef.current;
|
| 186 |
} catch {
|
| 187 |
+
addMessage({
|
| 188 |
+
id: Date.now(),
|
| 189 |
+
text: "Failed to initialize model. Please try again.",
|
| 190 |
+
sender: "bot",
|
| 191 |
+
});
|
| 192 |
+
return;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
}
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
// Add user message to UI
|
| 196 |
+
addMessage({ id: Date.now(), text, sender: "user" });
|
|
|
|
|
|
|
| 197 |
|
| 198 |
setIsProcessing(true);
|
| 199 |
+
|
| 200 |
try {
|
| 201 |
+
const mlMessages: ChatMessage[] = [
|
| 202 |
+
{ role: "developer", content: DEVELOPER_PROMPT },
|
| 203 |
+
{ role: "user", content: text },
|
| 204 |
+
];
|
| 205 |
+
|
| 206 |
+
const rawOutput = await api.generate(mlMessages, tools);
|
| 207 |
+
const functionCall = parseFunctionCall(rawOutput || "");
|
| 208 |
+
|
| 209 |
+
if (functionCall && onFunctionCall) {
|
| 210 |
+
// Save conversation state for continuation
|
| 211 |
+
setConversation(mlMessages, functionCall);
|
| 212 |
+
|
| 213 |
+
// Execute function and get result
|
| 214 |
+
const result = await onFunctionCall(functionCall);
|
| 215 |
+
|
| 216 |
+
// Continue conversation with result
|
| 217 |
+
const response = await continueWithResult(result);
|
| 218 |
+
addMessage({ id: Date.now(), text: response, sender: "bot" });
|
| 219 |
+
} else if (functionCall) {
|
| 220 |
+
// Function call but no handler
|
| 221 |
+
addMessage({
|
| 222 |
+
id: Date.now(),
|
| 223 |
+
text: "Function call detected but no handler provided.",
|
| 224 |
+
sender: "bot",
|
| 225 |
+
});
|
| 226 |
+
} else {
|
| 227 |
+
// No function call, just text response
|
| 228 |
+
const response = cleanupModelResponse(rawOutput || "");
|
| 229 |
+
addMessage({ id: Date.now(), text: response, sender: "bot" });
|
| 230 |
+
}
|
| 231 |
} finally {
|
| 232 |
setIsProcessing(false);
|
| 233 |
}
|
| 234 |
},
|
| 235 |
+
[
|
| 236 |
+
tools,
|
| 237 |
+
onFunctionCall,
|
| 238 |
+
addMessage,
|
| 239 |
+
setConversation,
|
| 240 |
+
continueWithResult,
|
| 241 |
+
initModel,
|
| 242 |
+
],
|
| 243 |
);
|
| 244 |
|
| 245 |
+
return {
|
| 246 |
+
messages,
|
| 247 |
+
sendMessage,
|
| 248 |
+
isProcessing,
|
| 249 |
+
loadingStatus,
|
| 250 |
+
};
|
| 251 |
}
|
src/lib/functionCalling.ts
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { FunctionCallResult } from "@/types/chat";
|
| 2 |
+
import { toCamelCase } from "./utils";
|
| 3 |
+
|
| 4 |
+
/**
|
| 5 |
+
* Parses FunctionGemma output format:
|
| 6 |
+
* <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
|
| 7 |
+
*/
|
| 8 |
+
export function parseFunctionCall(output: string): FunctionCallResult {
|
| 9 |
+
const match = output.match(
|
| 10 |
+
/<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>/,
|
| 11 |
+
);
|
| 12 |
+
|
| 13 |
+
if (!match) return null;
|
| 14 |
+
|
| 15 |
+
const funcName = match[1];
|
| 16 |
+
const argsStr = match[2];
|
| 17 |
+
|
| 18 |
+
// Convert function name from snake_case to camelCase
|
| 19 |
+
const functionName = toCamelCase(funcName);
|
| 20 |
+
|
| 21 |
+
// Parse all arguments generically
|
| 22 |
+
// Format: key:<escape>value<escape>
|
| 23 |
+
const args: Record<string, unknown> = {};
|
| 24 |
+
const argRegex = /(\w+):<escape>([^<]*)<escape>/g;
|
| 25 |
+
|
| 26 |
+
for (const argMatch of argsStr.matchAll(argRegex)) {
|
| 27 |
+
const key = toCamelCase(argMatch[1]);
|
| 28 |
+
const value = argMatch[2].trim();
|
| 29 |
+
args[key] = value;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return { functionName, args };
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* Cleans up model response by removing special tokens
|
| 37 |
+
*/
|
| 38 |
+
export function cleanupModelResponse(text: string): string {
|
| 39 |
+
return (
|
| 40 |
+
text
|
| 41 |
+
.replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
|
| 42 |
+
.replace(/<start_function_response>[\s\S]*?<end_function_response>/g, "")
|
| 43 |
+
.replace(/<end_of_turn>/g, "")
|
| 44 |
+
.replace(/<start_of_turn>/g, "")
|
| 45 |
+
.replace(/<\|im_end\|>/g, "")
|
| 46 |
+
.replace(/<\|im_start\|>/g, "")
|
| 47 |
+
.replace(/<\|endoftext\|>/g, "")
|
| 48 |
+
.replace(/<tool_call>[\s\S]*?<\/tool_call>/g, "")
|
| 49 |
+
.replace(/^model\s*/i, "")
|
| 50 |
+
.replace(/^assistant\s*/i, "")
|
| 51 |
+
.trim() || "Done!"
|
| 52 |
+
);
|
| 53 |
+
}
|
src/store/useChatStore.ts
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { create } from "zustand";
|
| 2 |
+
import type {
|
| 3 |
+
ChatMessage,
|
| 4 |
+
FunctionCallResult,
|
| 5 |
+
UIMessage,
|
| 6 |
+
} from "@/types/chat";
|
| 7 |
+
|
| 8 |
+
const INITIAL_MESSAGE: UIMessage = {
|
| 9 |
+
id: 1,
|
| 10 |
+
text: "Hi! I can change the square color for you. Ask me to set a color or ask what color it currently is!",
|
| 11 |
+
sender: "bot",
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
interface ChatState {
|
| 15 |
+
// UI Messages (displayed in chat)
|
| 16 |
+
messages: UIMessage[];
|
| 17 |
+
addMessage: (message: UIMessage) => void;
|
| 18 |
+
|
| 19 |
+
// ML Conversation (for multi-turn function calling)
|
| 20 |
+
conversationMessages: ChatMessage[];
|
| 21 |
+
lastFunctionCall: FunctionCallResult;
|
| 22 |
+
setConversation: (
|
| 23 |
+
messages: ChatMessage[],
|
| 24 |
+
functionCall: FunctionCallResult,
|
| 25 |
+
) => void;
|
| 26 |
+
clearConversation: () => void;
|
| 27 |
+
|
| 28 |
+
// App State
|
| 29 |
+
squareColor: string;
|
| 30 |
+
setSquareColor: (color: string) => void;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
export const useChatStore = create<ChatState>((set) => ({
|
| 34 |
+
// UI Messages
|
| 35 |
+
messages: [INITIAL_MESSAGE],
|
| 36 |
+
addMessage: (message) =>
|
| 37 |
+
set((state) => ({ messages: [...state.messages, message] })),
|
| 38 |
+
|
| 39 |
+
// ML Conversation
|
| 40 |
+
conversationMessages: [],
|
| 41 |
+
lastFunctionCall: null,
|
| 42 |
+
setConversation: (conversationMessages, lastFunctionCall) =>
|
| 43 |
+
set({ conversationMessages, lastFunctionCall }),
|
| 44 |
+
clearConversation: () =>
|
| 45 |
+
set({ conversationMessages: [], lastFunctionCall: null }),
|
| 46 |
+
|
| 47 |
+
// App State
|
| 48 |
+
squareColor: "rebeccapurple",
|
| 49 |
+
setSquareColor: (squareColor) => set({ squareColor }),
|
| 50 |
+
}));
|
src/store/useColorStore.ts
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import { create } from "zustand";
|
| 2 |
-
|
| 3 |
-
interface ColorState {
|
| 4 |
-
squareColor: string;
|
| 5 |
-
setSquareColor: (color: string) => void;
|
| 6 |
-
}
|
| 7 |
-
|
| 8 |
-
export const useColorStore = create<ColorState>((set) => ({
|
| 9 |
-
squareColor: "rebeccapurple",
|
| 10 |
-
setSquareColor: (color: string) => set({ squareColor: color }),
|
| 11 |
-
}));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/types/chat.ts
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export type ToolCall = {
|
| 2 |
+
type: "function";
|
| 3 |
+
function: {
|
| 4 |
+
name: string;
|
| 5 |
+
arguments: Record<string, unknown>;
|
| 6 |
+
};
|
| 7 |
+
};
|
| 8 |
+
|
| 9 |
+
export type ToolDefinition = {
|
| 10 |
+
type: "function";
|
| 11 |
+
function: {
|
| 12 |
+
name: string;
|
| 13 |
+
description: string;
|
| 14 |
+
parameters: {
|
| 15 |
+
type: "object";
|
| 16 |
+
properties: Record<string, unknown>;
|
| 17 |
+
required: string[];
|
| 18 |
+
};
|
| 19 |
+
};
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
export type ChatMessage =
|
| 23 |
+
| { role: "developer" | "user"; content: string }
|
| 24 |
+
| { role: "assistant"; content?: string; tool_calls?: ToolCall[] }
|
| 25 |
+
| { role: "tool"; content: Array<{ name: string; response: unknown }> };
|
| 26 |
+
|
| 27 |
+
export type FunctionCallResult = {
|
| 28 |
+
functionName: string;
|
| 29 |
+
args: Record<string, unknown>;
|
| 30 |
+
} | null;
|
| 31 |
+
|
| 32 |
+
export type UIMessage = {
|
| 33 |
+
id: number;
|
| 34 |
+
text: string;
|
| 35 |
+
sender: "user" | "bot";
|
| 36 |
+
};
|
src/worker.ts
CHANGED
|
@@ -5,8 +5,7 @@ import {
|
|
| 5 |
type PreTrainedTokenizer,
|
| 6 |
} from "@huggingface/transformers";
|
| 7 |
import * as Comlink from "comlink";
|
| 8 |
-
import {
|
| 9 |
-
import { toCamelCase, toSnakeCase } from "./lib/utils";
|
| 10 |
|
| 11 |
const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
|
| 12 |
|
|
@@ -25,42 +24,6 @@ export type LoadingProgress = {
|
|
| 25 |
|
| 26 |
type ProgressCallback = (progress: LoadingProgress) => void;
|
| 27 |
|
| 28 |
-
// Types for FunctionGemma multi-turn conversation
|
| 29 |
-
type ToolCall = {
|
| 30 |
-
type: "function";
|
| 31 |
-
function: {
|
| 32 |
-
name: string;
|
| 33 |
-
arguments: Record<string, unknown>;
|
| 34 |
-
};
|
| 35 |
-
};
|
| 36 |
-
|
| 37 |
-
type ChatMessage =
|
| 38 |
-
| { role: "developer" | "user"; content: string }
|
| 39 |
-
| { role: "assistant"; content?: string; tool_calls?: ToolCall[] }
|
| 40 |
-
| { role: "tool"; content: Array<{ name: string; response: unknown }> };
|
| 41 |
-
|
| 42 |
-
export type FunctionCallResult = {
|
| 43 |
-
functionName: string;
|
| 44 |
-
args: Record<string, unknown>;
|
| 45 |
-
} | null;
|
| 46 |
-
|
| 47 |
-
type ConversationState = {
|
| 48 |
-
messages: ChatMessage[];
|
| 49 |
-
lastFunctionCall: FunctionCallResult;
|
| 50 |
-
setMessages: (messages: ChatMessage[]) => void;
|
| 51 |
-
setLastFunctionCall: (fc: FunctionCallResult) => void;
|
| 52 |
-
clear: () => void;
|
| 53 |
-
};
|
| 54 |
-
|
| 55 |
-
// Store to maintain conversation state between processMessage and continueWithToolResult
|
| 56 |
-
const conversationStore = createStore<ConversationState>((set) => ({
|
| 57 |
-
messages: [],
|
| 58 |
-
lastFunctionCall: null,
|
| 59 |
-
setMessages: (messages) => set({ messages }),
|
| 60 |
-
setLastFunctionCall: (fc) => set({ lastFunctionCall: fc }),
|
| 61 |
-
clear: () => set({ messages: [], lastFunctionCall: null }),
|
| 62 |
-
}));
|
| 63 |
-
|
| 64 |
function getModel(onProgress?: ProgressCallback) {
|
| 65 |
modelPromise ??= (async () => {
|
| 66 |
console.log("[Worker] Loading model...");
|
|
@@ -80,71 +43,6 @@ function getModel(onProgress?: ProgressCallback) {
|
|
| 80 |
return modelPromise;
|
| 81 |
}
|
| 82 |
|
| 83 |
-
const tools = [
|
| 84 |
-
{
|
| 85 |
-
type: "function",
|
| 86 |
-
function: {
|
| 87 |
-
name: "set_square_color",
|
| 88 |
-
description: "Sets the color of the square displayed on the screen.",
|
| 89 |
-
parameters: {
|
| 90 |
-
type: "object",
|
| 91 |
-
properties: {
|
| 92 |
-
color: {
|
| 93 |
-
type: "string",
|
| 94 |
-
description: "The color to set, e.g. red, blue, green",
|
| 95 |
-
},
|
| 96 |
-
},
|
| 97 |
-
required: ["color"],
|
| 98 |
-
},
|
| 99 |
-
},
|
| 100 |
-
},
|
| 101 |
-
{
|
| 102 |
-
type: "function",
|
| 103 |
-
function: {
|
| 104 |
-
name: "get_square_color",
|
| 105 |
-
description:
|
| 106 |
-
"Returns the current color of the square. Use this when the user asks 'what color is the square' or 'tell me the color'.",
|
| 107 |
-
parameters: {
|
| 108 |
-
type: "object",
|
| 109 |
-
properties: {},
|
| 110 |
-
required: [],
|
| 111 |
-
},
|
| 112 |
-
},
|
| 113 |
-
},
|
| 114 |
-
];
|
| 115 |
-
|
| 116 |
-
// Exact prompt from functiongemma documentation
|
| 117 |
-
const developerPrompt =
|
| 118 |
-
"You are a model that can do function calling with the following functions";
|
| 119 |
-
|
| 120 |
-
// Parses functiongemma output: <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
|
| 121 |
-
function parseFunctionCall(output: string): FunctionCallResult {
|
| 122 |
-
const match = output.match(
|
| 123 |
-
/<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>/,
|
| 124 |
-
);
|
| 125 |
-
|
| 126 |
-
if (!match) return null;
|
| 127 |
-
|
| 128 |
-
const funcName = match[1];
|
| 129 |
-
const argsStr = match[2];
|
| 130 |
-
|
| 131 |
-
// Convert function name from snake_case to camelCase
|
| 132 |
-
const functionName = toCamelCase(funcName);
|
| 133 |
-
|
| 134 |
-
// Parse all arguments generically
|
| 135 |
-
// Format: key:<escape>value<escape>
|
| 136 |
-
const args: Record<string, unknown> = {};
|
| 137 |
-
const argRegex = /(\w+):<escape>([^<]*)<escape>/g;
|
| 138 |
-
|
| 139 |
-
for (const argMatch of argsStr.matchAll(argRegex)) {
|
| 140 |
-
const key = toCamelCase(argMatch[1]);
|
| 141 |
-
const value = argMatch[2].trim();
|
| 142 |
-
args[key] = value;
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
return { functionName, args };
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
const workerAPI = {
|
| 149 |
async initModel(progressCallback: ProgressCallback): Promise<void> {
|
| 150 |
console.log("[Worker] initModel called");
|
|
@@ -152,161 +50,17 @@ const workerAPI = {
|
|
| 152 |
console.log("[Worker] initModel completed");
|
| 153 |
},
|
| 154 |
|
| 155 |
-
async
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
}> {
|
| 159 |
-
console.log("[Worker] ========== NEW REQUEST ==========");
|
| 160 |
-
console.log("[Worker] User input:", text);
|
| 161 |
-
|
| 162 |
-
if (!text.trim()) return { functionCall: null };
|
| 163 |
-
|
| 164 |
-
const { model, tokenizer } = await getModel();
|
| 165 |
-
|
| 166 |
-
const messages = [
|
| 167 |
-
{ role: "developer", content: developerPrompt },
|
| 168 |
-
{ role: "user", content: text },
|
| 169 |
-
];
|
| 170 |
-
|
| 171 |
-
console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));
|
| 172 |
-
|
| 173 |
-
const inputs = tokenizer.apply_chat_template(messages, {
|
| 174 |
-
tools,
|
| 175 |
-
tokenize: true,
|
| 176 |
-
add_generation_prompt: true,
|
| 177 |
-
return_dict: true,
|
| 178 |
-
}) as { input_ids: { dims: number[]; data: unknown } };
|
| 179 |
-
|
| 180 |
-
console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
|
| 181 |
-
console.log("[Worker] Input structure:", Object.keys(inputs));
|
| 182 |
-
console.log(
|
| 183 |
-
"[Worker] input_ids type:",
|
| 184 |
-
typeof inputs.input_ids,
|
| 185 |
-
inputs.input_ids.constructor?.name,
|
| 186 |
-
);
|
| 187 |
-
|
| 188 |
-
try {
|
| 189 |
-
const inputIds = Array.from(inputs.input_ids.data as Iterable<number>);
|
| 190 |
-
const decodedInput = tokenizer.decode(inputIds, {
|
| 191 |
-
skip_special_tokens: false,
|
| 192 |
-
});
|
| 193 |
-
console.log("[Worker] Decoded input (what model sees):", decodedInput);
|
| 194 |
-
} catch (e) {
|
| 195 |
-
console.log("[Worker] Could not decode input:", e);
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
// biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
|
| 199 |
-
const output = await (model.generate as Function)({
|
| 200 |
-
...inputs,
|
| 201 |
-
max_new_tokens: 512,
|
| 202 |
-
});
|
| 203 |
-
|
| 204 |
-
const inputLength = inputs.input_ids.dims[1];
|
| 205 |
-
const outputLength = output.dims ? output.dims[1] : output.length;
|
| 206 |
-
console.log("[Worker] Output token count:", outputLength);
|
| 207 |
-
console.log("[Worker] New tokens generated:", outputLength - inputLength);
|
| 208 |
-
|
| 209 |
-
const decoded = tokenizer.decode(output.slice(0, [inputLength, null]), {
|
| 210 |
-
skip_special_tokens: false,
|
| 211 |
-
});
|
| 212 |
-
|
| 213 |
-
console.log("[Worker] Raw model output:", decoded);
|
| 214 |
-
const functionCall = parseFunctionCall(decoded as string);
|
| 215 |
-
console.log("[Worker] Parsed function call:", functionCall);
|
| 216 |
-
|
| 217 |
-
// If function was called, save conversation state for continueWithToolResult
|
| 218 |
-
if (functionCall) {
|
| 219 |
-
console.log(
|
| 220 |
-
"[Worker] ✅ Function call successful:",
|
| 221 |
-
functionCall.functionName,
|
| 222 |
-
);
|
| 223 |
-
|
| 224 |
-
// Save conversation state for the next turn
|
| 225 |
-
const { setMessages, setLastFunctionCall } = conversationStore.getState();
|
| 226 |
-
setMessages([
|
| 227 |
-
{ role: "developer", content: developerPrompt },
|
| 228 |
-
{ role: "user", content: text },
|
| 229 |
-
]);
|
| 230 |
-
setLastFunctionCall(functionCall);
|
| 231 |
-
|
| 232 |
-
return { functionCall };
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
// If model generated an error format, return null to trigger fallback
|
| 236 |
-
if ((decoded as string).includes("<start_function_call>error:")) {
|
| 237 |
-
console.log(
|
| 238 |
-
"[Worker] ⚠️ Model generated error format, triggering fallback",
|
| 239 |
-
);
|
| 240 |
-
return { functionCall: null };
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
console.log("[Worker] ❌ No function call detected, returning fallback");
|
| 244 |
-
|
| 245 |
-
const textResponse = (decoded as string)
|
| 246 |
-
.replace(/<tool_call>[\s\S]*?<\/tool_call>/g, "")
|
| 247 |
-
.replace(/<\|im_end\|>/g, "")
|
| 248 |
-
.replace(/<\|im_start\|>/g, "")
|
| 249 |
-
.replace(/<\|endoftext\|>/g, "")
|
| 250 |
-
.replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
|
| 251 |
-
.replace(/<start_function_response>/g, "")
|
| 252 |
-
.trim();
|
| 253 |
-
|
| 254 |
-
return { functionCall: null, textResponse: textResponse || undefined };
|
| 255 |
-
},
|
| 256 |
-
|
| 257 |
-
async continueWithToolResult(
|
| 258 |
-
functionName: string,
|
| 259 |
-
functionResult: unknown,
|
| 260 |
): Promise<string> {
|
| 261 |
-
console.log("[Worker] ==========
|
| 262 |
-
console.log("[Worker]
|
| 263 |
-
console.log("[Worker] Result:", functionResult);
|
| 264 |
-
|
| 265 |
-
const { messages, lastFunctionCall, clear } = conversationStore.getState();
|
| 266 |
-
|
| 267 |
-
if (messages.length === 0 || !lastFunctionCall) {
|
| 268 |
-
console.log("[Worker] ⚠️ No conversation state found");
|
| 269 |
-
return "I apologize, but I lost track of our conversation. Could you please try again?";
|
| 270 |
-
}
|
| 271 |
|
| 272 |
const { model, tokenizer } = await getModel();
|
| 273 |
|
| 274 |
-
// Build the full conversation with tool result
|
| 275 |
-
// Convert camelCase back to snake_case for the model
|
| 276 |
-
const snakeCaseName = toSnakeCase(lastFunctionCall.functionName);
|
| 277 |
-
|
| 278 |
-
// Build messages in the format expected by the tokenizer
|
| 279 |
-
// Note: Using 'as unknown' to bypass strict typing since the tokenizer
|
| 280 |
-
// accepts a more flexible format than what TypeScript infers
|
| 281 |
-
const fullMessages = [
|
| 282 |
-
...messages,
|
| 283 |
-
// Assistant turn with tool_calls
|
| 284 |
-
{
|
| 285 |
-
role: "assistant",
|
| 286 |
-
tool_calls: [
|
| 287 |
-
{
|
| 288 |
-
type: "function",
|
| 289 |
-
function: {
|
| 290 |
-
name: snakeCaseName,
|
| 291 |
-
arguments: lastFunctionCall.args,
|
| 292 |
-
},
|
| 293 |
-
},
|
| 294 |
-
],
|
| 295 |
-
},
|
| 296 |
-
// Tool response
|
| 297 |
-
{
|
| 298 |
-
role: "tool",
|
| 299 |
-
content: [{ name: snakeCaseName, response: functionResult }],
|
| 300 |
-
},
|
| 301 |
-
] as unknown[];
|
| 302 |
-
|
| 303 |
-
console.log(
|
| 304 |
-
"[Worker] Full messages for final response:",
|
| 305 |
-
JSON.stringify(fullMessages, null, 2),
|
| 306 |
-
);
|
| 307 |
-
|
| 308 |
// biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
|
| 309 |
-
const inputs = tokenizer.apply_chat_template(
|
| 310 |
tools,
|
| 311 |
tokenize: true,
|
| 312 |
add_generation_prompt: true,
|
|
@@ -315,23 +69,10 @@ const workerAPI = {
|
|
| 315 |
|
| 316 |
console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
|
| 317 |
|
| 318 |
-
try {
|
| 319 |
-
const inputIds = Array.from(inputs.input_ids.data as Iterable<number>);
|
| 320 |
-
const decodedInput = tokenizer.decode(inputIds, {
|
| 321 |
-
skip_special_tokens: false,
|
| 322 |
-
});
|
| 323 |
-
console.log(
|
| 324 |
-
"[Worker] Decoded input for final response:",
|
| 325 |
-
decodedInput,
|
| 326 |
-
);
|
| 327 |
-
} catch (e) {
|
| 328 |
-
console.log("[Worker] Could not decode input:", e);
|
| 329 |
-
}
|
| 330 |
-
|
| 331 |
// biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
|
| 332 |
const output = await (model.generate as Function)({
|
| 333 |
...inputs,
|
| 334 |
-
max_new_tokens:
|
| 335 |
});
|
| 336 |
|
| 337 |
const inputLength = inputs.input_ids.dims[1];
|
|
@@ -339,27 +80,9 @@ const workerAPI = {
|
|
| 339 |
skip_special_tokens: false,
|
| 340 |
});
|
| 341 |
|
| 342 |
-
console.log("[Worker] Raw
|
| 343 |
-
|
| 344 |
-
// Clear conversation state
|
| 345 |
-
clear();
|
| 346 |
-
|
| 347 |
-
// Clean up the response
|
| 348 |
-
const cleanResponse = (decoded as string)
|
| 349 |
-
.replace(/<start_function_call>[\s\S]*?<end_function_call>/g, "")
|
| 350 |
-
.replace(/<start_function_response>[\s\S]*?<end_function_response>/g, "")
|
| 351 |
-
.replace(/<\|im_end\|>/g, "")
|
| 352 |
-
.replace(/<\|im_start\|>/g, "")
|
| 353 |
-
.replace(/<\|endoftext\|>/g, "")
|
| 354 |
-
.replace(/<end_of_turn>/g, "")
|
| 355 |
-
.replace(/<start_of_turn>/g, "")
|
| 356 |
-
.replace(/^model\s*/i, "")
|
| 357 |
-
.replace(/^assistant\s*/i, "")
|
| 358 |
-
.trim();
|
| 359 |
-
|
| 360 |
-
console.log("[Worker] ✅ Final response:", cleanResponse);
|
| 361 |
|
| 362 |
-
return
|
| 363 |
},
|
| 364 |
};
|
| 365 |
|
|
|
|
| 5 |
type PreTrainedTokenizer,
|
| 6 |
} from "@huggingface/transformers";
|
| 7 |
import * as Comlink from "comlink";
|
| 8 |
+
import type { ChatMessage, ToolDefinition } from "./types/chat";
|
|
|
|
| 9 |
|
| 10 |
const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
|
| 11 |
|
|
|
|
| 24 |
|
| 25 |
type ProgressCallback = (progress: LoadingProgress) => void;
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
function getModel(onProgress?: ProgressCallback) {
|
| 28 |
modelPromise ??= (async () => {
|
| 29 |
console.log("[Worker] Loading model...");
|
|
|
|
| 43 |
return modelPromise;
|
| 44 |
}
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
const workerAPI = {
|
| 47 |
async initModel(progressCallback: ProgressCallback): Promise<void> {
|
| 48 |
console.log("[Worker] initModel called");
|
|
|
|
| 50 |
console.log("[Worker] initModel completed");
|
| 51 |
},
|
| 52 |
|
| 53 |
+
async generate(
|
| 54 |
+
messages: ChatMessage[],
|
| 55 |
+
tools: ToolDefinition[],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
): Promise<string> {
|
| 57 |
+
console.log("[Worker] ========== GENERATE ==========");
|
| 58 |
+
console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
const { model, tokenizer } = await getModel();
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
// biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
|
| 63 |
+
const inputs = tokenizer.apply_chat_template(messages as any, {
|
| 64 |
tools,
|
| 65 |
tokenize: true,
|
| 66 |
add_generation_prompt: true,
|
|
|
|
| 69 |
|
| 70 |
console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
// biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
|
| 73 |
const output = await (model.generate as Function)({
|
| 74 |
...inputs,
|
| 75 |
+
max_new_tokens: 512,
|
| 76 |
});
|
| 77 |
|
| 78 |
const inputLength = inputs.input_ids.dims[1];
|
|
|
|
| 80 |
skip_special_tokens: false,
|
| 81 |
});
|
| 82 |
|
| 83 |
+
console.log("[Worker] Raw output:", decoded);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
return decoded as string;
|
| 86 |
},
|
| 87 |
};
|
| 88 |
|