Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList | |
| import torch | |
| # Load the tokenizer and model | |
| repo_name = "nvidia/Hymba-1.5B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True) | |
| # Move the model to GPU with float16 precision for efficiency | |
| model = model.to("cuda").to(torch.float16) | |
| # Initialize the conversation history | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| # Define stopping criteria | |
| stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings=["</s>"])]) | |
| # Chat function for Gradio interface | |
| def chat_function(user_input): | |
| # Add user message to the conversation history | |
| messages.append({"role": "user", "content": user_input}) | |
| # Tokenize the conversation | |
| tokenized_chat = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to("cuda") | |
| # Generate a response | |
| outputs = model.generate( | |
| tokenized_chat["input_ids"], | |
| max_new_tokens=256, | |
| do_sample=False, | |
| temperature=0.7, | |
| use_cache=True, | |
| stopping_criteria=stopping_criteria | |
| ) | |
| # Decode the output response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Add the assistant's response to the conversation history | |
| messages.append({"role": "assistant", "content": response}) | |
| return response | |
| # Set up Gradio interface with the chatbot template | |
| iface = gr.Interface( | |
| fn=chat_function, | |
| inputs=gr.inputs.Textbox(label="Your message", placeholder="Enter your message here..."), | |
| outputs=gr.outputs.Chatbot(), | |
| live=True, | |
| title="Hymba Chatbot", | |
| description="Chat with the Hymba-1.5B-Instruct model!" | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch() | |