import gc from functools import partial import gradio as gr import torch from langdetect import detect, LangDetectException from transformers import MarianMTModel, MarianTokenizer from utils import get_pytorch_device, spaces_gpu, get_torch_dtype # Language code mapping to Helsinki-NLP translation models # If a specific language pair model doesn't exist, we'll use the multilingual model LANGUAGE_TO_MODEL_MAP = { "fr": "Helsinki-NLP/opus-mt-fr-en", "de": "Helsinki-NLP/opus-mt-de-en", "es": "Helsinki-NLP/opus-mt-es-en", "it": "Helsinki-NLP/opus-mt-it-en", "pt": "Helsinki-NLP/opus-mt-pt-en", "ru": "Helsinki-NLP/opus-mt-ru-en", "zh": "Helsinki-NLP/opus-mt-zh-en", "ja": "Helsinki-NLP/opus-mt-ja-en", "ko": "Helsinki-NLP/opus-mt-ko-en", "ar": "Helsinki-NLP/opus-mt-ar-en", "nl": "Helsinki-NLP/opus-mt-nl-en", "pl": "Helsinki-NLP/opus-mt-pl-en", "tr": "Helsinki-NLP/opus-mt-tr-en", "vi": "Helsinki-NLP/opus-mt-vi-en", "hi": "Helsinki-NLP/opus-mt-hi-en", "cs": "Helsinki-NLP/opus-mt-cs-en", "sv": "Helsinki-NLP/opus-mt-sv-en", "fi": "Helsinki-NLP/opus-mt-fi-en", "uk": "Helsinki-NLP/opus-mt-uk-en", "ro": "Helsinki-NLP/opus-mt-ro-en", "th": "Helsinki-NLP/opus-mt-th-en", } def detect_language(text: str) -> str: """Detect the language of the input text using langdetect library. Uses the langdetect library, which is a Python port of Google's language-detection library. It supports over 55 languages and is known for high accuracy, especially for languages with unique character sets like Korean, Japanese, and Chinese. Args: text: Input text to detect the language of. Returns: ISO 639-1 language code (e.g., "en", "fr", "de", "ko", "ja") of the detected language. Raises: LangDetectException: If the language cannot be detected (e.g., text is too short). """ try: language_code = detect(text) return language_code except LangDetectException: # If detection fails, default to English (will be handled by translation logic) return "en" def get_translation_model(language_code: str, fallback_model: str) -> str: """Get the appropriate translation model for a given language code. Args: language_code: ISO 639-1 language code (e.g., "fr", "de", "en"). fallback_model: Fallback model to use if no specific model is available. Returns: Model ID for translation, or fallback model if language not in mapping. """ if language_code == "en": return None # Already in English return LANGUAGE_TO_MODEL_MAP.get(language_code, fallback_model) @spaces_gpu def translate_to_english(fallback_translation_model: str, text: str) -> str: """Translate text to English using automatic language detection. First detects the source language using the langdetect library, then selects the appropriate translation model and translates the text to English using a local MarianMT model. Args: fallback_translation_model: Fallback translation model to use if no language-specific model is available. text: Input text to translate to English. Returns: String containing the translated text in English, or the original text if it is already in English. Note: - Uses safetensors for secure model loading. - Automatically selects the best available device (CUDA/XPU/MPS/CPU). - Cleans up model and GPU memory after inference. """ # Detect the language using langdetect library detected_lang = detect_language(text) # Check if already in English if detected_lang == "en": return text # Get the appropriate translation model translation_model = get_translation_model(detected_lang, fallback_translation_model) # Load model and tokenizer pytorch_device = get_pytorch_device() dtype = get_torch_dtype() # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() # reduces memory consumption by not storing gradients. This can significantly reduce the # amount of memory used during the inference phase. tokenizer = MarianTokenizer.from_pretrained(translation_model) model = MarianMTModel.from_pretrained( translation_model, use_safetensors=True, dtype=dtype ).to(pytorch_device) # Tokenize and translate inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(pytorch_device) with torch.no_grad(): translated = model.generate(**inputs) translation = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] # Clean up GPU memory del model, tokenizer, inputs, translated if pytorch_device == "cuda": torch.cuda.empty_cache() gc.collect() return translation def create_translation_tab(fallback_translation_model: str): """Create the translation to English tab in the Gradio interface. This function sets up all UI components for translation with automatic language detection, including input textbox, translate button, and output textbox. Args: fallback_translation_model: Fallback translation model to use if no language-specific model is available. """ gr.Markdown("Translate text to English. The source language will be automatically detected.") translation_input = gr.Textbox(label="Input Text", lines=5) translation_button = gr.Button("Translate") translation_output = gr.Textbox(label="Translated Text", lines=5, interactive=False) translation_button.click( fn=partial(translate_to_english, fallback_translation_model), inputs=translation_input, outputs=translation_output )