Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer | |
| from utils import get_pytorch_device, spaces_gpu | |
| # Global chatbot instance (initialized once) | |
| _chatbot = None | |
| _tokenizer = None | |
| _is_seq2seq = None | |
| def get_chatbot(model: str): | |
| """Get or create the chatbot model instance. | |
| This function implements a singleton pattern to load and cache the chatbot | |
| model and tokenizer. It supports both causal language models (like GPT-style | |
| models) and sequence-to-sequence models (like BlenderBot). The model type | |
| is automatically detected from the model configuration. | |
| Args: | |
| model: Hugging Face model ID to use for the chatbot. | |
| Returns: | |
| Tuple containing: | |
| - Model: The loaded transformer model (AutoModelForCausalLM or AutoModelForSeq2SeqLM) | |
| - Tokenizer: The corresponding tokenizer | |
| - bool: Whether the model is a seq2seq model (True) or causal LM (False) | |
| Note: | |
| - Models are loaded with safetensors for secure loading. | |
| - Automatically selects the best available device (CUDA/XPU/MPS/CPU). | |
| - Sets pad_token to eos_token if pad_token is not configured. | |
| - Model is cached globally after first load for performance. | |
| """ | |
| global _chatbot, _tokenizer, _is_seq2seq | |
| if _chatbot is None: | |
| device = get_pytorch_device() | |
| _tokenizer = AutoTokenizer.from_pretrained(model) | |
| # Try to determine model type and load accordingly | |
| # Check tokenizer config or model config to see if it's seq2seq | |
| try: | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained(model) | |
| # Seq2seq models have encoder/decoder, causal LMs don't | |
| _is_seq2seq = hasattr(config, 'is_encoder_decoder') and config.is_encoder_decoder | |
| except Exception: | |
| # Default to causal LM (most modern chat models) | |
| _is_seq2seq = False | |
| if _is_seq2seq: | |
| _chatbot = AutoModelForSeq2SeqLM.from_pretrained( | |
| model, | |
| use_safetensors=True | |
| ).to(device) | |
| else: | |
| _chatbot = AutoModelForCausalLM.from_pretrained( | |
| model, | |
| use_safetensors=True | |
| ).to(device) | |
| # Set pad token if not set | |
| if _tokenizer.pad_token is None: | |
| _tokenizer.pad_token = _tokenizer.eos_token | |
| return _chatbot, _tokenizer, _is_seq2seq | |
| def chat(model: str, message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]: | |
| """Generate a chatbot response given a user message and conversation history. | |
| This function handles conversation with AI chatbots, supporting both modern | |
| chat models with chat templates (like Qwen, Mistral) and older models | |
| without templates (like BlenderBot). It manages conversation history and | |
| formats inputs appropriately based on the model type. | |
| Args: | |
| model: Hugging Face model ID to use for the chatbot. | |
| message: The user's current message as a string. | |
| conversation_history: Optional list of previous conversation messages. | |
| Each message is a dict with "role" ("user" or "assistant") and "content". | |
| If None, starts a new conversation. | |
| Returns: | |
| Tuple containing: | |
| - str: The assistant's response message | |
| - list[dict]: Updated conversation history including the new exchange | |
| Note: | |
| - Supports models with chat templates (uses apply_chat_template) | |
| - Falls back to manual formatting for models without templates | |
| - Handles both causal LM and seq2seq model architectures | |
| - Uses sampling with temperature=0.7 for varied responses | |
| - Generates up to 256 new tokens | |
| - Automatically manages conversation context and history | |
| - Extracts only newly generated text for causal LMs with chat templates | |
| """ | |
| model_instance, tokenizer, is_seq2seq = get_chatbot(model) | |
| # Initialize conversation history if this is the first message | |
| if conversation_history is None: | |
| conversation_history = [] | |
| # Add the user's message | |
| conversation_history.append({"role": "user", "content": message}) | |
| device = get_pytorch_device() | |
| # Check if tokenizer has a chat template (modern chat models) | |
| use_chat_template = hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None | |
| if use_chat_template: | |
| # Use chat template for modern chat models (Qwen, Mistral, etc.) | |
| try: | |
| formatted_input = tokenizer.apply_chat_template( | |
| conversation_history, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(formatted_input, return_tensors="pt", truncation=True).to(device) | |
| except Exception: | |
| use_chat_template = False | |
| if not use_chat_template: | |
| # For models without chat templates (BlenderBot, older models) | |
| if is_seq2seq: | |
| # Seq2seq format: "User: ...\nAssistant: ..." | |
| dialogue_text = "" | |
| for msg in conversation_history: | |
| if msg["role"] == "user": | |
| dialogue_text += f"User: {msg['content']}\n" | |
| elif msg["role"] == "assistant": | |
| dialogue_text += f"Assistant: {msg['content']}\n" | |
| inputs = tokenizer([dialogue_text], return_tensors="pt", truncation=True, max_length=512).to(device) | |
| else: | |
| # Causal LM format: just concatenate messages | |
| dialogue_text = "" | |
| for msg in conversation_history: | |
| if msg["role"] == "user": | |
| dialogue_text += f"User: {msg['content']}\n\n" | |
| elif msg["role"] == "assistant": | |
| dialogue_text += f"Assistant: {msg['content']}\n\n" | |
| dialogue_text += "Assistant:" | |
| inputs = tokenizer(dialogue_text, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
| # Generate response | |
| outputs = model_instance.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the response | |
| if is_seq2seq: | |
| # For seq2seq, output is just the generated response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up any "Assistant:" prefix | |
| if response.startswith("Assistant:"): | |
| response = response[len("Assistant:"):].strip() | |
| else: | |
| # For causal LMs, extract only the newly generated part | |
| if use_chat_template: | |
| # Extract only new tokens (generated part) | |
| input_length = inputs.input_ids.shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| else: | |
| # Extract text after the prompt | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = full_text.split("Assistant:")[-1].strip() | |
| # Add the assistant's response to history | |
| conversation_history.append({"role": "assistant", "content": response}) | |
| return response, conversation_history | |
| def create_chatbot_tab(model: str): | |
| """Create the chatbot tab in the Gradio interface. | |
| This function sets up all UI components for the conversational chatbot, | |
| including: | |
| - Chatbot component for displaying conversation history (using messages format) | |
| - Text input box for user messages | |
| - Send button and Enter key submission support | |
| It also wires up event handlers for both button clicks and Enter key presses. | |
| The conversation history uses Gradio's messages format (list of dicts with | |
| "role" and "content" keys), which matches the internal chatbot API format. | |
| Args: | |
| model: Hugging Face model ID to use for the chatbot. | |
| """ | |
| gr.Markdown("Have a conversation with an AI chatbot.") | |
| chatbot_output = gr.Chatbot(label="Conversation", type="messages") | |
| chatbot_input = gr.Textbox(label="Your message") | |
| chatbot_send_button = gr.Button("Send") | |
| def chat_interface(message: str, history: list[dict] | None): | |
| """Handle chatbot interaction with Gradio messages format. | |
| This function handles chatbot interactions using Gradio's messages format, | |
| where each message is a dictionary with "role" and "content" keys. | |
| Args: | |
| message: The user's message string from the input box. | |
| history: Gradio's chat history in messages format (list of dicts with | |
| "role" and "content" keys). If None, starts a new conversation. | |
| Returns: | |
| Tuple containing: | |
| - Updated chat history in messages format | |
| - Empty string (to clear the input field) | |
| """ | |
| if not message.strip(): | |
| return history, "" | |
| print(history) | |
| # Use history directly as conversation_state since they're the same format | |
| response, updated_conversation = chat(model, message, history) | |
| return updated_conversation, "" # Return updated conversation history and clear input field | |
| chatbot_send_button.click( | |
| fn=chat_interface, | |
| inputs=[chatbot_input, chatbot_output], | |
| outputs=[chatbot_output, chatbot_input] | |
| ) | |
| chatbot_input.submit( | |
| fn=chat_interface, | |
| inputs=[chatbot_input, chatbot_output], | |
| outputs=[chatbot_output, chatbot_input] | |
| ) | |