amine_dubs
commited on
Commit
·
ec41997
1
Parent(s):
050f2a9
changed model
Browse files- backend/main.py +25 -57
- static/script.js +25 -8
backend/main.py
CHANGED
|
@@ -87,8 +87,8 @@ def initialize_model():
|
|
| 87 |
try:
|
| 88 |
print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
|
| 89 |
|
| 90 |
-
# Use a
|
| 91 |
-
model_name = "
|
| 92 |
|
| 93 |
# Check for available device - properly detect CPU/GPU
|
| 94 |
device = "cpu" # Default to CPU which is more reliable
|
|
@@ -101,7 +101,8 @@ def initialize_model():
|
|
| 101 |
print(f"Loading tokenizer from {model_name}...")
|
| 102 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 103 |
model_name,
|
| 104 |
-
cache_dir="/tmp/transformers_cache"
|
|
|
|
| 105 |
)
|
| 106 |
if tokenizer is None:
|
| 107 |
print("Failed to load tokenizer")
|
|
@@ -130,7 +131,7 @@ def initialize_model():
|
|
| 130 |
try:
|
| 131 |
# Create the pipeline with explicit model and tokenizer
|
| 132 |
translator = pipeline(
|
| 133 |
-
"
|
| 134 |
model=model,
|
| 135 |
tokenizer=tokenizer,
|
| 136 |
device=0 if device == "cuda" else -1, # Proper device mapping
|
|
@@ -142,7 +143,8 @@ def initialize_model():
|
|
| 142 |
return False
|
| 143 |
|
| 144 |
# Test the model with a simple translation to verify it works
|
| 145 |
-
|
|
|
|
| 146 |
print(f"Model test result: {test_result}")
|
| 147 |
if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
|
| 148 |
print("Model test failed: Invalid output format")
|
|
@@ -176,32 +178,25 @@ def translate_text(text, source_lang, target_lang):
|
|
| 176 |
return use_fallback_translation(text, source_lang, target_lang)
|
| 177 |
|
| 178 |
try:
|
| 179 |
-
# Prepare input with explicit instruction format for better results with
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
input_text = f"You are a bilingual in {source_lang} and Arabic, a professional translator, translate this script from {source_lang} to Arabic MSA with cultural sensitivity and accuracy, with a focus on meaning and eloquence (Balagha), avoiding overly literal translations.: {text}"
|
| 183 |
-
else:
|
| 184 |
-
input_text = f"Translate from {source_lang} to {target_lang}: {text}"
|
| 185 |
|
| 186 |
# Use a more reliable timeout approach with concurrent.futures
|
| 187 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 188 |
future = executor.submit(
|
| 189 |
lambda: translator(
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
)[0]["
|
| 195 |
)
|
| 196 |
|
| 197 |
try:
|
| 198 |
# Set a reasonable timeout (15 seconds instead of 10)
|
| 199 |
result = future.result(timeout=15)
|
| 200 |
|
| 201 |
-
# Clean up result (remove any instruction preamble if present)
|
| 202 |
-
if ':' in result and len(result.split(':', 1)) > 1:
|
| 203 |
-
result = result.split(':', 1)[1].strip()
|
| 204 |
-
|
| 205 |
return result
|
| 206 |
except concurrent.futures.TimeoutError:
|
| 207 |
print(f"Model inference timed out after 15 seconds, falling back to online translation")
|
|
@@ -230,8 +225,8 @@ def check_and_reinitialize_model():
|
|
| 230 |
return initialize_model()
|
| 231 |
|
| 232 |
# Test the existing model with a simple translation
|
| 233 |
-
test_text = "
|
| 234 |
-
result = translator(test_text, max_length=128)
|
| 235 |
|
| 236 |
# If we got a valid result, model is working fine
|
| 237 |
if result and isinstance(result, list) and len(result) > 0:
|
|
@@ -388,51 +383,24 @@ async def translate_text_endpoint(request: TranslationRequest):
|
|
| 388 |
raise Exception("Failed to initialize translation model")
|
| 389 |
|
| 390 |
# Format the prompt for the model
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
"zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ar": "Arabic",
|
| 394 |
-
"ru": "Russian", "pt": "Portuguese", "it": "Italian", "nl": "Dutch"
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
source_lang_name = lang_code_map.get(source_lang.lower(), source_lang)
|
| 398 |
-
target_lang_name = lang_code_map.get(target_lang.lower(), target_lang)
|
| 399 |
|
| 400 |
-
# Create a proper prompt for instruction-based models
|
| 401 |
-
prompt = f"Translate from {source_lang_name} to {target_lang_name}: {text}"
|
| 402 |
-
print(f"Using prompt: {prompt}")
|
| 403 |
-
|
| 404 |
-
# Check that translator is callable before proceeding
|
| 405 |
-
if not callable(translator):
|
| 406 |
-
print("[DEBUG] Translator is not callable, attempting to reinitialize")
|
| 407 |
-
success = initialize_model()
|
| 408 |
-
if not success or not callable(translator):
|
| 409 |
-
raise Exception("Translator is not callable after reinitialization")
|
| 410 |
print("[DEBUG] Calling translator model...")
|
| 411 |
# Use a thread pool to execute the translation with a timeout
|
| 412 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 413 |
future = executor.submit(
|
| 414 |
lambda: translator(
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
)
|
| 420 |
)
|
| 421 |
|
| 422 |
try:
|
| 423 |
result = future.result(timeout=15)
|
| 424 |
-
|
| 425 |
-
if not result or not isinstance(result, list) or len(result) == 0:
|
| 426 |
-
raise Exception(f"Invalid model output format: {result}")
|
| 427 |
-
|
| 428 |
-
translation_result = result[0]["generated_text"]
|
| 429 |
-
|
| 430 |
-
# Clean up the output - remove any prefix like "Translation:"
|
| 431 |
-
prefixes = ["Translation:", "Translation: ", f"{target_lang_name}:", f"{target_lang_name}: "]
|
| 432 |
-
for prefix in prefixes:
|
| 433 |
-
if translation_result.startswith(prefix):
|
| 434 |
-
translation_result = translation_result[len(prefix):].strip()
|
| 435 |
-
|
| 436 |
print(f"Local model translation result: {translation_result}")
|
| 437 |
except concurrent.futures.TimeoutError:
|
| 438 |
print("Translation timed out after 15 seconds")
|
|
|
|
| 87 |
try:
|
| 88 |
print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
|
| 89 |
|
| 90 |
+
# Use a better translation model that handles multilingual tasks well
|
| 91 |
+
model_name = "facebook/nllb-200-distilled-600M" # Better multilingual translation model
|
| 92 |
|
| 93 |
# Check for available device - properly detect CPU/GPU
|
| 94 |
device = "cpu" # Default to CPU which is more reliable
|
|
|
|
| 101 |
print(f"Loading tokenizer from {model_name}...")
|
| 102 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 103 |
model_name,
|
| 104 |
+
cache_dir="/tmp/transformers_cache",
|
| 105 |
+
use_fast=True # Use faster tokenizer when possible
|
| 106 |
)
|
| 107 |
if tokenizer is None:
|
| 108 |
print("Failed to load tokenizer")
|
|
|
|
| 131 |
try:
|
| 132 |
# Create the pipeline with explicit model and tokenizer
|
| 133 |
translator = pipeline(
|
| 134 |
+
"translation",
|
| 135 |
model=model,
|
| 136 |
tokenizer=tokenizer,
|
| 137 |
device=0 if device == "cuda" else -1, # Proper device mapping
|
|
|
|
| 143 |
return False
|
| 144 |
|
| 145 |
# Test the model with a simple translation to verify it works
|
| 146 |
+
# NLLB needs language codes in format like "eng_Latn" and "ara_Arab"
|
| 147 |
+
test_result = translator("hello", src_lang="eng_Latn", tgt_lang="ara_Arab", max_length=128)
|
| 148 |
print(f"Model test result: {test_result}")
|
| 149 |
if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
|
| 150 |
print("Model test failed: Invalid output format")
|
|
|
|
| 178 |
return use_fallback_translation(text, source_lang, target_lang)
|
| 179 |
|
| 180 |
try:
|
| 181 |
+
# Prepare input with explicit instruction format for better results with NLLB
|
| 182 |
+
src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
|
| 183 |
+
tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
# Use a more reliable timeout approach with concurrent.futures
|
| 186 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 187 |
future = executor.submit(
|
| 188 |
lambda: translator(
|
| 189 |
+
text,
|
| 190 |
+
src_lang=src_lang_code,
|
| 191 |
+
tgt_lang=tgt_lang_code,
|
| 192 |
+
max_length=512
|
| 193 |
+
)[0]["translation_text"]
|
| 194 |
)
|
| 195 |
|
| 196 |
try:
|
| 197 |
# Set a reasonable timeout (15 seconds instead of 10)
|
| 198 |
result = future.result(timeout=15)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
return result
|
| 201 |
except concurrent.futures.TimeoutError:
|
| 202 |
print(f"Model inference timed out after 15 seconds, falling back to online translation")
|
|
|
|
| 225 |
return initialize_model()
|
| 226 |
|
| 227 |
# Test the existing model with a simple translation
|
| 228 |
+
test_text = "hello"
|
| 229 |
+
result = translator(test_text, src_lang="eng_Latn", tgt_lang="fra_Latn", max_length=128)
|
| 230 |
|
| 231 |
# If we got a valid result, model is working fine
|
| 232 |
if result and isinstance(result, list) and len(result) > 0:
|
|
|
|
| 383 |
raise Exception("Failed to initialize translation model")
|
| 384 |
|
| 385 |
# Format the prompt for the model
|
| 386 |
+
src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
|
| 387 |
+
tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
print("[DEBUG] Calling translator model...")
|
| 390 |
# Use a thread pool to execute the translation with a timeout
|
| 391 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 392 |
future = executor.submit(
|
| 393 |
lambda: translator(
|
| 394 |
+
text,
|
| 395 |
+
src_lang=src_lang_code,
|
| 396 |
+
tgt_lang=tgt_lang_code,
|
| 397 |
+
max_length=512
|
| 398 |
+
)[0]["translation_text"]
|
| 399 |
)
|
| 400 |
|
| 401 |
try:
|
| 402 |
result = future.result(timeout=15)
|
| 403 |
+
translation_result = result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
print(f"Local model translation result: {translation_result}")
|
| 405 |
except concurrent.futures.TimeoutError:
|
| 406 |
print("Translation timed out after 15 seconds")
|
static/script.js
CHANGED
|
@@ -56,15 +56,16 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
| 56 |
docLoadingIndicator.style.display = 'none';
|
| 57 |
}
|
| 58 |
|
| 59 |
-
//
|
| 60 |
if (textForm) {
|
| 61 |
textForm.addEventListener('submit', async (e) => {
|
| 62 |
e.preventDefault();
|
| 63 |
clearFeedback();
|
| 64 |
|
| 65 |
-
|
| 66 |
-
const
|
| 67 |
-
const
|
|
|
|
| 68 |
|
| 69 |
if (!sourceText) {
|
| 70 |
displayError('Please enter text to translate');
|
|
@@ -72,8 +73,16 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
| 72 |
}
|
| 73 |
|
| 74 |
try {
|
| 75 |
-
// Show loading state
|
| 76 |
-
document.getElementById('text-loading')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
// Log payload for debugging
|
| 79 |
console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
|
|
@@ -91,7 +100,7 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
| 91 |
});
|
| 92 |
|
| 93 |
// Hide loading state
|
| 94 |
-
|
| 95 |
|
| 96 |
// Log response status
|
| 97 |
console.log('Response status:', response.status);
|
|
@@ -112,13 +121,21 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
| 112 |
return;
|
| 113 |
}
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
textOutput.textContent = data.translated_text;
|
| 116 |
textResultBox.style.display = 'block';
|
| 117 |
|
| 118 |
} catch (error) {
|
| 119 |
console.error('Error:', error);
|
| 120 |
displayError('Network error or invalid response format');
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
}
|
| 123 |
});
|
| 124 |
}
|
|
|
|
| 56 |
docLoadingIndicator.style.display = 'none';
|
| 57 |
}
|
| 58 |
|
| 59 |
+
// Fix the text form submission handler to use correct field IDs
|
| 60 |
if (textForm) {
|
| 61 |
textForm.addEventListener('submit', async (e) => {
|
| 62 |
e.preventDefault();
|
| 63 |
clearFeedback();
|
| 64 |
|
| 65 |
+
// Use correct field IDs matching the HTML
|
| 66 |
+
const sourceText = document.getElementById('text-input').value.trim();
|
| 67 |
+
const sourceLang = document.getElementById('source-lang-text').value;
|
| 68 |
+
const targetLang = document.getElementById('target-lang-text').value;
|
| 69 |
|
| 70 |
if (!sourceText) {
|
| 71 |
displayError('Please enter text to translate');
|
|
|
|
| 73 |
}
|
| 74 |
|
| 75 |
try {
|
| 76 |
+
// Show loading state (create it if missing)
|
| 77 |
+
let textLoading = document.getElementById('text-loading');
|
| 78 |
+
if (!textLoading) {
|
| 79 |
+
textLoading = document.createElement('div');
|
| 80 |
+
textLoading.id = 'text-loading';
|
| 81 |
+
textLoading.className = 'loading-spinner';
|
| 82 |
+
textLoading.innerHTML = 'Translating...';
|
| 83 |
+
textForm.appendChild(textLoading);
|
| 84 |
+
}
|
| 85 |
+
textLoading.style.display = 'block';
|
| 86 |
|
| 87 |
// Log payload for debugging
|
| 88 |
console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
|
|
|
|
| 100 |
});
|
| 101 |
|
| 102 |
// Hide loading state
|
| 103 |
+
textLoading.style.display = 'none';
|
| 104 |
|
| 105 |
// Log response status
|
| 106 |
console.log('Response status:', response.status);
|
|
|
|
| 121 |
return;
|
| 122 |
}
|
| 123 |
|
| 124 |
+
if (!data.translated_text) {
|
| 125 |
+
displayError('Translation returned empty text');
|
| 126 |
+
return;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
textOutput.textContent = data.translated_text;
|
| 130 |
textResultBox.style.display = 'block';
|
| 131 |
|
| 132 |
} catch (error) {
|
| 133 |
console.error('Error:', error);
|
| 134 |
displayError('Network error or invalid response format');
|
| 135 |
+
|
| 136 |
+
// Hide loading if it exists
|
| 137 |
+
const textLoading = document.getElementById('text-loading');
|
| 138 |
+
if (textLoading) textLoading.style.display = 'none';
|
| 139 |
}
|
| 140 |
});
|
| 141 |
}
|