Spaces:
Running
Running
File size: 2,659 Bytes
a116036 cd4ccbf e214ac8 baf49f3 a116036 baf49f3 a116036 9014640 a116036 0e477bf a116036 baf49f3 a116036 e06733f e214ac8 84928fb e214ac8 84928fb e214ac8 84928fb e214ac8 84928fb e214ac8 84928fb e214ac8 84928fb 9014640 baf49f3 a116036 9014640 a116036 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import {
AutoModelForCausalLM,
AutoTokenizer,
type PreTrainedModel,
type PreTrainedTokenizer,
} from "@huggingface/transformers";
import * as Comlink from "comlink";
import type { ChatMessage, ToolDefinition } from "./types/chat";
const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";
let modelPromise: Promise<{
model: PreTrainedModel;
tokenizer: PreTrainedTokenizer;
}> | null = null;
export type LoadingProgress = {
status: "initiate" | "download" | "progress" | "done" | "ready";
file?: string;
progress?: number;
loaded?: number;
total?: number;
};
type ProgressCallback = (progress: LoadingProgress) => void;
function getModel(onProgress?: ProgressCallback) {
modelPromise ??= (async () => {
console.log("[Worker] Loading model...");
const [tokenizer, model] = await Promise.all([
AutoTokenizer.from_pretrained(MODEL_ID, {
progress_callback: onProgress,
}),
AutoModelForCausalLM.from_pretrained(MODEL_ID, {
dtype: "fp16",
device: "webgpu",
progress_callback: onProgress,
}),
]);
console.log("[Worker] Model loaded!");
return { model, tokenizer };
})();
return modelPromise;
}
const workerAPI = {
async initModel(progressCallback: ProgressCallback): Promise<void> {
console.log("[Worker] initModel called");
await getModel(progressCallback);
console.log("[Worker] initModel completed");
},
async generate(
messages: ChatMessage[],
tools: ToolDefinition[],
): Promise<string> {
console.log("[Worker] ========== GENERATE ==========");
console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));
const { model, tokenizer } = await getModel();
// biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
const inputs = tokenizer.apply_chat_template(messages as any, {
tools,
tokenize: true,
add_generation_prompt: true,
return_dict: true,
}) as { input_ids: { dims: number[]; data: unknown } };
console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);
// biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
const output = await (model.generate as Function)({
...inputs,
max_new_tokens: 512,
});
const inputLength = inputs.input_ids.dims[1];
const decoded = tokenizer.decode(output.slice(0, [inputLength, null]), {
skip_special_tokens: false,
});
console.log("[Worker] Raw output:", decoded);
return decoded as string;
},
};
export type WorkerAPI = typeof workerAPI;
Comlink.expose(workerAPI);
|