VBoussot's picture
Update total_mr/Model.py
f33b46f verified
import torch
from konfai.network import network, blocks
from konfai.predictor import Reduction
class ConvBlock(network.ModuleArgsDict):
def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
super().__init__()
self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
class UNetHead(network.ModuleArgsDict):
def __init__(self, in_channels: int, nb_class: int) -> None:
super().__init__()
self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
self.add_module("Softmax", torch.nn.Softmax(dim=1))
class UNetBlock(network.ModuleArgsDict):
def __init__(self, channels, mri: bool, i : int = 0) -> None:
super().__init__()
self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
if len(channels) > 2:
self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
if i > 0:
self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
class Unet_TS(network.Network):
def __init__(self,
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
schedulers: dict[str, network.LRSchedulersLoader] = {
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
},
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
channels = [1, 32, 64, 128, 320, 320],
mri: bool = False) -> None:
super().__init__(
in_channels=channels[0],
optimizer=optimizer,
schedulers=schedulers,
outputs_criterions=outputs_criterions,
patch=None,
dim=3,
)
self.add_module("UNetBlock", UNetBlock(channels, mri))
self.add_module("Head", UNetHead(channels[1], 42))
def load(
self,
state_dict: dict[str, dict[str, torch.Tensor] | int],
init: bool = True,
ema: bool = False,
):
nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
super().load(state_dict, init, ema)