Update app.py
Browse files
app.py
CHANGED
|
@@ -346,11 +346,11 @@ def track_video(n_frames, video_state):
|
|
| 346 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 347 |
config = "sam2_hiera_l.yaml"
|
| 348 |
video_predictor_local = build_sam2_video_predictor(
|
| 349 |
-
config, sam2_checkpoint, device=
|
| 350 |
)
|
| 351 |
|
| 352 |
inference_state = video_predictor_local.init_state(
|
| 353 |
-
images=images / 255, device=
|
| 354 |
)
|
| 355 |
|
| 356 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
|
|
|
| 346 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 347 |
config = "sam2_hiera_l.yaml"
|
| 348 |
video_predictor_local = build_sam2_video_predictor(
|
| 349 |
+
config, sam2_checkpoint, device="cuda"
|
| 350 |
)
|
| 351 |
|
| 352 |
inference_state = video_predictor_local.init_state(
|
| 353 |
+
images=images / 255, device="cuda"
|
| 354 |
)
|
| 355 |
|
| 356 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|