VBoussot commited on
Commit
1e9e15a
·
verified ·
1 Parent(s): c51038e

Update total/Model.py

Browse files
Files changed (1) hide show
  1. total/Model.py +0 -13
total/Model.py CHANGED
@@ -62,16 +62,3 @@ class Unet_TS(network.Network):
62
  nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
  super().load(state_dict, init, ema)
65
-
66
- class Combine(Reduction):
67
-
68
- def __init__(self):
69
- pass
70
-
71
- def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
72
- fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
73
- sum_fg = fg_all.sum(dim=1, keepdim=True)
74
- bg = (1.0 - sum_fg).clamp(min=1e-6)
75
-
76
- probs = torch.cat([bg, fg_all], dim=1)
77
- return probs / probs.sum(dim=1, keepdim=True)
 
62
  nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
  self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
  super().load(state_dict, init, ema)