Update total/Model.py
Browse files- total/Model.py +0 -13
total/Model.py
CHANGED
|
@@ -62,16 +62,3 @@ class Unet_TS(network.Network):
|
|
| 62 |
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
|
| 63 |
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 64 |
super().load(state_dict, init, ema)
|
| 65 |
-
|
| 66 |
-
class Combine(Reduction):
|
| 67 |
-
|
| 68 |
-
def __init__(self):
|
| 69 |
-
pass
|
| 70 |
-
|
| 71 |
-
def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
|
| 72 |
-
fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
|
| 73 |
-
sum_fg = fg_all.sum(dim=1, keepdim=True)
|
| 74 |
-
bg = (1.0 - sum_fg).clamp(min=1e-6)
|
| 75 |
-
|
| 76 |
-
probs = torch.cat([bg, fg_all], dim=1)
|
| 77 |
-
return probs / probs.sum(dim=1, keepdim=True)
|
|
|
|
| 62 |
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
|
| 63 |
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 64 |
super().load(state_dict, init, ema)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|