Heng2004 commited on
Commit
fe6264c
·
verified ·
1 Parent(s): b396cd3

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +173 -0
model_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ from typing import List, Optional
3
+ import re
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ import qa_store
9
+ from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
10
+
11
+ # -----------------------------
12
+ # Model
13
+ # -----------------------------
14
+ MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.float32,
20
+ )
21
+
22
+ # Load data once at import time
23
+ load_curriculum()
24
+ load_manual_qa()
25
+ rebuild_combined_qa()
26
+
27
+ SYSTEM_PROMPT = (
28
+ "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
29
+ "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
30
+ "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
31
+ "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
32
+ "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
33
+ )
34
+
35
+
36
+ def retrieve_context(question: str, max_entries: int = 2) -> str:
37
+ """
38
+ Simple keyword retrieval over textbook entries.
39
+ """
40
+ if not qa_store.ENTRIES:
41
+ return qa_store.RAW_KNOWLEDGE
42
+
43
+ q = question.lower().strip()
44
+ terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
45
+
46
+ if not terms:
47
+ chosen = qa_store.ENTRIES[:max_entries]
48
+ return "\n\n".join(
49
+ f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
50
+ f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]\n{e['text']}"
51
+ for e in chosen
52
+ )
53
+
54
+ scored = []
55
+
56
+ for e in qa_store.ENTRIES:
57
+ text = e.get("text", "")
58
+ title = e.get("title", "")
59
+ kws = e.get("keywords", [])
60
+ topic = e.get("topic", "")
61
+
62
+ base = (text + " " + title).lower()
63
+ score = 0
64
+
65
+ for t in terms:
66
+ score += base.count(t)
67
+
68
+ for kw in kws:
69
+ kw_lower = kw.lower()
70
+ for t in terms:
71
+ if t in kw_lower:
72
+ score += 2
73
+
74
+ if topic and any(t in topic for t in terms):
75
+ score += 1
76
+
77
+ if score > 0:
78
+ scored.append((score, e))
79
+
80
+ scored.sort(key=lambda x: x[0], reverse=True)
81
+ top_entries = [e for _, e in scored[:max_entries]]
82
+
83
+ if not top_entries:
84
+ top_entries = qa_store.ENTRIES[:max_entries]
85
+
86
+ context_blocks = []
87
+ for e in top_entries:
88
+ header = (
89
+ f"[ຊັ້ນ {e.get('grade','')}, "
90
+ f"ບົດ {e.get('chapter','')}, "
91
+ f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]"
92
+ )
93
+ context_blocks.append(f"{header}\n{e.get('text','')}")
94
+
95
+ return "\n\n".join(context_blocks)
96
+
97
+
98
+ def build_prompt(question: str) -> str:
99
+ context = retrieve_context(question)
100
+ return f"""{SYSTEM_PROMPT}
101
+
102
+ ຂໍ້ມູນອ້າງອີງ:
103
+ {context}
104
+
105
+ ຄຳຖາມ: {question}
106
+
107
+ ຄຳຕອບດ້ວຍພາສາລາວ:"""
108
+
109
+
110
+ def generate_answer(question: str) -> str:
111
+ prompt = build_prompt(question)
112
+ inputs = tokenizer(prompt, return_tensors="pt")
113
+ with torch.no_grad():
114
+ outputs = model.generate(
115
+ **inputs,
116
+ max_new_tokens=160,
117
+ do_sample=False,
118
+ )
119
+
120
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
121
+ answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
122
+ return answer.strip()
123
+
124
+
125
+ def answer_from_qa(question: str) -> Optional[str]:
126
+ """
127
+ 1) exact match in QA_INDEX
128
+ 2) fuzzy match via word overlap with ALL_QA_KNOWLEDGE
129
+ """
130
+ norm_q = qa_store.normalize_question(question)
131
+ if not norm_q:
132
+ return None
133
+
134
+ if norm_q in qa_store.QA_INDEX:
135
+ return qa_store.QA_INDEX[norm_q]
136
+
137
+ q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
138
+ if not q_terms:
139
+ return None
140
+
141
+ best_score = 0
142
+ best_answer: Optional[str] = None
143
+
144
+ for item in qa_store.ALL_QA_KNOWLEDGE:
145
+ stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
146
+ overlap = sum(1 for t in q_terms if t in stored_terms)
147
+ if overlap > best_score:
148
+ best_score = overlap
149
+ best_answer = item["a"]
150
+
151
+ if best_score >= 1:
152
+ return best_answer
153
+
154
+ return None
155
+
156
+
157
+ def laos_history_bot(message: str, history: List) -> str:
158
+ """
159
+ Main chatbot function for Student tab.
160
+ """
161
+ if not message.strip():
162
+ return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
163
+
164
+ direct = answer_from_qa(message)
165
+ if direct:
166
+ return direct
167
+
168
+ try:
169
+ answer = generate_answer(message)
170
+ except Exception as e: # noqa: BLE001
171
+ return f"ລະບົ��ມີບັນຫາ: {e}"
172
+
173
+ return answer