Analyze Model Behvaiour of 2D Segmentation
Consider to download this Jupyter Notebook and run locally, or test it with Colab.
In this notebook, we will show how to analyze the model behaviour based using MOVAL for multi-class 2D segmentation tasks.
We provide the model predicted 2D segmentation results (network logits) for this tutorial, which will be download automatically. We also provide the model training code in https://github.com/ZerojumpLine/Robust-Medical-Segmentation.
More specifically, we show an example of analyzing model behaviour under domain shifts on Cardiac MRI segmentation (into 4 classes including background, left ventricle (LV), myocardium(MYO) and right ventricle (RV)) based on a 3D U-Net. We will utilize the calculated logits on test dataset acquired with a different scanner.
We will visualize the calibrated confidence scores as an proper indication of missegmentation.
[1]:
!pip install moval
!pip install seaborn
!pip install pandas
!pip install tqdm
!pip install matplotlib
!pip install nibabel
Requirement already satisfied: moval in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (0.3.16)
Requirement already satisfied: scikit-learn>=1.3.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.3.0)
Requirement already satisfied: scipy>=1.8.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.10.1)
Requirement already satisfied: pytest in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (7.4.3)
Requirement already satisfied: gdown in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (4.7.1)
Requirement already satisfied: pandas in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.5.3)
Requirement already satisfied: nibabel in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (5.1.0)
Requirement already satisfied: numpy>=1.17.3 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (1.24.4)
Requirement already satisfied: joblib>=1.1.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (1.3.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (3.1.0)
Requirement already satisfied: filelock in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (3.13.1)
Requirement already satisfied: requests[socks] in /Users/zejuli/.local/lib/python3.8/site-packages (from gdown->moval) (2.31.0)
Requirement already satisfied: six in /Users/zejuli/.local/lib/python3.8/site-packages (from gdown->moval) (1.16.0)
Requirement already satisfied: tqdm in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (4.65.0)
Requirement already satisfied: beautifulsoup4 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (4.12.2)
Requirement already satisfied: importlib-resources>=1.3 in /Users/zejuli/.local/lib/python3.8/site-packages (from nibabel->moval) (5.12.0)
Requirement already satisfied: packaging>=17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel->moval) (23.1)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas->moval) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas->moval) (2023.3.post1)
Requirement already satisfied: iniconfig in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (2.0.0)
Requirement already satisfied: pluggy<2.0,>=0.12 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (1.3.0)
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (1.1.3)
Requirement already satisfied: tomli>=1.0.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (2.0.1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=1.3->nibabel->moval) (3.15.0)
Requirement already satisfied: soupsieve>1.2 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from beautifulsoup4->gdown->moval) (2.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (3.1.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (2.0.3)
Requirement already satisfied: certifi>=2017.4.17 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (2023.5.7)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (1.7.1)
Requirement already satisfied: seaborn in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (0.12.0)
Requirement already satisfied: numpy>=1.17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (1.24.4)
Requirement already satisfied: pandas>=0.25 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (1.5.3)
Requirement already satisfied: matplotlib>=3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (3.7.4)
Requirement already satisfied: contourpy>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (4.46.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (10.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (5.12.0)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas>=0.25->seaborn) (2023.3.post1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib>=3.1->seaborn) (3.15.0)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib>=3.1->seaborn) (1.16.0)
Requirement already satisfied: pandas in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (1.5.3)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (2023.3.post1)
Requirement already satisfied: numpy>=1.20.3 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (1.24.4)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)
Requirement already satisfied: tqdm in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (4.65.0)
Requirement already satisfied: matplotlib in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (3.7.4)
Requirement already satisfied: contourpy>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (4.46.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.4.5)
Requirement already satisfied: numpy<2,>=1.20 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.24.4)
Requirement already satisfied: packaging>=20.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (10.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from matplotlib) (5.12.0)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.15.0)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: nibabel in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (5.1.0)
Requirement already satisfied: importlib-resources>=1.3 in /Users/zejuli/.local/lib/python3.8/site-packages (from nibabel) (5.12.0)
Requirement already satisfied: numpy>=1.19 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel) (1.24.4)
Requirement already satisfied: packaging>=17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel) (23.1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=1.3->nibabel) (3.15.0)
[2]:
import os
import gdown
import itertools
import zipfile
import pandas as pd
import numpy as np
import nibabel as nib
import moval
from moval.solvers.utils import ComputMetric
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
[3]:
print(f"The installed MOVAL verision is {moval.__version__}")
The installed MOVAL verision is 0.3.16
Load the data
[4]:
# download the data of cardiac
output = "data_moval_supp.zip"
if not os.path.exists(output):
url = "https://drive.google.com/u/0/uc?id=1ZlC66MGmPlf05aYYCKBaRT2q5uod8GFk&export=download"
output = "data_moval_supp.zip"
gdown.download(url, output, quiet=False)
directory_data = "data_moval_supp"
if not os.path.exists(directory_data):
with zipfile.ZipFile(output, 'r') as zip_ref:
zip_ref.extractall(directory_data)
[5]:
# download the coresponding image data
output = "img_cardiac.zip"
if not os.path.exists(output):
url = "https://drive.google.com/u/0/uc?id=1kS5V69dfdPEGiMfauMLuhi76Vrb8VB5k&export=download"
output = "img_cardiac.zip"
gdown.download(url, output, quiet=False)
directory_data = "img_cardiac"
if not os.path.exists(directory_data):
with zipfile.ZipFile(output, 'r') as zip_ref:
zip_ref.extractall(directory_data)
[6]:
ls
analysis_cls.ipynb data_moval_supp.zip img_cifar/
analysis_seg2d.ipynb estim_cls.ipynb img_cifar.zip
analysis_seg3d.ipynb estim_seg2d.ipynb img_prostate/
data_moval/ estim_seg3d.ipynb img_prostate.zip
data_moval.zip img_cardiac/
data_moval_supp/ img_cardiac.zip
[7]:
# now I am playing with cardiac segmentation
Datafile_eval = "data_moval_supp/Cardiacresults/seg-eval.txt"
Imglist_eval = open(Datafile_eval)
Imglist_eval_read = Imglist_eval.read().splitlines()
logits = []
gt = []
for Imgname_eval in Imglist_eval_read:
#
caseID = Imgname_eval.split("/")[-2]
#
GT_file = f"data_moval_supp/Cardiacresults/GT/1/{caseID}/seg.nii.gz"
#
logit_cls0_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls0_prob.nii.gz"
logit_cls1_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls1_prob.nii.gz"
logit_cls2_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls2_prob.nii.gz"
logit_cls3_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls3_prob.nii.gz"
#
logit_cls0_read = nib.load(logit_cls0_file)
logit_cls1_read = nib.load(logit_cls1_file)
logit_cls2_read = nib.load(logit_cls2_file)
logit_cls3_read = nib.load(logit_cls3_file)
#
logit_cls0 = logit_cls0_read.get_fdata() # ``(H, W, D)``
logit_cls1 = logit_cls1_read.get_fdata()
logit_cls2 = logit_cls2_read.get_fdata()
logit_cls3 = logit_cls3_read.get_fdata()
#
GT_read = nib.load(GT_file)
GTimg = GT_read.get_fdata() # ``(H, W, D)``
#
logit_cls = np.stack((logit_cls0, logit_cls1, logit_cls2, logit_cls3)) # ``(d, H, W, D)``
# only including the slices that contains labels
for dslice in range(GTimg.shape[2]):
if np.sum(GTimg[:, :, dslice]) > 0:
logits.append(logit_cls[:, :, :, dslice])
gt.append(GTimg[:, :, dslice])
# logits is a list of length ``n``, each element has ``(d, H, W)``.
# gt is a list of length ``n``, each element has ``(H, W)``.
# H and W could differ for different cases.
caseIDs = ['A5D0G0_0', 'B0N3W8_15', 'E5S7W7_12', 'H7N4V9_0', 'I0I2J8_7']
logits_test = []
gt_test = []
imgs_test = []
for caseID in caseIDs:
GT_file = f"data_moval_supp/Cardiacresults/GT/2/{caseID}/seg.nii.gz"
#
logit_cls0_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls0_prob.nii.gz"
logit_cls1_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls1_prob.nii.gz"
logit_cls2_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls2_prob.nii.gz"
logit_cls3_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls3_prob.nii.gz"
#
logit_cls0_read = nib.load(logit_cls0_file)
logit_cls1_read = nib.load(logit_cls1_file)
logit_cls2_read = nib.load(logit_cls2_file)
logit_cls3_read = nib.load(logit_cls3_file)
#
logit_cls0 = logit_cls0_read.get_fdata() # ``(H, W, D)``
logit_cls1 = logit_cls1_read.get_fdata()
logit_cls2 = logit_cls2_read.get_fdata()
logit_cls3 = logit_cls3_read.get_fdata()
#
img_file = f"img_cardiac/img_cardiac/2/{caseID}/image.nii.gz"
img_read = nib.load(img_file)
img_data = img_read.get_fdata() # ``(H, W, D)``
#
GT_read = nib.load(GT_file)
GTimg = GT_read.get_fdata() # ``(H, W, D)``
logit_cls = np.stack((logit_cls0, logit_cls1, logit_cls2, logit_cls3)) # ``(d, H, W, D)``
# only including the slices that contains labels
for dslice in range(GTimg.shape[2]):
if np.sum(GTimg[:, :, dslice]) > 0:
logits_test.append(logit_cls[:, :, :, dslice])
gt_test.append(GTimg[:, :, dslice])
imgs_test.append(img_data[:, :, dslice])
# logits_test is a list of length ``n``, each element has ``(d, H, W)``.
# gt_test is a list of length ``n``, each element has ``(H, W)``.
# H and W could differ for different cases.
[8]:
print(f"The validation predictions, ``logits`` are a list of length {len(logits)} each element has approximately {logits[0].shape}")
print(f"The validation labels, ``gt`` are a list of length {len(gt)}, each element has approximately {gt[0].shape}\n")
print(f"The test predictions, ``logits_test`` are a list of length {len(logits_test)} each element has approximately {logits_test[0].shape}")
print(f"The test labels, ``gt_test`` are a list of length {len(gt_test)}, each element has approximately {gt_test[0].shape}")
print(f"The test imgs, ``imgs_test`` are a list of length {len(imgs_test)}, each element has approximately {imgs_test[0].shape}")
The validation predictions, ``logits`` are a list of length 156 each element has approximately (4, 210, 257)
The validation labels, ``gt`` are a list of length 156, each element has approximately (210, 257)
The test predictions, ``logits_test`` are a list of length 38 each element has approximately (4, 338, 338)
The test labels, ``gt_test`` are a list of length 38, each element has approximately (338, 338)
The test imgs, ``imgs_test`` are a list of length 38, each element has approximately (338, 338)
[9]:
import random
random.seed(79)
test_inds = list(range(len(logits)))
random.shuffle(test_inds)
test_inds = test_inds[:100]
#
_logits = []
_gt = []
for test_ind in test_inds:
_logits.append(logits[test_ind])
_gt.append(gt[test_ind])
logits_val = _logits
gt_val = _gt
#
print(f"The validation predictions, ``logits`` are a list of length {len(logits_val)} each element has approximately {logits_val[0].shape}")
print(f"The validation labels, ``gt`` are a list of length {len(gt_val)}, each element has approximately {gt_val[0].shape}")
The validation predictions, ``logits`` are a list of length 100 each element has approximately (4, 223, 272)
The validation labels, ``gt`` are a list of length 100, each element has approximately (223, 272)
MOVAL estimation
[10]:
moval_options = []
moval_options.append(['ac-model', 'segmentation', 'max_class_probability-conf', False])
moval_options.append(['ts-model', 'segmentation', 'max_class_probability-conf', False])
moval_options.append(['ts-model', 'segmentation', 'max_class_probability-conf', True])
moval_options.append(['ts-atc-model', 'segmentation', 'entropy-conf', True])
[11]:
def test_cls(estim_algorithm, mode, confidence_scores, class_specific, logits, gt, logits_tests, gt_tests):
"""Test MOVAL with different conditions for segmentation tasks
Args:
mode (str): The given task to estimate model performance.
confidence_scores (str):
The method to calculate the confidence scores. We provide a list of confidence score calculation methods which
can be displayed by running :py:func:`moval.models.get_conf_options`.
estim_algorithm (str):
The algorithm to estimate model performance. We also provide a list of estimation algorithm which can be displayed by
running :py:func:`moval.models.get_estim_options`.
class_specific (bool):
If ``True``, the calculation will match class-wise confidence to class-wise accuracy.
logits: The network output (logits) of a list of n ``(d, H, W, (D))`` for segmentation.
gt: The cooresponding annotation of a list of n ``(H, W, (D))`` for segmentation.
logits_tests: The network testing output (logits) of a list of n' ``(d, H', W', (D'))`` for segmentation.
gt_test: The cooresponding testing annotation of a list of n' ``(H', W', (D'))`` for segmentation.
Returns:
err_test (float): testing error.
moval_model: Optimized moval model.
"""
moval_model = moval.MOVAL(
mode = mode,
metric = "f1score",
confidence_scores = confidence_scores,
estim_algorithm = estim_algorithm,
class_specific = class_specific,
approximate = True,
approximate_boundary = 10
)
#
moval_model.fit(logits, gt)
# save the test err in the result files.
estim_dsc_test = moval_model.estimate(logits_test)
return moval_model
[12]:
moval_model_MCP = test_cls(
estim_algorithm = moval_options[0][0],
mode = moval_options[0][1],
confidence_scores = moval_options[0][2],
class_specific = moval_options[0][3],
logits = logits_val,
gt = gt_val,
logits_tests = logits_test,
gt_tests = gt_test
)
#
moval_model_baseline = test_cls(
estim_algorithm = moval_options[1][0],
mode = moval_options[1][1],
confidence_scores = moval_options[1][2],
class_specific = moval_options[1][3],
logits = logits_val,
gt = gt_val,
logits_tests = logits_test,
gt_tests = gt_test
)
#
moval_model_cs = test_cls(
estim_algorithm = moval_options[2][0],
mode = moval_options[2][1],
confidence_scores = moval_options[2][2],
class_specific = moval_options[2][3],
logits = logits_val,
gt = gt_val,
logits_tests = logits_test,
gt_tests = gt_test
)
#
moval_model_cs_entropy = test_cls(
estim_algorithm = moval_options[3][0],
mode = moval_options[3][1],
confidence_scores = moval_options[3][2],
class_specific = moval_options[3][3],
logits = logits_val,
gt = gt_val,
logits_tests = logits_test,
gt_tests = gt_test
)
Starting optimizing for model ac-model with confidence max_class_probability-conf based on metric f1score, class specific is False.
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-model with confidence max_class_probability-conf based on metric f1score, class specific is False.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-model with confidence max_class_probability-conf based on metric f1score, class specific is True.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-atc-model with confidence entropy-conf based on metric f1score, class specific is True.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...
Get the confidence scores
[13]:
conf_mcp = moval_model_MCP.model_.calibrate(logits_test) # n list of (H, W)
conf_baseline = moval_model_baseline.model_.calibrate(logits_test) # n list of (H, W)
conf_cs = moval_model_cs.model_.calibrate(logits_test) # n list of (H, W)
conf_cs_entropy = moval_model_cs_entropy.model_.calibrate(logits_test) # n list of (H, W)
[14]:
def caculate_dsc(score, inp, gt):
"""Calculate real dsc and estimated dsc
Args:
score: Confidence scores of a list of n ``(H, W, (D))``.
inp: Network output of a list of n ``(d, H, W, (D))``
gt: Ground truth segmentation of a list of n ``(H, W, (D))``.
Return:
estim_dsc: The mean estimated DSC for cases of shape ``(n,)``.
real_dsc: The mean real DSC of for cases shape ``(n,)``.
"""
from moval.models.utils import SoftDiceLoss
num_class = inp[0].shape[0]
estim_dsc_list = []
#
pred_all_flatten_bg = []
gt_all_flatten_bg = []
dsc = []
for n_case in range(len(inp)):
pred_case = np.argmax(inp[n_case], axis = 0) # ``(H, W, (D))``
pred_flatten = pred_case.flatten() # ``n``
score_case = score[n_case] # ``(H, W, (D))``
score_flatten = score[n_case].flatten() # ``n``
score_filled = np.zeros((score_flatten.shape + (num_class,))) # ``(n, d)``
score_filled[np.arange(score_filled.shape[0]), pred_flatten] = score_flatten # ``(n, d)``
score_filled = score_filled.T # ``(d, n)``
score_filled = score_filled.reshape(((num_class,) + score_case.shape)) # ``(d, H, W, (D))``
#
estim_dsc = SoftDiceLoss(score_filled[np.newaxis, ...], pred_case[np.newaxis, ...])
#
gt_case = gt[n_case] # ``(H, W, (D))``
pos_bg = np.where(pred_case.flatten() == 0)[0]
pred_all_flatten_bg.append(pred_case.flatten()[pos_bg])
gt_all_flatten_bg.append(gt_case.flatten()[pos_bg])
dsc_case = np.zeros(inp[n_case].shape[0])
for kcls in range(1, inp[n_case].shape[0]):
if np.sum(gt_case == kcls) == 0:
estim_dsc[kcls] = -1
dsc_case[kcls] = -1
else:
dsc_cal, _, _ = ComputMetric(pred_case == kcls, gt_case == kcls)
dsc_case[kcls] = dsc_cal
estim_dsc_list.append(np.mean(estim_dsc[1:][dsc_case[1:] >= 0]))
dsc.append(np.mean(dsc_case[1:][dsc_case[1:] >= 0]))
return estim_dsc_list, dsc
Plot the discrepancy between estimatation and real dsc
[15]:
def lineplot(estim_dsc, dsc, title = "Estimation Error with MCP"):
"""Make the line plot to show the estimated segmentaiton error
Args:
estim_dsc: The estimated dsc for all the case with shape ``(n, )``
dsc: The real dsc for all the case with shape ``(n, )``
title: The given title to be shown.
"""
d = {'Estimated DSC': estim_dsc, 'Real DSC': dsc}
df = pd.DataFrame(data=d)
df = df.sort_values(by="Real DSC", ascending=False, na_position='first')
df = df.reset_index(drop=True)
df.insert(2, "ID", range(len(dsc)), True)
sns.set(rc={'figure.figsize':(12, 3)})
sns.set_style("darkgrid")
# Line plot
ax = sns.lineplot(
data=df,
x="ID", y="Real DSC",
color='#0074b6',
label='Real DSC'
)
# Dot plot
sns.scatterplot(
data=df,
x="ID", y="Real DSC",
color='#0074b6',
marker='o',
s=50,
label=None
)
# Fill between the lines
plt.fill_between(
df['ID'],
df['Estimated DSC'],
df['Real DSC'],
color='#a52a2a',
alpha=0.2,
label='Estimatation Error'
)
ax.set_xticks([])
# Add legend
plt.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.title(title)
plt.show()
[16]:
estim_dsc_mcp, dsc_mcp = caculate_dsc(conf_mcp, logits_test, gt_test)
[17]:
lineplot(estim_dsc_mcp, dsc_mcp, title = "Estimation Error with MCP")
Plot the confidence map
[18]:
def visual_two_case(conf_maps, imgs_toshow, lbls_toshow, preds_toshow, _real_dscs, _estim_dscs, title = 'The Visualization of MCP for Cases with High and Low DSC'):
"""Visualization of confidence score for two cases.
Args:
conf_maps: A list of confidence maps, of shape ``(H, W)``.
imgs_toshow: A list of image, of shape ``(H, W)``.
lbls_toshow: A list of label, of shape ``(H, W)``.
preds_toshow: A list of prediction, of shape ``(H, W)``.
_real_dscs: A list of real DSC (float).
_estim_dscs: A list of esimtaed DSC (float).
title: The given title to be shown.
"""
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(15, 8))
gs = GridSpec(2, 3, width_ratios=[1, 1, 1.2]) # Adjust the ratios as needed
cmap = sns.color_palette("ch:s=-.2,r=.6", as_cmap=True)
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1])
ax2 = plt.subplot(gs[2])
ax3 = plt.subplot(gs[3])
ax4 = plt.subplot(gs[4])
ax5 = plt.subplot(gs[5])
# Define colors for each class
class_colors = ['', 'green', 'blue', 'red'] # Add more colors if you have more classes
ax0.imshow(imgs_toshow[0], cmap='gray')
ax0.set_title('Image Overlapped with Ground Truth', x=0.5, y=1.05)
# Overlay the label map using class-specific colors
for i in range(len(class_colors)):
if i == 0: # Skip label 0
continue
mask = lbls_toshow[0] == i
label_map_rgb = np.zeros((*lbls_toshow[0].shape, 4))
label_map_rgb[mask, :] = plt.cm.colors.to_rgba(class_colors[i])
ax0.imshow(label_map_rgb, alpha=0.5)
ax0.set_xticks([])
ax0.set_yticks([])
sns.heatmap(np.abs(preds_toshow[0] - lbls_toshow[0]), cmap=cmap, xticklabels=False, yticklabels=False, cbar=False, ax=ax1)
ax1.set_title('Absolute Error between Prediction and Ground Truth', x=0.5, y=1.1)
sns.heatmap(1 - conf_maps[0], cmap=cmap, xticklabels=False, yticklabels=False, cbar_kws={'format': ''}, ax=ax2)
cbar = ax2.collections[0].colorbar
ax2.set_title('Confidence Score', x=0.5, y=1.1)
cbar.set_ticks([])
cbar.ax.text(2.5, 0.04, 'min', ha='center', va='center')
cbar.ax.text(2.5, 0.5, 'max', ha='center', va='center')
#
ax3.imshow(imgs_toshow[1], cmap='gray')
ax3.set_title('Image Overlapped with Ground Truth', x=0.5, y=1.05)
# Overlay the label map using class-specific colors
for i in range(len(class_colors)):
if i == 0: # Skip label 0
continue
mask = lbls_toshow[1] == i
label_map_rgb = np.zeros((*lbls_toshow[1].shape, 4))
label_map_rgb[mask, :] = plt.cm.colors.to_rgba(class_colors[i])
ax3.imshow(label_map_rgb, alpha=0.5)
ax3.set_xticks([])
ax3.set_yticks([])
sns.heatmap(np.abs(preds_toshow[1] - lbls_toshow[1]), cmap=cmap, xticklabels=False, yticklabels=False, cbar=False, ax=ax4)
ax4.set_title('Absolute Error between Prediction and Ground Truth', x=0.5, y=1.1)
sns.heatmap(1 - conf_maps[1], cmap=cmap, xticklabels=False, yticklabels=False, cbar_kws={'format': ''}, ax=ax5)
cbar = ax5.collections[0].colorbar
ax5.set_title('Confidence Score', x=0.5, y=1.1)
cbar.set_ticks([])
cbar.ax.text(2.5, 0.04, 'min', ha='center', va='center')
ax1.text(0.5, 1.05, f'Real DSC is {_real_dscs[0]:.3f}', transform=ax1.transAxes, ha='center', va='center', fontsize=12)
ax4.text(0.5, 1.05, f'Real DSC is {_real_dscs[1]:.3f}', transform=ax4.transAxes, ha='center', va='center', fontsize=12)
ax2.text(0.5, 1.05, f'Estimated DSC is {_estim_dscs[0]:.3f}', transform=ax2.transAxes, ha='center', va='center', fontsize=12)
ax5.text(0.5, 1.05, f'Estimated DSC is {_estim_dscs[1]:.3f}', transform=ax5.transAxes, ha='center', va='center', fontsize=12)
cbar.ax.text(2.5, 0.5, 'max', ha='center', va='center')
fig.suptitle(title, fontsize=16, y=1)
[19]:
dsc_sort = np.argsort(dsc_mcp)
imgs_toshow = [imgs_test[dsc_sort[-1]], imgs_test[dsc_sort[0]]]
lbls_toshow = [gt_test[dsc_sort[-1]], gt_test[dsc_sort[0]]]
preds_toshow = [np.argmax(logits_test[dsc_sort[-1]], axis=0), np.argmax(logits_test[dsc_sort[0]], axis=0)]
_real_dscs = [dsc_mcp[dsc_sort[-1]], dsc_mcp[dsc_sort[0]]]
#
conf_maps = [conf_mcp[dsc_sort[-1]], conf_mcp[dsc_sort[0]]]
_estim_dscs = [estim_dsc_mcp[dsc_sort[-1]], estim_dsc_mcp[dsc_sort[0]]]
[20]:
visual_two_case(conf_maps, imgs_toshow, lbls_toshow, preds_toshow, _real_dscs, _estim_dscs, title = 'The Visualization of MCP for Cases with High and Low DSC')
[ ]:
[ ]: