File size: 3,425 Bytes
6ec9956
 
2d34814
6ec9956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d34814
6ec9956
 
 
2d34814
6ec9956
 
 
 
2d34814
6ec9956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d34814
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from konfai.network import network, blocks
from konfai.predictor import Reduction

class ConvBlock(network.ModuleArgsDict):
    def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
        super().__init__()
        self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
        self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
        self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
        self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
        self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
        self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
        
class UNetHead(network.ModuleArgsDict):
    def __init__(self, in_channels: int, nb_class: int) -> None:
        super().__init__()
        self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
        self.add_module("Softmax", torch.nn.Softmax(dim=1))
    
class UNetBlock(network.ModuleArgsDict):

    def __init__(self, channels, mri: bool, i : int = 0) -> None:
        super().__init__()
        self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
     
        if len(channels) > 2:
            self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
            self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
            
        if i > 0:
            self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size =  (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
            self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])

class Unet_TS(network.Network):

    def __init__(self,
        optimizer: network.OptimizerLoader = network.OptimizerLoader(),
        schedulers: dict[str, network.LRSchedulersLoader] = {
            "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
        },
        outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
        channels = [1, 32, 64, 128, 320, 320], 
        mri: bool = False) -> None:
        super().__init__(
            in_channels=channels[0],
            optimizer=optimizer,
            schedulers=schedulers,
            outputs_criterions=outputs_criterions,
            patch=None,
            dim=3,
        )
        self.add_module("UNetBlock", UNetBlock(channels, mri))
        self.add_module("Head", UNetHead(channels[1], 42))

    def load(
        self,
        state_dict: dict[str, dict[str, torch.Tensor] | int],
        init: bool = True,
        ema: bool = False,
    ):
        nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
        self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
        super().load(state_dict, init, ema)