sidharthg's picture
Upload app.py
fd693e4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoConfig
from model import SmolLM2
import os
# 1. Setup and Loading
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = "checkpoint_5050.pt"
tokenizer_path = "./custom_tokenizer"
print(f"Using device: {device}")
# Load Tokenizer
if os.path.exists(tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
# Fallback
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load Model
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config.vocab_size = len(tokenizer) # Sync vocab size
model = SmolLM2(config).to(device)
# Load Checkpoint
if os.path.exists(checkpoint_path):
print(f"Loading checkpoint from {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
else:
print("Checkpoint not found! Using random weights.")
model.eval()
# 2. Generation Function
def generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids
# Generation settings
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id
}
# Generate
with torch.no_grad():
generated_ids = input_ids
for _ in range(int(max_new_tokens)):
outputs = model(generated_ids)
next_token_logits = outputs[:, -1, :]
# Repetition Penalty
if repetition_penalty != 1.0:
for i in range(generated_ids.shape[0]):
for previous_token in set(generated_ids[i].tolist()):
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
# Temperature
next_token_logits = next_token_logits / temperature
# Top-K
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-P (Nucleus Sampling)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# 3. Gradio UI - Redesigned
print(f"Gradio Version: {gr.__version__}")
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🌌 SmolLM2-135M Playground
### A custom 135M parameter model trained from scratch.
"""
)
with gr.Row():
# Sidebar for Settings
with gr.Column(scale=1, variant="panel"):
gr.Markdown("### ⚙️ Generation Settings")
gr.Markdown("Adjust these parameters to control the creativity and length of the generated text.")
max_new_tokens = gr.Slider(minimum=10, maximum=1024, value=150, step=10, label="Max New Tokens", info="Maximum number of tokens to generate.")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", info="Higher values mean more random/creative output.")
top_k = gr.Slider(minimum=0, maximum=100, value=40, step=1, label="Top-K", info="Limit to top K tokens.")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty", info="Penalize repeated tokens.")
# Main Content Area
with gr.Column(scale=3):
prompt = gr.Textbox(
label="Input Prompt",
placeholder="Type your prompt here (e.g., 'First Citizen:')...",
lines=5
)
gr.Examples(
examples=[
["First Citizen:"],
["The meaning of life is"],
["Once upon a time"],
["To be or not to be"],
["The quick brown fox"]
],
inputs=prompt,
label="Click on an example to load it:"
)
generate_btn = gr.Button("✨ Generate Text", variant="primary", size="lg")
output = gr.Textbox(
label="Generated Output",
lines=12,
interactive=False
)
# Footer / Info
with gr.Accordion("ℹ️ Model Information", open=False):
gr.Markdown(
"""
* **Architecture**: SmolLM2 (Transformer with Grouped Query Attention)
* **Parameters**: 135M
* **Training Data**: Wikitext / Custom Dataset
* **Tokenizer**: Custom BPE
"""
)
generate_btn.click(
fn=generate,
inputs=[prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty],
outputs=output
)
if __name__ == "__main__":
demo.launch()