New-space / app.py
Ahkjtgcfdhzjzxk's picture
Update app.py
b841bb7 verified
# app.py
import os
import threading
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import gradio as gr
import time
# -------------------------
# Config (change if needed)
# -------------------------
MODEL_REPO = os.environ.get("MODEL_REPO", "unsloth/gpt-oss-20b-GGUF")
MODEL_FILES = [
"gpt-oss-20b-Q4_K_M.gguf", # recommended (11.6 GB)
"gpt-oss-20b-Q4_0.gguf",
"gpt-oss-20b-Q4_K_S.gguf",
"gpt-oss-20b-Q2_K.gguf", # smaller, lower quality
]
MODEL_CACHE_DIR = os.environ.get("MODEL_CACHE_DIR", "/tmp/model-cache")
N_CTX = int(os.environ.get("N_CTX", 2048)) # safer default for CPU
N_THREADS = int(os.environ.get("N_THREADS", 2))
# Globals & state
llm = None
loaded_model = None
loading = False
load_error = None
_load_lock = threading.Lock()
# -------------------------
# Loader (background)
# -------------------------
def load_model_background():
global llm, loaded_model, loading, load_error
with _load_lock:
if llm is not None or loading:
return
loading = True
load_error = None
last_exc = None
for fname in MODEL_FILES:
try:
print(f"[loader] attempting: {fname}")
path = hf_hub_download(repo_id=MODEL_REPO, filename=fname, cache_dir=MODEL_CACHE_DIR)
print(f"[loader] downloaded: {path} β€” instantiating Llama...")
# instantiate
llm = Llama(
model_path=path,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_gpu_layers=0, # CPU-only
use_mmap=True,
use_mlock=False
)
loaded_model = fname
loading = False
print(f"[loader] SUCCESS: loaded {fname}")
return
except Exception as e:
print(f"[loader] failed for {fname}: {e}")
last_exc = e
# all failed
load_error = str(last_exc)
loading = False
print("[loader] ALL QUANTS FAILED. last error:", load_error)
# start background loader immediately (will print progress to logs)
_thread = threading.Thread(target=load_model_background, daemon=True)
_thread.start()
# -------------------------
# Prompt builder
# -------------------------
MAX_HISTORY_TURNS = 6
def build_prompt_from_history(history, user_input, reasoning_level="high"):
sys_line = f"[SYSTEM] Reasoning: {reasoning_level}\n\n"
chat_ctx = ""
# history is list of dicts: {'role': 'user'/'assistant', 'content': ...}
# Build as simple User/Assistant transcript, last MAX_HISTORY_TURNS user/assistant pairs
pairs = []
tmp = []
for m in history:
if m.get("role") == "user":
tmp = [m.get("content"), ""]
elif m.get("role") == "assistant":
if tmp:
tmp[1] = m.get("content")
pairs.append(tuple(tmp))
tmp = []
# keep last turns
pairs = pairs[-MAX_HISTORY_TURNS:]
for u, a in pairs:
chat_ctx += f"User: {u}\nAssistant: {a}\n"
chat_ctx += f"User: {user_input}\nAssistant:"
return sys_line + chat_ctx
# -------------------------
# Gen function
# -------------------------
def generate_reply(history, user_input, temperature, top_p, max_new_tokens, reasoning_level):
global llm, loaded_model, loading, load_error
if load_error:
raise RuntimeError(f"Model load failed: {load_error}")
if loading:
raise RuntimeError("Model is currently downloading/loading. Check the Space logs and wait a minute, then retry.")
if llm is None:
raise RuntimeError("Model not loaded yet. Please wait or check logs.")
prompt = build_prompt_from_history(history, user_input, reasoning_level)
out = llm(
prompt,
max_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
stop=["\nUser:", "\nAssistant:"]
)
text = out["choices"][0]["text"].strip()
return text
# -------------------------
# Gradio UI
# -------------------------
with gr.Blocks(title="gpt-oss-20b GGUF (llama.cpp) β€” CPU Space") as demo:
gr.Markdown("**gpt-oss-20b (GGUF)** β€” Recommended quant: **Q4_K_M**. Model loads in background; watch logs. If you see 'invalid ggml type' error in logs, see troubleshooting notes below.")
status = gr.Markdown(value="Model status: loading in background (check logs).")
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
txt = gr.Textbox(placeholder="Type message...", lines=3)
with gr.Column(scale=40):
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
max_new_tokens = gr.Slider(16, 1024, value=256, step=8, label="Max new tokens")
reasoning = gr.Dropdown(choices=["low","medium","high"], value="high", label="Reasoning level")
refresh_btn = gr.Button("Refresh status")
clear_btn = gr.Button("Clear chat")
state = gr.State(value=[]) # list of message dicts
def refresh_status():
s = "LOADING" if loading else ("LOADED: " + (loaded_model or "none"))
if load_error:
s += f"\n\n**ERROR:** {load_error}\n\nCommon cause: your GGUF requires a newer llama.cpp build (see troubleshooting)."
return gr.update(value=f"**Model status:** {s}")
def user_send(user_message, history, temperature, top_p, max_new_tokens, reasoning):
# history is list of {role:, content:}
history = history or []
if not user_message or user_message.strip()=="":
return history, ""
# check loader state
if load_error:
# show error inline
history.append({"role":"assistant","content":f"Model failed to load:\n\n{load_error}\n\nSee Space logs and troubleshooting in README."})
return history, ""
if loading or llm is None:
history.append({"role":"assistant","content":"Model is still loading in background. Please wait ~1-3 minutes and retry. Check Space logs for download/progress."})
return history, ""
# generate
try:
reply = generate_reply(history, user_message, temperature, top_p, max_new_tokens, reasoning)
except Exception as e:
history.append({"role":"assistant","content":f"Error generating: {e}"})
return history, ""
history.append({"role":"user","content":user_message})
history.append({"role":"assistant","content":reply})
return history, ""
txt.submit(user_send, inputs=[txt, state, temperature, top_p, max_new_tokens, reasoning], outputs=[chatbot, txt])
refresh_btn.click(fn=lambda: refresh_status(), inputs=None, outputs=[status])
clear_btn.click(fn=lambda: [], inputs=None, outputs=[chatbot, state])
# initial status displayed
status.update(value="Model status: background loader started. Check logs.")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)