Firebox / app.py
DSDUDEd's picture
Update app.py
d65bc64 verified
raw
history blame
2.41 kB
import os
import asyncio
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
# Paths
PROMPTS_CSV = "prompts.csv"
MODEL_NAME = "LiquidAI/LFM2-2.6B"
# Check for dataset, download if missing
if not os.path.exists(PROMPTS_CSV):
print("prompts.csv not found. Downloading dataset from Hugging Face...")
dataset = load_dataset("fka/awesome-chatgpt-prompts", split="train")
df = pd.DataFrame(dataset)
df.to_csv(PROMPTS_CSV, index=False)
print("Dataset saved to prompts.csv")
else:
df = pd.read_csv(PROMPTS_CSV)
all_prompts = df['prompt'].tolist()
print(f"Total prompts available: {len(all_prompts)}")
# Load first 20 prompts for fast startup
fast_prompts = all_prompts[:20]
remaining_prompts = all_prompts[20:]
# Load tokenizer and model
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")
# Async function to load remaining prompts
async def load_remaining_prompts():
global fast_prompts
print("Loading remaining prompts asynchronously...")
await asyncio.sleep(1) # simulate async loading
fast_prompts.extend(remaining_prompts)
print("All prompts loaded.")
# Function to generate response
def generate_response(prompt, max_tokens=100):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return response
# Gradio interface
def chat_with_prompt(prompt_idx):
prompt = fast_prompts[prompt_idx]
response = generate_response(prompt)
return f"Prompt:\n{prompt}\n\nResponse:\n{response}"
with gr.Blocks() as demo:
gr.Markdown("## ChatGPT Prompt Tester")
prompt_dropdown = gr.Dropdown(choices=[str(i) for i in range(len(fast_prompts))], label="Select Prompt Index")
output_text = gr.Textbox(label="Model Response", lines=15)
prompt_dropdown.change(chat_with_prompt, inputs=prompt_dropdown, outputs=output_text)
# Run async loading in the background
asyncio.create_task(load_remaining_prompts())
# Launch Gradio
demo.launch(server_name="0.0.0.0", server_port=7860)