Valentin Boussot
Add fast inference
2d34814
import torch
import requests
from tqdm import tqdm
import zipfile
import shutil
from pathlib import Path
import os
from functools import partial
from Model import Unet_TS
def convert_torchScript_full(model_name: str, model: torch.nn.Module, type: int, url: str) -> None:
state_dict = download(url)
tmp = {}
with open("Destination_Unet_{}.txt".format(type)) as f2:
it = iter(state_dict.keys())
for l1 in f2:
key = next(it)
while "decoder.seg_layers" in key:
if type == 1:
if "decoder.seg_layers.4" in key :
break
if type == 2:
if "decoder.seg_layers.3" in key:
break
if type == 3:
if "decoder.seg_layers.2" in key:
break
key = next(it)
while "all_modules" in key or "decoder.encoder" in key:
key = next(it)
tmp[l1.replace("\n", "")] = state_dict[key]
model.load_state_dict(tmp)
torch.save({"Model" : {"Unet_TS" : tmp}}, f"{model_name}.pt")
def download(url: str) -> dict[str, torch.Tensor]:
with open(url.split("/")[-1], 'wb') as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading")
for chunk in r.iter_content(chunk_size=8192 * 16):
progress_bar.update(len(chunk))
f.write(chunk)
progress_bar.close()
with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
os.remove(url.split("/")[-1])
state_dict = torch.load(next(Path(url.split("/")[-1].replace(".zip", "")).rglob("checkpoint_final.pth"), None), weights_only=False)["network_weights"]
shutil.rmtree(url.split("/")[-1].replace(".zip", ""))
return state_dict
url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
UnetCPP_1 = partial(Unet_TS, channels = [1,32,64,128,256,320,320])
UnetCPP_2 = partial(Unet_TS, channels = [1,32,64,128,256,320])
UnetCPP_3 = partial(Unet_TS, channels = [1,32,64,128,256])
models = {
"M291" : (UnetCPP_1(nb_class=25), 1, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
"M292" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
"M293" : (UnetCPP_1(nb_class=19), 1, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
"M294" : (UnetCPP_1(nb_class=24), 1, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
"M295" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
"M297" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
"M298" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
"M730" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
"M731" : (UnetCPP_1(nb_class=28, mri = True), 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
"M732" : (UnetCPP_2(nb_class=57), 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
"M733" : (UnetCPP_3(nb_class=57), 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
"M850" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
"M851" : (UnetCPP_1(nb_class=22, mri = True), 1, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
"M852" : (UnetCPP_2(nb_class=51), 2, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
"M853" : (UnetCPP_3(nb_class=51), 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")}
if __name__ == "__main__":
for name, model in models.items():
convert_torchScript_full(name, model[0], model[1], model[2])