Ryan-PR commited on
Commit
5dc3610
·
verified ·
1 Parent(s): 6a3d054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -38
app.py CHANGED
@@ -320,6 +320,7 @@ def clear_ref_clicks(ref_state):
320
 
321
 
322
  @spaces.GPU(duration=40)
 
323
  def track_video(n_frames, video_state):
324
  input_points = video_state["input_points"]
325
  input_labels = video_state["input_labels"]
@@ -345,48 +346,56 @@ def track_video(n_frames, video_state):
345
 
346
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
347
  config = "sam2_hiera_l.yaml"
348
- video_predictor_local = build_sam2_video_predictor(
349
- config, sam2_checkpoint, device="cuda"
350
- )
351
 
352
- inference_state = video_predictor_local.init_state(
353
- images=images_np / 255, device="cuda"
354
- )
 
355
 
356
- if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
357
- mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
358
- else:
359
- mask0 = torch.from_numpy(video_state["masks"][0])
360
 
361
- video_predictor_local.add_new_mask(
362
- inference_state=inference_state,
363
- frame_idx=0,
364
- obj_id=obj_id,
365
- mask=mask0,
366
- )
367
 
368
- output_frames = []
369
- mask_frames = []
370
- color = (
371
- np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
372
- / 255.0
373
- )
374
- color = color[None, None, :]
375
- for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
376
- inference_state
377
- ):
378
- frame = images_np[out_frame_idx].astype(np.float32) / 255.0
379
- mask = np.zeros((H, W, 3), dtype=np.float32)
380
- for i, logit in enumerate(out_mask_logits):
381
- out_mask = logit.cpu().squeeze().detach().numpy()
382
- out_mask = (out_mask[:, :, None] > 0).astype(np.float32)
383
- mask += out_mask
384
- mask = np.clip(mask, 0, 1)
385
- mask = cv2.resize(mask, (W_, H_))
386
- mask_frames.append(mask)
387
- painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
388
- painted = np.uint8(np.clip(painted * 255, 0, 255))
389
- output_frames.append(painted)
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
392
  clip = ImageSequenceClip(output_frames, fps=15)
 
320
 
321
 
322
  @spaces.GPU(duration=40)
323
+ @torch.no_grad()
324
  def track_video(n_frames, video_state):
325
  input_points = video_state["input_points"]
326
  input_labels = video_state["input_labels"]
 
346
 
347
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
348
  config = "sam2_hiera_l.yaml"
 
 
 
349
 
350
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
351
+ video_predictor_local = build_sam2_video_predictor(
352
+ config, sam2_checkpoint, device="cuda"
353
+ )
354
 
355
+ inference_state = video_predictor_local.init_state(
356
+ images=images_np / 255, device="cuda"
357
+ )
 
358
 
359
+ if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
360
+ mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
361
+ else:
362
+ mask0 = torch.from_numpy(video_state["masks"][0])
 
 
363
 
364
+ video_predictor_local.add_new_mask(
365
+ inference_state=inference_state,
366
+ frame_idx=0,
367
+ obj_id=obj_id,
368
+ mask=mask0,
369
+ )
370
+
371
+ output_frames = []
372
+ mask_frames = []
373
+ color = (
374
+ np.array(
375
+ COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)],
376
+ dtype=np.float32,
377
+ )
378
+ / 255.0
379
+ )
380
+ color = color[None, None, :]
381
+
382
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
383
+ inference_state
384
+ ):
385
+ frame = images_np[out_frame_idx].astype(np.float32) / 255.0
386
+
387
+ mask = np.zeros((H, W, 3), dtype=np.float32)
388
+ for i, logit in enumerate(out_mask_logits):
389
+ out_mask = logit.cpu().squeeze().detach().numpy()
390
+ out_mask = (out_mask[:, :, None] > 0).astype(np.float32)
391
+ mask += out_mask
392
+
393
+ mask = np.clip(mask, 0, 1)
394
+ mask = cv2.resize(mask, (W_, H_))
395
+ mask_frames.append(mask)
396
+ painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
397
+ painted = np.uint8(np.clip(painted * 255, 0, 255))
398
+ output_frames.append(painted)
399
 
400
  video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
401
  clip = ImageSequenceClip(output_frames, fps=15)