Spaces:
Runtime error
Runtime error
wondervictor
commited on
Commit
·
fc81a43
1
Parent(s):
6cd385f
update README
Browse files- condition/midas/midas/vit.py +33 -13
condition/midas/midas/vit.py
CHANGED
|
@@ -128,12 +128,32 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
|
| 128 |
return posemb
|
| 129 |
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
def flat_forward_flex(model, x):
|
| 132 |
b, c, h, w = x.shape
|
| 133 |
|
| 134 |
-
pos_embed =
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
B = x.shape[0]
|
| 139 |
|
|
@@ -352,10 +372,10 @@ def _make_vit_b16_backbone(
|
|
| 352 |
|
| 353 |
# We inject this function into the VisionTransformer instances so that
|
| 354 |
# we can use it with interpolated position embeddings without modifying the library source.
|
| 355 |
-
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
| 356 |
-
|
| 357 |
-
pretrained.model._resize_pos_embed = types.MethodType(
|
| 358 |
-
|
| 359 |
|
| 360 |
return pretrained
|
| 361 |
|
|
@@ -550,13 +570,13 @@ def _make_vit_b_rn50_backbone(
|
|
| 550 |
|
| 551 |
# We inject this function into the VisionTransformer instances so that
|
| 552 |
# we can use it with interpolated position embeddings without modifying the library source.
|
| 553 |
-
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
| 554 |
-
|
| 555 |
|
| 556 |
-
# We inject this function into the VisionTransformer instances so that
|
| 557 |
-
# we can use it with interpolated position embeddings without modifying the library source.
|
| 558 |
-
pretrained.model._resize_pos_embed = types.MethodType(
|
| 559 |
-
|
| 560 |
|
| 561 |
return pretrained
|
| 562 |
|
|
|
|
| 128 |
return posemb
|
| 129 |
|
| 130 |
|
| 131 |
+
def _flat_resize_pos_embed(model, posemb, gs_h, gs_w):
|
| 132 |
+
posemb_tok, posemb_grid = (
|
| 133 |
+
posemb[:, :model.start_index],
|
| 134 |
+
posemb[0, model.start_index:],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 138 |
+
|
| 139 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
| 140 |
+
-1).permute(0, 3, 1, 2)
|
| 141 |
+
posemb_grid = F.interpolate(posemb_grid,
|
| 142 |
+
size=(gs_h, gs_w),
|
| 143 |
+
mode="bilinear")
|
| 144 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
| 145 |
+
|
| 146 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 147 |
+
|
| 148 |
+
return posemb
|
| 149 |
+
|
| 150 |
+
|
| 151 |
def flat_forward_flex(model, x):
|
| 152 |
b, c, h, w = x.shape
|
| 153 |
|
| 154 |
+
pos_embed = _flat_resize_pos_embed(model, model.pos_embed,
|
| 155 |
+
h // model.patch_size[1],
|
| 156 |
+
w // model.patch_size[0])
|
| 157 |
|
| 158 |
B = x.shape[0]
|
| 159 |
|
|
|
|
| 372 |
|
| 373 |
# We inject this function into the VisionTransformer instances so that
|
| 374 |
# we can use it with interpolated position embeddings without modifying the library source.
|
| 375 |
+
# pretrained.model.forward_flex = types.MethodType(forward_flex,
|
| 376 |
+
# pretrained.model)
|
| 377 |
+
# pretrained.model._resize_pos_embed = types.MethodType(
|
| 378 |
+
# _resize_pos_embed, pretrained.model)
|
| 379 |
|
| 380 |
return pretrained
|
| 381 |
|
|
|
|
| 570 |
|
| 571 |
# We inject this function into the VisionTransformer instances so that
|
| 572 |
# we can use it with interpolated position embeddings without modifying the library source.
|
| 573 |
+
# pretrained.model.forward_flex = types.MethodType(forward_flex,
|
| 574 |
+
# pretrained.model)
|
| 575 |
|
| 576 |
+
# # We inject this function into the VisionTransformer instances so that
|
| 577 |
+
# # we can use it with interpolated position embeddings without modifying the library source.
|
| 578 |
+
# pretrained.model._resize_pos_embed = types.MethodType(
|
| 579 |
+
# _resize_pos_embed, pretrained.model)
|
| 580 |
|
| 581 |
return pretrained
|
| 582 |
|