Update app.py
Browse files
app.py
CHANGED
|
@@ -320,6 +320,7 @@ def clear_ref_clicks(ref_state):
|
|
| 320 |
|
| 321 |
|
| 322 |
@spaces.GPU(duration=40)
|
|
|
|
| 323 |
def track_video(n_frames, video_state):
|
| 324 |
input_points = video_state["input_points"]
|
| 325 |
input_labels = video_state["input_labels"]
|
|
@@ -345,48 +346,56 @@ def track_video(n_frames, video_state):
|
|
| 345 |
|
| 346 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 347 |
config = "sam2_hiera_l.yaml"
|
| 348 |
-
video_predictor_local = build_sam2_video_predictor(
|
| 349 |
-
config, sam2_checkpoint, device="cuda"
|
| 350 |
-
)
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
|
|
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
mask0 = torch.from_numpy(video_state["masks"][0])
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
mask=mask0,
|
| 366 |
-
)
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
|
| 392 |
clip = ImageSequenceClip(output_frames, fps=15)
|
|
|
|
| 320 |
|
| 321 |
|
| 322 |
@spaces.GPU(duration=40)
|
| 323 |
+
@torch.no_grad()
|
| 324 |
def track_video(n_frames, video_state):
|
| 325 |
input_points = video_state["input_points"]
|
| 326 |
input_labels = video_state["input_labels"]
|
|
|
|
| 346 |
|
| 347 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 348 |
config = "sam2_hiera_l.yaml"
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 351 |
+
video_predictor_local = build_sam2_video_predictor(
|
| 352 |
+
config, sam2_checkpoint, device="cuda"
|
| 353 |
+
)
|
| 354 |
|
| 355 |
+
inference_state = video_predictor_local.init_state(
|
| 356 |
+
images=images_np / 255, device="cuda"
|
| 357 |
+
)
|
|
|
|
| 358 |
|
| 359 |
+
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
| 360 |
+
mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
|
| 361 |
+
else:
|
| 362 |
+
mask0 = torch.from_numpy(video_state["masks"][0])
|
|
|
|
|
|
|
| 363 |
|
| 364 |
+
video_predictor_local.add_new_mask(
|
| 365 |
+
inference_state=inference_state,
|
| 366 |
+
frame_idx=0,
|
| 367 |
+
obj_id=obj_id,
|
| 368 |
+
mask=mask0,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
output_frames = []
|
| 372 |
+
mask_frames = []
|
| 373 |
+
color = (
|
| 374 |
+
np.array(
|
| 375 |
+
COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)],
|
| 376 |
+
dtype=np.float32,
|
| 377 |
+
)
|
| 378 |
+
/ 255.0
|
| 379 |
+
)
|
| 380 |
+
color = color[None, None, :]
|
| 381 |
+
|
| 382 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
|
| 383 |
+
inference_state
|
| 384 |
+
):
|
| 385 |
+
frame = images_np[out_frame_idx].astype(np.float32) / 255.0
|
| 386 |
+
|
| 387 |
+
mask = np.zeros((H, W, 3), dtype=np.float32)
|
| 388 |
+
for i, logit in enumerate(out_mask_logits):
|
| 389 |
+
out_mask = logit.cpu().squeeze().detach().numpy()
|
| 390 |
+
out_mask = (out_mask[:, :, None] > 0).astype(np.float32)
|
| 391 |
+
mask += out_mask
|
| 392 |
+
|
| 393 |
+
mask = np.clip(mask, 0, 1)
|
| 394 |
+
mask = cv2.resize(mask, (W_, H_))
|
| 395 |
+
mask_frames.append(mask)
|
| 396 |
+
painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
|
| 397 |
+
painted = np.uint8(np.clip(painted * 255, 0, 255))
|
| 398 |
+
output_frames.append(painted)
|
| 399 |
|
| 400 |
video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
|
| 401 |
clip = ImageSequenceClip(output_frames, fps=15)
|