Spaces:
Sleeping
Sleeping
Commit
·
5139a47
1
Parent(s):
227a9e0
another attempt at RT speedup for L4
Browse files
app.py
CHANGED
|
@@ -1,3 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from magenta_rt import system, audio as au
|
| 2 |
import numpy as np
|
| 3 |
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
|
|
@@ -15,7 +44,7 @@ from utils import (
|
|
| 15 |
|
| 16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
| 17 |
import uuid, threading
|
| 18 |
-
|
| 19 |
import logging
|
| 20 |
|
| 21 |
import gradio as gr
|
|
@@ -25,25 +54,6 @@ from typing import Optional
|
|
| 25 |
import json, asyncio, base64
|
| 26 |
import time
|
| 27 |
|
| 28 |
-
# ---- Perf knobs (add at top of app.py) ----
|
| 29 |
-
os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
|
| 30 |
-
os.environ.setdefault("XLA_FLAGS",
|
| 31 |
-
"--xla_gpu_enable_triton_gemm=true "
|
| 32 |
-
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
| 33 |
-
"--xla_gpu_autotune_level=2")
|
| 34 |
-
# TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
|
| 35 |
-
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
| 36 |
-
|
| 37 |
-
import jax
|
| 38 |
-
jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
|
| 39 |
-
# Optional: persist XLA compile artifacts across restarts (saves warmup time)
|
| 40 |
-
try:
|
| 41 |
-
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 42 |
-
cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
|
| 43 |
-
except Exception:
|
| 44 |
-
pass
|
| 45 |
-
# --------------------------------------------
|
| 46 |
-
|
| 47 |
|
| 48 |
|
| 49 |
from starlette.websockets import WebSocketState
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Useful XLA GPU optimizations (harmless if a flag is unknown)
|
| 3 |
+
os.environ.setdefault(
|
| 4 |
+
"XLA_FLAGS",
|
| 5 |
+
" ".join([
|
| 6 |
+
"--xla_gpu_enable_triton_gemm=true",
|
| 7 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
| 8 |
+
"--xla_gpu_autotune_level=2",
|
| 9 |
+
])
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 13 |
+
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 14 |
+
|
| 15 |
+
import jax
|
| 16 |
+
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
| 17 |
+
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
| 18 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
| 19 |
+
|
| 20 |
+
# Initialize the on-disk compilation cache (best-effort)
|
| 21 |
+
try:
|
| 22 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 23 |
+
cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
|
| 24 |
+
except Exception:
|
| 25 |
+
pass
|
| 26 |
+
# --------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
from magenta_rt import system, audio as au
|
| 31 |
import numpy as np
|
| 32 |
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
|
|
|
|
| 44 |
|
| 45 |
from jam_worker import JamWorker, JamParams, JamChunk
|
| 46 |
import uuid, threading
|
| 47 |
+
|
| 48 |
import logging
|
| 49 |
|
| 50 |
import gradio as gr
|
|
|
|
| 54 |
import json, asyncio, base64
|
| 55 |
import time
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
from starlette.websockets import WebSocketState
|