|
|
import os, tempfile, json, traceback |
|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" |
|
|
|
|
|
def _pick_token(request: gr.Request | None): |
|
|
tok = os.getenv("HF_TOKEN") |
|
|
if not tok and request: |
|
|
auth = request.headers.get("authorization", "") |
|
|
if isinstance(auth, str) and auth.lower().startswith("bearer "): |
|
|
tok = auth.split(" ", 1)[1] |
|
|
tok = tok or request.headers.get("x-hf-token") |
|
|
return tok |
|
|
|
|
|
def t2v(prompt: str, fps: int = 12, request: gr.Request | None = None): |
|
|
try: |
|
|
token = _pick_token(request) |
|
|
if not token: |
|
|
raise RuntimeError("No token found. Set Space secret HF_TOKEN or send Authorization: Bearer hf_...") |
|
|
|
|
|
client = InferenceClient(provider="fal-ai", api_key=token) |
|
|
video_bytes = client.text_to_video(prompt, model=MODEL_ID) |
|
|
|
|
|
if hasattr(video_bytes, "read"): |
|
|
video_bytes = video_bytes.read() |
|
|
if not isinstance(video_bytes, (bytes, bytearray)): |
|
|
raise RuntimeError(f"Unexpected provider return type: {type(video_bytes)}") |
|
|
|
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
|
|
tmp.write(video_bytes); tmp.flush(); tmp.close() |
|
|
return tmp.name |
|
|
except Exception as e: |
|
|
err = {"error": str(e), "traceback": traceback.format_exc(), "has_secret": bool(os.getenv("HF_TOKEN")), "model": MODEL_ID} |
|
|
raise gr.Error(json.dumps(err, ensure_ascii=False)) |
|
|
|
|
|
def health(): |
|
|
return {"ok": True, "model": MODEL_ID, "has_secret": bool(os.getenv("HF_TOKEN"))} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("### Wan2.2 T2V via fal-ai · HTTP API at `/api/predict` and `/api/health`") |
|
|
gr.Interface( |
|
|
fn=t2v, |
|
|
inputs=[gr.Textbox(label="Prompt"), gr.Slider(4, 24, value=12, step=1, label="FPS")], |
|
|
outputs=gr.Video(label="Video"), |
|
|
api_name="predict", |
|
|
title="Wan2.2 Text-to-Video", |
|
|
flagging_mode="never", |
|
|
) |
|
|
gr.Interface(fn=health, inputs=[], outputs="json", api_name="health") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |
|
|
|