|
|
--- |
|
|
library_name: transformers |
|
|
tags: |
|
|
- radiology |
|
|
- ct |
|
|
- organ |
|
|
- classification |
|
|
license: apache-2.0 |
|
|
base_model: |
|
|
- timm/tf_efficientnetv2_b0.in1k |
|
|
pipeline_tag: image-classification |
|
|
--- |
|
|
|
|
|
# TotalClassifier: Slice-Level Organ Classification for CT Examinations |
|
|
|
|
|
TotalClassifier is a classification model which predicts the presence of various organs on a 2D slice from a CT volume. |
|
|
It supports axial, sagittal, and coronal images, and a variety of windowing parameters. |
|
|
This model uses a `tf_efficientnetv2_b0` backbone with a gated recurrent unit (GRU) head which performs sequence modeling across extracted slice-level features. |
|
|
The model also works with single 2D images. |
|
|
|
|
|
The model is trained on the publicly available [TotalSegmentator dataset](https://zenodo.org/records/10047292), version 2.0.1. It predicts 117 labels corresponding to the |
|
|
available labels from TotalSegmentator. The classification labels were generated from the provided segmentation labels. |
|
|
|
|
|
Note that the model expects one channel. If you create a multi-channel image using multiple CT windows, simply take the mean across channels. |
|
|
The model also expects 8-bit input (converted to float). Thus if your CT volume is in Hounsfield units, you can apply a standard window, |
|
|
such as soft tissue (level=50, width=400), before inputting it into the model. |
|
|
|
|
|
## Example Usage |
|
|
|
|
|
``` |
|
|
import torch |
|
|
from transformers import AutoModel |
|
|
|
|
|
device = "cuda" |
|
|
organ_model = AutoModel.from_pretrained("ianpan/total-classifier", trust_remote_code=True).eval().to(device) |
|
|
|
|
|
# can use model to load CT from folder with DICOM files, if pydicom is installed |
|
|
# here we apply soft tissue window |
|
|
ct_volume = organ_model.load_stack_from_dicom_folder("/path/to/dicom/folder", windows=[[50, 400]], dicom_extension=".dcm") |
|
|
|
|
|
# ct_volume.shape is (num_slices, height, width, num_channels) if applying windows |
|
|
# otherwise is (num_slices, height, width) if using original Hounsfield units |
|
|
|
|
|
# preprocess |
|
|
x = model.preprocess(ct_volume, mode="3d", torchify=True, add_batch_dim=True, device=device) |
|
|
|
|
|
# here, ct_volume is a numpy array |
|
|
# if you are loading volumes as torch.Tensors, then you can skip the preprocess function |
|
|
# and just resize the volume to height and width of 256 x 256 |
|
|
|
|
|
# x is now torch.Tensor with shape (1, num_slices, num_channels, height, width) |
|
|
# note that these are the expected dims for the model's forward method |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = organ_model(x) |
|
|
out_df = organ_model(x, return_as_df=True) |
|
|
|
|
|
# out is a torch.Tensor of shape (1, num_slices, 117) containing scores [0-1] for each organ label |
|
|
# out_df is a list of pandas DataFrames with shape (num_slices, 117), where column names are the organ names |
|
|
# each element of the list corresponds to each sample in the batch |
|
|
# however if using batch sizes >1, then all samples need to be padded to the same number of slices |
|
|
|
|
|
# you can use out_df to only get slices with predicted organ labels greater than a certain threshold |
|
|
out_df = out_df[0] |
|
|
threshold = 0.5 |
|
|
liver_indices = np.where(out_df["liver"].values >= threshold)[0] |
|
|
|
|
|
# or slices where at least one of the specified organ labels is greater than threshold |
|
|
organs_of_interest = ["liver", "spleen", "pancreas"] |
|
|
threshold = 0.5 |
|
|
slice_indices = np.where((out_df[organs_of_interest].values >= threshold).max(1))[0] |
|
|
|
|
|
# organ_model.label2index can be used to convert organ label names to the indices 0-116 |
|
|
# organ_model.index2label is the inverse |
|
|
``` |
|
|
|
|
|
If you have a large number of slices and limited GPU memory, you can either process the volume in chunks, |
|
|
or downsample the volume along the slice dimension and interpolate the predictions back to the original number of slices. |