File size: 2,659 Bytes
a116036
 
 
 
 
 
cd4ccbf
e214ac8
baf49f3
a116036
baf49f3
a116036
 
 
 
9014640
a116036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e477bf
a116036
 
 
 
 
 
 
 
baf49f3
 
a116036
 
 
 
 
 
e06733f
e214ac8
 
 
84928fb
e214ac8
 
84928fb
 
 
 
e214ac8
84928fb
 
 
 
 
 
 
 
 
 
 
e214ac8
84928fb
 
 
 
 
 
 
e214ac8
84928fb
e214ac8
84928fb
9014640
baf49f3
a116036
9014640
a116036
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import {
  AutoModelForCausalLM,
  AutoTokenizer,
  type PreTrainedModel,
  type PreTrainedTokenizer,
} from "@huggingface/transformers";
import * as Comlink from "comlink";
import type { ChatMessage, ToolDefinition } from "./types/chat";

const MODEL_ID = "onnx-community/functiongemma-270m-it-ONNX";

let modelPromise: Promise<{
  model: PreTrainedModel;
  tokenizer: PreTrainedTokenizer;
}> | null = null;

export type LoadingProgress = {
  status: "initiate" | "download" | "progress" | "done" | "ready";
  file?: string;
  progress?: number;
  loaded?: number;
  total?: number;
};

type ProgressCallback = (progress: LoadingProgress) => void;

function getModel(onProgress?: ProgressCallback) {
  modelPromise ??= (async () => {
    console.log("[Worker] Loading model...");
    const [tokenizer, model] = await Promise.all([
      AutoTokenizer.from_pretrained(MODEL_ID, {
        progress_callback: onProgress,
      }),
      AutoModelForCausalLM.from_pretrained(MODEL_ID, {
        dtype: "fp16",
        device: "webgpu",
        progress_callback: onProgress,
      }),
    ]);
    console.log("[Worker] Model loaded!");
    return { model, tokenizer };
  })();
  return modelPromise;
}

const workerAPI = {
  async initModel(progressCallback: ProgressCallback): Promise<void> {
    console.log("[Worker] initModel called");
    await getModel(progressCallback);
    console.log("[Worker] initModel completed");
  },

  async generate(
    messages: ChatMessage[],
    tools: ToolDefinition[],
  ): Promise<string> {
    console.log("[Worker] ========== GENERATE ==========");
    console.log("[Worker] Messages:", JSON.stringify(messages, null, 2));

    const { model, tokenizer } = await getModel();

    // biome-ignore lint/suspicious/noExplicitAny: tokenizer accepts flexible message format
    const inputs = tokenizer.apply_chat_template(messages as any, {
      tools,
      tokenize: true,
      add_generation_prompt: true,
      return_dict: true,
    }) as { input_ids: { dims: number[]; data: unknown } };

    console.log("[Worker] Input token count:", inputs.input_ids.dims[1]);

    // biome-ignore lint/complexity/noBannedTypes: model.generate type is not properly typed in transformers.js
    const output = await (model.generate as Function)({
      ...inputs,
      max_new_tokens: 512,
    });

    const inputLength = inputs.input_ids.dims[1];
    const decoded = tokenizer.decode(output.slice(0, [inputLength, null]), {
      skip_special_tokens: false,
    });

    console.log("[Worker] Raw output:", decoded);

    return decoded as string;
  },
};

export type WorkerAPI = typeof workerAPI;

Comlink.expose(workerAPI);