Nefertury commited on
Commit
7327516
·
verified ·
1 Parent(s): 166f868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -60
app.py CHANGED
@@ -2,10 +2,9 @@ import os
2
  import torch
3
  import gradio as gr
4
  import requests
5
- from typing import List, Dict
6
- from threading import Lock
7
-
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
  from peft import PeftModel
10
 
11
  # --- 1. Конфигурация и загрузка модели ---
@@ -24,7 +23,7 @@ TOP_P = 0.9
24
  REPETITION_PENALTY = 1.05
25
  SYS_PROMPT_TT = (
26
  "Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир."
27
- "мәгълүмат җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. "
28
  "Һәрвакыт татарча гына җавап бир."
29
  )
30
 
@@ -40,13 +39,11 @@ model.config.use_cache = True
40
  model.eval()
41
  print("✅ Модель успешно загружена!")
42
 
43
- # --- 2. Логика приложения (функции перевода и генерации) ---
44
 
45
  YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
46
- YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect" # НОВЫЙ URL ДЛЯ ОПРЕДЕЛЕНИЯ ЯЗЫКА
47
- generation_lock = Lock()
48
 
49
- # НОВАЯ ФУНКЦИЯ для определения языка 🧠
50
  def detect_language(text: str) -> str:
51
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
52
  payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
@@ -54,40 +51,32 @@ def detect_language(text: str) -> str:
54
  resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
55
  resp.raise_for_status()
56
  data = resp.json()
57
- return data.get("languageCode", "ru") # Если не определился, считаем, что русский
58
  except requests.exceptions.RequestException as e:
59
  print(f"Ошибка определения языка: {e}")
60
- return "ru" # В случае ошибки считаем, что это русский для безопасности
61
 
62
- def _yandex_translate(texts: List[str], source: str, target: str) -> List[str]:
63
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
64
- payload = {"folderId": YANDEX_FOLDER_ID, "texts": texts, "sourceLanguageCode": source, "targetLanguageCode": target}
65
  try:
66
  resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
67
  resp.raise_for_status()
68
- data = resp.json()
69
- return [item["text"] for item in data["translations"]]
70
  except requests.exceptions.RequestException as e:
71
  print(f"Ошибка перевода: {e}")
72
- return [f"Ошибка перевода: {text}" for text in texts]
73
-
74
- def ru2tt(text: str) -> str:
75
- return _yandex_translate([text], "ru", "tt")[0]
76
-
77
- def tt2ru(text: str) -> str:
78
- return _yandex_translate([text], "tt", "ru")[0]
79
 
80
  def render_prompt(messages: List[Dict[str, str]]) -> str:
 
81
  if getattr(tok, "chat_template", None):
82
  try:
83
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
84
- except Exception:
85
- pass
86
  sys_text = ""
87
  turns = []
88
  for m in messages:
89
- if m["role"] == "system":
90
- sys_text += m["content"].strip() + "\n"
91
  i = 0
92
  while i < len(messages):
93
  m = messages[i]
@@ -95,32 +84,51 @@ def render_prompt(messages: List[Dict[str, str]]) -> str:
95
  next_assistant = None
96
  if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant":
97
  next_assistant = messages[i + 1]["content"]
98
- if len(turns) == 0 and sys_text:
99
- user_block = f"<<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n{m['content']}"
100
- else:
101
- user_block = m["content"]
102
  if next_assistant is None:
103
  turns.append(f"<s>[INST] {user_block} [/INST]")
104
  else:
105
  turns.append(f"<s>[INST] {user_block} [/INST] {next_assistant}</s>")
106
  i += 1
107
  i += 1
108
- if not turns:
109
- return f"<s>[INST] <<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n [/INST]" if sys_text else "<s>[INST] [/INST]"
110
- return "".join(turns)
111
 
 
112
  @torch.inference_mode()
113
- def generate_tt_reply(messages: List[Dict[str, str]]) -> str:
114
- with generation_lock:
115
- prompt = render_prompt(messages)
116
- inputs = tok(prompt, return_tensors="pt").to(model.device)
117
- out = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id)
118
- gen_text = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
119
- return gen_text.strip()
120
-
121
- # --- 3. Gradio интерфейс ---
122
-
123
- def chat_fn(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # 1. Формируем историю для модели
125
  messages = [{"role": "system", "content": SYS_PROMPT_TT}]
126
  for user_msg, bot_msg in history:
@@ -128,29 +136,25 @@ def chat_fn(message, history):
128
  if bot_msg:
129
  messages.append({"role": "assistant", "content": bot_msg})
130
 
131
- # 2. ОПРЕДЕЛЯЕМ ЯЗЫК и переводим, если нужно 🛡️
132
  detected_lang = detect_language(message)
133
- if detected_lang != "tt":
134
- user_tt = ru2tt(message)
135
- else:
136
- user_tt = message # Уже на татарском, используем как есть
137
-
138
  messages.append({"role": "user", "content": user_tt})
139
 
140
- # 3. Генерируем ответ модели
141
- tt_reply = generate_tt_reply(messages)
142
-
143
- # 4. Добавляем в историю татарский вопрос и татарский ответ
144
- history.append([user_tt, tt_reply])
145
 
146
- # 5. Возвращаем полную историю на татарском
147
- return history
 
 
148
 
149
- # Создаем интерфейс с татарскими надписями
150
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
151
  gr.Markdown("## Татарский чат-бот от команды Сбера")
152
- chatbot = gr.Chatbot(label="Диалог", height=500)
153
- msg = gr.Textbox(label="Хәбәрегезне рус яки татар телендә языгыз", placeholder="Татарстанның башкаласы нинди шәһәр?")
154
  clear = gr.Button("🗑️ Чистарту")
155
 
156
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=chatbot)
 
2
  import torch
3
  import gradio as gr
4
  import requests
5
+ from typing import List, Dict, Iterator
6
+ from threading import Thread
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
 
8
  from peft import PeftModel
9
 
10
  # --- 1. Конфигурация и загрузка модели ---
 
23
  REPETITION_PENALTY = 1.05
24
  SYS_PROMPT_TT = (
25
  "Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир."
26
+ "мәгълүmat җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. "
27
  "Һәрвакыт татарча гына җавап бир."
28
  )
29
 
 
39
  model.eval()
40
  print("✅ Модель успешно загружена!")
41
 
42
+ # --- 2. Логика приложения (с изменениями для стриминга) ---
43
 
44
  YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
45
+ YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
 
46
 
 
47
  def detect_language(text: str) -> str:
48
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
49
  payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
 
51
  resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
52
  resp.raise_for_status()
53
  data = resp.json()
54
+ return data.get("languageCode", "ru")
55
  except requests.exceptions.RequestException as e:
56
  print(f"Ошибка определения языка: {e}")
57
+ return "ru"
58
 
59
+ def ru2tt(text: str) -> str:
60
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
61
+ payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"}
62
  try:
63
  resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
64
  resp.raise_for_status()
65
+ return resp.json()["translations"][0]["text"]
 
66
  except requests.exceptions.RequestException as e:
67
  print(f"Ошибка перевода: {e}")
68
+ return f"Ошибка перевода: {text}"
 
 
 
 
 
 
69
 
70
  def render_prompt(messages: List[Dict[str, str]]) -> str:
71
+ # Ваша функция render_prompt остается без изменений
72
  if getattr(tok, "chat_template", None):
73
  try:
74
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
+ except Exception: pass
 
76
  sys_text = ""
77
  turns = []
78
  for m in messages:
79
+ if m["role"] == "system": sys_text += m["content"].strip() + "\n"
 
80
  i = 0
81
  while i < len(messages):
82
  m = messages[i]
 
84
  next_assistant = None
85
  if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant":
86
  next_assistant = messages[i + 1]["content"]
87
+ user_block = f"<<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n{m['content']}" if len(turns) == 0 and sys_text else m['content']
 
 
 
88
  if next_assistant is None:
89
  turns.append(f"<s>[INST] {user_block} [/INST]")
90
  else:
91
  turns.append(f"<s>[INST] {user_block} [/INST] {next_assistant}</s>")
92
  i += 1
93
  i += 1
94
+ return "".join(turns) if turns else (f"<s>[INST] <<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n [/INST]" if sys_text else "<s>[INST] [/INST]")
 
 
95
 
96
+ # ❗ ИЗМЕНЕННАЯ ФУНКЦИЯ ГЕНЕРАЦИИ
97
  @torch.inference_mode()
98
+ def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
99
+ prompt = render_prompt(messages)
100
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
101
+
102
+ # Создаем streamer
103
+ streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
104
+
105
+ # Аргументы для генерации
106
+ generation_kwargs = dict(
107
+ inputs,
108
+ streamer=streamer,
109
+ max_new_tokens=MAX_NEW_TOKENS,
110
+ do_sample=True,
111
+ temperature=TEMPERATURE,
112
+ top_p=TOP_P,
113
+ repetition_penalty=REPETITION_PENALTY,
114
+ eos_token_id=tok.eos_token_id,
115
+ pad_token_id=tok.pad_token_id,
116
+ )
117
+
118
+ # Запускаем генерацию в отдельном потоке
119
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
120
+ thread.start()
121
+
122
+ # Yield'им каждый новый кусочек текста
123
+ generated_text = ""
124
+ for new_text in streamer:
125
+ generated_text += new_text
126
+ yield generated_text
127
+
128
+ # --- 3. Gradio интерфейс (с изменениями для стриминга) ---
129
+
130
+ # ❗ ИЗМЕНЕННАЯ ФУНКЦИЯ-КОНТРОЛЛЕР
131
+ def chat_fn(message: str, history: list) -> Iterator[list]:
132
  # 1. Формируем историю для модели
133
  messages = [{"role": "system", "content": SYS_PROMPT_TT}]
134
  for user_msg, bot_msg in history:
 
136
  if bot_msg:
137
  messages.append({"role": "assistant", "content": bot_msg})
138
 
139
+ # 2. Определяем язык и переводим, если нужно
140
  detected_lang = detect_language(message)
141
+ user_tt = ru2tt(message) if detected_lang != "tt" else message
142
+
 
 
 
143
  messages.append({"role": "user", "content": user_tt})
144
 
145
+ # 3. Добавляем в историю сообщение пользователя и пустой ответ бота
146
+ history.append([user_tt, ""])
 
 
 
147
 
148
+ # 4. Стримим ответ модели и обновляем историю на лету
149
+ for partial_response in generate_tt_reply_stream(messages):
150
+ history[-1][1] = partial_response # Обновляем последнее сообщение в истории
151
+ yield history # Возвращаем всю историю на каждом шаге
152
 
153
+ # Создаем интерфейс
154
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
155
  gr.Markdown("## Татарский чат-бот от команды Сбера")
156
+ chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False)
157
+ msg = gr.Textbox(label="Хәбәрегезне рус яки татар телендә языгыз", placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?")
158
  clear = gr.Button("🗑️ Чистарту")
159
 
160
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=chatbot)