Valentin Boussot
commited on
Commit
Β·
406b8bb
1
Parent(s):
7d4dc86
Refactor model structure to support KonfAI inference from Hugging Face repositories
Browse files- M298.pt +0 -3
- M730.pt +0 -3
- M731.pt +0 -3
- M732.pt +0 -3
- M733.pt +0 -3
- M853.pt +0 -3
- M297.pt β total-3mm/M297.pt +0 -0
- Model.py β total-3mm/Model.py +0 -0
- Prediction_CT_Fast.yml β total-3mm/Prediction.yml +0 -0
- total-3mm/metadata.json +7 -0
- total-3mm/requirements.txt +1 -0
- M291.pt β total/M291.pt +0 -0
- M292.pt β total/M292.pt +0 -0
- M293.pt β total/M293.pt +0 -0
- M294.pt β total/M294.pt +0 -0
- M295.pt β total/M295.pt +0 -0
- total/Model.py +78 -0
- Prediction_CT.yml β total/Prediction.yml +0 -0
- total/metadata.json +7 -0
- total/requirements.txt +1 -0
- M852.pt β total_mr-3mm/M852.pt +0 -0
- total_mr-3mm/Model.py +78 -0
- Prediction_MR_Fast.yml β total_mr-3mm/Prediction.yml +0 -0
- total_mr-3mm/metadata.json +7 -0
- total_mr-3mm/requirements.txt +1 -0
- M850.pt β total_mr/M850.pt +0 -0
- M851.pt β total_mr/M851.pt +0 -0
- total_mr/Model.py +78 -0
- Prediction_MR.yml β total_mr/Prediction.yml +0 -0
- total_mr/metadata.json +7 -0
- total_mr/requirements.txt +1 -0
M298.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8f04169298561a6e1fcad186ef1a24623b48edf5f776710793de13a5a9c6d40e
|
| 3 |
-
size 66225317
|
|
|
|
|
|
|
|
|
|
|
|
M730.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:462c3c7ca2eb1a40a0984d5b1f004cfbf617c3e9218af65f442006e132fd594e
|
| 3 |
-
size 123170169
|
|
|
|
|
|
|
|
|
|
|
|
M731.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f9998319e380221bc0e69dae2d0f120f7e9d7c38254dc982afd05aa9f44877b1
|
| 3 |
-
size 123169913
|
|
|
|
|
|
|
|
|
|
|
|
M732.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ce4d777f8da2ebdeb2a4b19f976f8eac36bba252c5f63fccf76f9e3dd8bc296c
|
| 3 |
-
size 66217253
|
|
|
|
|
|
|
|
|
|
|
|
M733.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3340e70c92bb9c095f1db8e139319efeb7fb8e444f04ffbd8421c70e657fc444
|
| 3 |
-
size 22435409
|
|
|
|
|
|
|
|
|
|
|
|
M853.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:024cadec3839c0b552fb07b4ba866073430c3856d06ef5a2e0bb2896d20e7596
|
| 3 |
-
size 22434641
|
|
|
|
|
|
|
|
|
|
|
|
M297.pt β total-3mm/M297.pt
RENAMED
|
File without changes
|
Model.py β total-3mm/Model.py
RENAMED
|
File without changes
|
Prediction_CT_Fast.yml β total-3mm/Prediction.yml
RENAMED
|
File without changes
|
total-3mm/metadata.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"display_name": "Total 3mm",
|
| 3 |
+
"short_description": "",
|
| 4 |
+
"description": "",
|
| 5 |
+
"tta": 0,
|
| 6 |
+
"mc_dropout": 0
|
| 7 |
+
}
|
total-3mm/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
segmentation_models_pytorch
|
M291.pt β total/M291.pt
RENAMED
|
File without changes
|
M292.pt β total/M292.pt
RENAMED
|
File without changes
|
M293.pt β total/M293.pt
RENAMED
|
File without changes
|
M294.pt β total/M294.pt
RENAMED
|
File without changes
|
M295.pt β total/M295.pt
RENAMED
|
File without changes
|
total/Model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from konfai.network import network, blocks
|
| 3 |
+
from konfai.predictor import Reduction
|
| 4 |
+
|
| 5 |
+
class ConvBlock(network.ModuleArgsDict):
|
| 6 |
+
def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
|
| 9 |
+
self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 10 |
+
self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 11 |
+
self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
|
| 12 |
+
self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 13 |
+
self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 14 |
+
|
| 15 |
+
class UNetHead(network.ModuleArgsDict):
|
| 16 |
+
def __init__(self, in_channels: int, nb_class: int) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 19 |
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
| 20 |
+
|
| 21 |
+
class UNetBlock(network.ModuleArgsDict):
|
| 22 |
+
|
| 23 |
+
def __init__(self, channels, mri: bool, i : int = 0) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
|
| 26 |
+
|
| 27 |
+
if len(channels) > 2:
|
| 28 |
+
self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
|
| 29 |
+
self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
|
| 30 |
+
|
| 31 |
+
if i > 0:
|
| 32 |
+
self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
|
| 33 |
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
| 34 |
+
|
| 35 |
+
class Unet_TS(network.Network):
|
| 36 |
+
|
| 37 |
+
def __init__(self,
|
| 38 |
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
| 39 |
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
| 40 |
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
| 41 |
+
},
|
| 42 |
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
| 43 |
+
channels = [1, 32, 64, 128, 320, 320],
|
| 44 |
+
mri: bool = False) -> None:
|
| 45 |
+
super().__init__(
|
| 46 |
+
in_channels=channels[0],
|
| 47 |
+
optimizer=optimizer,
|
| 48 |
+
schedulers=schedulers,
|
| 49 |
+
outputs_criterions=outputs_criterions,
|
| 50 |
+
patch=None,
|
| 51 |
+
dim=3,
|
| 52 |
+
)
|
| 53 |
+
self.add_module("UNetBlock", UNetBlock(channels, mri))
|
| 54 |
+
self.add_module("Head", UNetHead(channels[1], 42))
|
| 55 |
+
|
| 56 |
+
def load(
|
| 57 |
+
self,
|
| 58 |
+
state_dict: dict[str, dict[str, torch.Tensor] | int],
|
| 59 |
+
init: bool = True,
|
| 60 |
+
ema: bool = False,
|
| 61 |
+
):
|
| 62 |
+
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
|
| 63 |
+
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 64 |
+
super().load(state_dict, init, ema)
|
| 65 |
+
|
| 66 |
+
class Combine(Reduction):
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
|
| 72 |
+
fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
|
| 73 |
+
sum_fg = fg_all.sum(dim=1, keepdim=True)
|
| 74 |
+
bg = (1.0 - sum_fg).clamp(min=1e-6)
|
| 75 |
+
|
| 76 |
+
probs = torch.cat([bg, fg_all], dim=1)
|
| 77 |
+
|
| 78 |
+
return probs / probs.sum(dim=1, keepdim=True)
|
Prediction_CT.yml β total/Prediction.yml
RENAMED
|
File without changes
|
total/metadata.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"display_name": "total",
|
| 3 |
+
"short_description": "",
|
| 4 |
+
"description": "",
|
| 5 |
+
"tta": 0,
|
| 6 |
+
"mc_dropout": 0
|
| 7 |
+
}
|
total/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
segmentation_models_pytorch
|
M852.pt β total_mr-3mm/M852.pt
RENAMED
|
File without changes
|
total_mr-3mm/Model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from konfai.network import network, blocks
|
| 3 |
+
from konfai.predictor import Reduction
|
| 4 |
+
|
| 5 |
+
class ConvBlock(network.ModuleArgsDict):
|
| 6 |
+
def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
|
| 9 |
+
self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 10 |
+
self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 11 |
+
self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
|
| 12 |
+
self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 13 |
+
self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 14 |
+
|
| 15 |
+
class UNetHead(network.ModuleArgsDict):
|
| 16 |
+
def __init__(self, in_channels: int, nb_class: int) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 19 |
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
| 20 |
+
|
| 21 |
+
class UNetBlock(network.ModuleArgsDict):
|
| 22 |
+
|
| 23 |
+
def __init__(self, channels, mri: bool, i : int = 0) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
|
| 26 |
+
|
| 27 |
+
if len(channels) > 2:
|
| 28 |
+
self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
|
| 29 |
+
self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
|
| 30 |
+
|
| 31 |
+
if i > 0:
|
| 32 |
+
self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
|
| 33 |
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
| 34 |
+
|
| 35 |
+
class Unet_TS(network.Network):
|
| 36 |
+
|
| 37 |
+
def __init__(self,
|
| 38 |
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
| 39 |
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
| 40 |
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
| 41 |
+
},
|
| 42 |
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
| 43 |
+
channels = [1, 32, 64, 128, 320, 320],
|
| 44 |
+
mri: bool = False) -> None:
|
| 45 |
+
super().__init__(
|
| 46 |
+
in_channels=channels[0],
|
| 47 |
+
optimizer=optimizer,
|
| 48 |
+
schedulers=schedulers,
|
| 49 |
+
outputs_criterions=outputs_criterions,
|
| 50 |
+
patch=None,
|
| 51 |
+
dim=3,
|
| 52 |
+
)
|
| 53 |
+
self.add_module("UNetBlock", UNetBlock(channels, mri))
|
| 54 |
+
self.add_module("Head", UNetHead(channels[1], 42))
|
| 55 |
+
|
| 56 |
+
def load(
|
| 57 |
+
self,
|
| 58 |
+
state_dict: dict[str, dict[str, torch.Tensor] | int],
|
| 59 |
+
init: bool = True,
|
| 60 |
+
ema: bool = False,
|
| 61 |
+
):
|
| 62 |
+
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
|
| 63 |
+
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 64 |
+
super().load(state_dict, init, ema)
|
| 65 |
+
|
| 66 |
+
class Combine(Reduction):
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
|
| 72 |
+
fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
|
| 73 |
+
sum_fg = fg_all.sum(dim=1, keepdim=True)
|
| 74 |
+
bg = (1.0 - sum_fg).clamp(min=1e-6)
|
| 75 |
+
|
| 76 |
+
probs = torch.cat([bg, fg_all], dim=1)
|
| 77 |
+
|
| 78 |
+
return probs / probs.sum(dim=1, keepdim=True)
|
Prediction_MR_Fast.yml β total_mr-3mm/Prediction.yml
RENAMED
|
File without changes
|
total_mr-3mm/metadata.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"display_name": "Total MR 3mm",
|
| 3 |
+
"short_description": "",
|
| 4 |
+
"description": "",
|
| 5 |
+
"tta": 0,
|
| 6 |
+
"mc_dropout": 0
|
| 7 |
+
}
|
total_mr-3mm/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
segmentation_models_pytorch
|
M850.pt β total_mr/M850.pt
RENAMED
|
File without changes
|
M851.pt β total_mr/M851.pt
RENAMED
|
File without changes
|
total_mr/Model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from konfai.network import network, blocks
|
| 3 |
+
from konfai.predictor import Reduction
|
| 4 |
+
|
| 5 |
+
class ConvBlock(network.ModuleArgsDict):
|
| 6 |
+
def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
|
| 9 |
+
self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 10 |
+
self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 11 |
+
self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
|
| 12 |
+
self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 13 |
+
self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 14 |
+
|
| 15 |
+
class UNetHead(network.ModuleArgsDict):
|
| 16 |
+
def __init__(self, in_channels: int, nb_class: int) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 19 |
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
| 20 |
+
|
| 21 |
+
class UNetBlock(network.ModuleArgsDict):
|
| 22 |
+
|
| 23 |
+
def __init__(self, channels, mri: bool, i : int = 0) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
|
| 26 |
+
|
| 27 |
+
if len(channels) > 2:
|
| 28 |
+
self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
|
| 29 |
+
self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
|
| 30 |
+
|
| 31 |
+
if i > 0:
|
| 32 |
+
self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
|
| 33 |
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
| 34 |
+
|
| 35 |
+
class Unet_TS(network.Network):
|
| 36 |
+
|
| 37 |
+
def __init__(self,
|
| 38 |
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
| 39 |
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
| 40 |
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
| 41 |
+
},
|
| 42 |
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
| 43 |
+
channels = [1, 32, 64, 128, 320, 320],
|
| 44 |
+
mri: bool = False) -> None:
|
| 45 |
+
super().__init__(
|
| 46 |
+
in_channels=channels[0],
|
| 47 |
+
optimizer=optimizer,
|
| 48 |
+
schedulers=schedulers,
|
| 49 |
+
outputs_criterions=outputs_criterions,
|
| 50 |
+
patch=None,
|
| 51 |
+
dim=3,
|
| 52 |
+
)
|
| 53 |
+
self.add_module("UNetBlock", UNetBlock(channels, mri))
|
| 54 |
+
self.add_module("Head", UNetHead(channels[1], 42))
|
| 55 |
+
|
| 56 |
+
def load(
|
| 57 |
+
self,
|
| 58 |
+
state_dict: dict[str, dict[str, torch.Tensor] | int],
|
| 59 |
+
init: bool = True,
|
| 60 |
+
ema: bool = False,
|
| 61 |
+
):
|
| 62 |
+
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
|
| 63 |
+
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 64 |
+
super().load(state_dict, init, ema)
|
| 65 |
+
|
| 66 |
+
class Combine(Reduction):
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
|
| 72 |
+
fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
|
| 73 |
+
sum_fg = fg_all.sum(dim=1, keepdim=True)
|
| 74 |
+
bg = (1.0 - sum_fg).clamp(min=1e-6)
|
| 75 |
+
|
| 76 |
+
probs = torch.cat([bg, fg_all], dim=1)
|
| 77 |
+
|
| 78 |
+
return probs / probs.sum(dim=1, keepdim=True)
|
Prediction_MR.yml β total_mr/Prediction.yml
RENAMED
|
File without changes
|
total_mr/metadata.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"display_name": "Total MR",
|
| 3 |
+
"short_description": "",
|
| 4 |
+
"description": "",
|
| 5 |
+
"tta": 0,
|
| 6 |
+
"mc_dropout": 0
|
| 7 |
+
}
|
total_mr/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
segmentation_models_pytorch
|