Analyze Model Behvaiour of 2D Segmentation

Consider to download this Jupyter Notebook and run locally, or test it with Colab. Download Open In Colab

  • In this notebook, we will show how to analyze the model behaviour based using MOVAL for multi-class 2D segmentation tasks.

  • We provide the model predicted 2D segmentation results (network logits) for this tutorial, which will be download automatically. We also provide the model training code in https://github.com/ZerojumpLine/Robust-Medical-Segmentation.

  • More specifically, we show an example of analyzing model behaviour under domain shifts on Cardiac MRI segmentation (into 4 classes including background, left ventricle (LV), myocardium(MYO) and right ventricle (RV)) based on a 3D U-Net. We will utilize the calculated logits on test dataset acquired with a different scanner.

  • We will visualize the calibrated confidence scores as an proper indication of missegmentation.

[1]:
!pip install moval
!pip install seaborn
!pip install pandas
!pip install tqdm
!pip install matplotlib
!pip install nibabel
Requirement already satisfied: moval in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (0.3.16)
Requirement already satisfied: scikit-learn>=1.3.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.3.0)
Requirement already satisfied: scipy>=1.8.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.10.1)
Requirement already satisfied: pytest in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (7.4.3)
Requirement already satisfied: gdown in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (4.7.1)
Requirement already satisfied: pandas in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (1.5.3)
Requirement already satisfied: nibabel in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from moval) (5.1.0)
Requirement already satisfied: numpy>=1.17.3 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (1.24.4)
Requirement already satisfied: joblib>=1.1.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (1.3.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from scikit-learn>=1.3.0->moval) (3.1.0)
Requirement already satisfied: filelock in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (3.13.1)
Requirement already satisfied: requests[socks] in /Users/zejuli/.local/lib/python3.8/site-packages (from gdown->moval) (2.31.0)
Requirement already satisfied: six in /Users/zejuli/.local/lib/python3.8/site-packages (from gdown->moval) (1.16.0)
Requirement already satisfied: tqdm in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (4.65.0)
Requirement already satisfied: beautifulsoup4 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from gdown->moval) (4.12.2)
Requirement already satisfied: importlib-resources>=1.3 in /Users/zejuli/.local/lib/python3.8/site-packages (from nibabel->moval) (5.12.0)
Requirement already satisfied: packaging>=17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel->moval) (23.1)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas->moval) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas->moval) (2023.3.post1)
Requirement already satisfied: iniconfig in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (2.0.0)
Requirement already satisfied: pluggy<2.0,>=0.12 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (1.3.0)
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (1.1.3)
Requirement already satisfied: tomli>=1.0.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pytest->moval) (2.0.1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=1.3->nibabel->moval) (3.15.0)
Requirement already satisfied: soupsieve>1.2 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from beautifulsoup4->gdown->moval) (2.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (3.1.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (2.0.3)
Requirement already satisfied: certifi>=2017.4.17 in /Users/zejuli/.local/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (2023.5.7)
Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from requests[socks]->gdown->moval) (1.7.1)
Requirement already satisfied: seaborn in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (0.12.0)
Requirement already satisfied: numpy>=1.17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (1.24.4)
Requirement already satisfied: pandas>=0.25 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (1.5.3)
Requirement already satisfied: matplotlib>=3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from seaborn) (3.7.4)
Requirement already satisfied: contourpy>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (4.46.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (10.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from matplotlib>=3.1->seaborn) (5.12.0)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas>=0.25->seaborn) (2023.3.post1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib>=3.1->seaborn) (3.15.0)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib>=3.1->seaborn) (1.16.0)
Requirement already satisfied: pandas in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (1.5.3)
Requirement already satisfied: python-dateutil>=2.8.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (2023.3.post1)
Requirement already satisfied: numpy>=1.20.3 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from pandas) (1.24.4)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)
Requirement already satisfied: tqdm in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (4.65.0)
Requirement already satisfied: matplotlib in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (3.7.4)
Requirement already satisfied: contourpy>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.1.1)
Requirement already satisfied: cycler>=0.10 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (4.46.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.4.5)
Requirement already satisfied: numpy<2,>=1.20 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (1.24.4)
Requirement already satisfied: packaging>=20.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (10.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from matplotlib) (5.12.0)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.15.0)
Requirement already satisfied: six>=1.5 in /Users/zejuli/.local/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: nibabel in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (5.1.0)
Requirement already satisfied: importlib-resources>=1.3 in /Users/zejuli/.local/lib/python3.8/site-packages (from nibabel) (5.12.0)
Requirement already satisfied: numpy>=1.19 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel) (1.24.4)
Requirement already satisfied: packaging>=17 in /Users/zejuli/opt/anaconda3/envs/moval/lib/python3.8/site-packages (from nibabel) (23.1)
Requirement already satisfied: zipp>=3.1.0 in /Users/zejuli/.local/lib/python3.8/site-packages (from importlib-resources>=1.3->nibabel) (3.15.0)
[2]:
import os
import gdown
import itertools
import zipfile
import pandas as pd
import numpy as np
import nibabel as nib
import moval
from moval.solvers.utils import ComputMetric
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
[3]:
print(f"The installed MOVAL verision is {moval.__version__}")
The installed MOVAL verision is 0.3.16

Load the data

[4]:
# download the data of cardiac

output = "data_moval_supp.zip"
if not os.path.exists(output):
    url = "https://drive.google.com/u/0/uc?id=1ZlC66MGmPlf05aYYCKBaRT2q5uod8GFk&export=download"
    output = "data_moval_supp.zip"
    gdown.download(url, output, quiet=False)

directory_data = "data_moval_supp"
if not os.path.exists(directory_data):
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall(directory_data)
[5]:
# download the coresponding image data

output = "img_cardiac.zip"
if not os.path.exists(output):
    url = "https://drive.google.com/u/0/uc?id=1kS5V69dfdPEGiMfauMLuhi76Vrb8VB5k&export=download"
    output = "img_cardiac.zip"
    gdown.download(url, output, quiet=False)

directory_data = "img_cardiac"
if not os.path.exists(directory_data):
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall(directory_data)
[6]:
ls
analysis_cls.ipynb    data_moval_supp.zip   img_cifar/
analysis_seg2d.ipynb  estim_cls.ipynb       img_cifar.zip
analysis_seg3d.ipynb  estim_seg2d.ipynb     img_prostate/
data_moval/           estim_seg3d.ipynb     img_prostate.zip
data_moval.zip        img_cardiac/
data_moval_supp/      img_cardiac.zip
[7]:
# now I am playing with cardiac segmentation

Datafile_eval = "data_moval_supp/Cardiacresults/seg-eval.txt"
Imglist_eval = open(Datafile_eval)
Imglist_eval_read = Imglist_eval.read().splitlines()

logits = []
gt = []
for Imgname_eval in Imglist_eval_read:
    #
    caseID = Imgname_eval.split("/")[-2]
    #
    GT_file = f"data_moval_supp/Cardiacresults/GT/1/{caseID}/seg.nii.gz"
    #
    logit_cls0_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls0_prob.nii.gz"
    logit_cls1_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls1_prob.nii.gz"
    logit_cls2_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls2_prob.nii.gz"
    logit_cls3_file = "data_moval_supp/Cardiacresults/cardiacval/results/pred_" + caseID + "cls3_prob.nii.gz"
    #
    logit_cls0_read = nib.load(logit_cls0_file)
    logit_cls1_read = nib.load(logit_cls1_file)
    logit_cls2_read = nib.load(logit_cls2_file)
    logit_cls3_read = nib.load(logit_cls3_file)
    #
    logit_cls0      = logit_cls0_read.get_fdata()   # ``(H, W, D)``
    logit_cls1      = logit_cls1_read.get_fdata()
    logit_cls2      = logit_cls2_read.get_fdata()
    logit_cls3      = logit_cls3_read.get_fdata()
    #
    GT_read         = nib.load(GT_file)
    GTimg           = GT_read.get_fdata()           # ``(H, W, D)``
    #
    logit_cls = np.stack((logit_cls0, logit_cls1, logit_cls2, logit_cls3))  # ``(d, H, W, D)``
    # only including the slices that contains labels
    for dslice in range(GTimg.shape[2]):
        if np.sum(GTimg[:, :, dslice]) > 0:
            logits.append(logit_cls[:, :, :, dslice])
            gt.append(GTimg[:, :, dslice])

# logits is a list of length ``n``,  each element has ``(d, H, W)``.
# gt is a list of length ``n``,  each element has ``(H, W)``.
# H and W could differ for different cases.

caseIDs = ['A5D0G0_0', 'B0N3W8_15', 'E5S7W7_12', 'H7N4V9_0', 'I0I2J8_7']

logits_test = []
gt_test = []
imgs_test = []
for caseID in caseIDs:
    GT_file = f"data_moval_supp/Cardiacresults/GT/2/{caseID}/seg.nii.gz"
    #
    logit_cls0_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls0_prob.nii.gz"
    logit_cls1_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls1_prob.nii.gz"
    logit_cls2_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls2_prob.nii.gz"
    logit_cls3_file = "data_moval_supp/Cardiacresults/cardiactest_2/results/pred_" + caseID + "cls3_prob.nii.gz"
    #
    logit_cls0_read = nib.load(logit_cls0_file)
    logit_cls1_read = nib.load(logit_cls1_file)
    logit_cls2_read = nib.load(logit_cls2_file)
    logit_cls3_read = nib.load(logit_cls3_file)
    #
    logit_cls0      = logit_cls0_read.get_fdata()   # ``(H, W, D)``
    logit_cls1      = logit_cls1_read.get_fdata()
    logit_cls2      = logit_cls2_read.get_fdata()
    logit_cls3      = logit_cls3_read.get_fdata()
    #
    img_file = f"img_cardiac/img_cardiac/2/{caseID}/image.nii.gz"
    img_read = nib.load(img_file)
    img_data      = img_read.get_fdata()   # ``(H, W, D)``
    #
    GT_read         = nib.load(GT_file)
    GTimg           = GT_read.get_fdata()           # ``(H, W, D)``
    logit_cls = np.stack((logit_cls0, logit_cls1, logit_cls2, logit_cls3))  # ``(d, H, W, D)``
    # only including the slices that contains labels
    for dslice in range(GTimg.shape[2]):
        if np.sum(GTimg[:, :, dslice]) > 0:
            logits_test.append(logit_cls[:, :, :, dslice])
            gt_test.append(GTimg[:, :, dslice])
            imgs_test.append(img_data[:, :, dslice])

# logits_test is a list of length ``n``,  each element has ``(d, H, W)``.
# gt_test is a list of length ``n``,  each element has ``(H, W)``.
# H and W could differ for different cases.
[8]:
print(f"The validation predictions, ``logits`` are a list of length {len(logits)} each element has approximately {logits[0].shape}")
print(f"The validation labels, ``gt`` are a list of length {len(gt)}, each element has approximately {gt[0].shape}\n")
print(f"The test predictions, ``logits_test`` are a list of length {len(logits_test)} each element has approximately {logits_test[0].shape}")
print(f"The test labels, ``gt_test`` are a list of length {len(gt_test)}, each element has approximately {gt_test[0].shape}")
print(f"The test imgs, ``imgs_test`` are a list of length {len(imgs_test)}, each element has approximately {imgs_test[0].shape}")
The validation predictions, ``logits`` are a list of length 156 each element has approximately (4, 210, 257)
The validation labels, ``gt`` are a list of length 156, each element has approximately (210, 257)

The test predictions, ``logits_test`` are a list of length 38 each element has approximately (4, 338, 338)
The test labels, ``gt_test`` are a list of length 38, each element has approximately (338, 338)
The test imgs, ``imgs_test`` are a list of length 38, each element has approximately (338, 338)
[9]:
import random
random.seed(79)
test_inds = list(range(len(logits)))
random.shuffle(test_inds)
test_inds = test_inds[:100]
#
_logits = []
_gt = []
for test_ind in test_inds:
    _logits.append(logits[test_ind])
    _gt.append(gt[test_ind])
logits_val = _logits
gt_val = _gt
#
print(f"The validation predictions, ``logits`` are a list of length {len(logits_val)} each element has approximately {logits_val[0].shape}")
print(f"The validation labels, ``gt`` are a list of length {len(gt_val)}, each element has approximately {gt_val[0].shape}")
The validation predictions, ``logits`` are a list of length 100 each element has approximately (4, 223, 272)
The validation labels, ``gt`` are a list of length 100, each element has approximately (223, 272)

MOVAL estimation

[10]:
moval_options = []
moval_options.append(['ac-model', 'segmentation', 'max_class_probability-conf', False])
moval_options.append(['ts-model', 'segmentation', 'max_class_probability-conf', False])
moval_options.append(['ts-model', 'segmentation', 'max_class_probability-conf', True])
moval_options.append(['ts-atc-model', 'segmentation', 'entropy-conf', True])
[11]:
def test_cls(estim_algorithm, mode, confidence_scores, class_specific, logits, gt, logits_tests, gt_tests):
    """Test MOVAL with different conditions for segmentation tasks

    Args:
        mode (str): The given task to estimate model performance.
        confidence_scores (str):
            The method to calculate the confidence scores. We provide a list of confidence score calculation methods which
            can be displayed by running :py:func:`moval.models.get_conf_options`.
        estim_algorithm (str):
            The algorithm to estimate model performance. We also provide a list of estimation algorithm which can be displayed by
            running :py:func:`moval.models.get_estim_options`.
        class_specific (bool):
            If ``True``, the calculation will match class-wise confidence to class-wise accuracy.
        logits: The network output (logits) of a list of n ``(d, H, W, (D))`` for segmentation.
        gt: The cooresponding annotation of a list of n ``(H, W, (D))`` for segmentation.
        logits_tests:  The network testing output (logits) of a list of n' ``(d, H', W', (D'))`` for segmentation.
        gt_test: The cooresponding testing annotation of a list of n' ``(H', W', (D'))`` for segmentation.

    Returns:
        err_test (float): testing error.
        moval_model: Optimized moval model.

    """

    moval_model = moval.MOVAL(
                mode = mode,
                metric = "f1score",
                confidence_scores = confidence_scores,
                estim_algorithm = estim_algorithm,
                class_specific = class_specific,
                approximate = True,
                approximate_boundary = 10
                )

    #
    moval_model.fit(logits, gt)

    # save the test err in the result files.

    estim_dsc_test = moval_model.estimate(logits_test)

    return moval_model
[12]:
moval_model_MCP = test_cls(
        estim_algorithm = moval_options[0][0],
        mode = moval_options[0][1],
        confidence_scores = moval_options[0][2],
        class_specific = moval_options[0][3],
        logits = logits_val,
        gt = gt_val,
        logits_tests = logits_test,
        gt_tests = gt_test
    )
#
moval_model_baseline = test_cls(
        estim_algorithm = moval_options[1][0],
        mode = moval_options[1][1],
        confidence_scores = moval_options[1][2],
        class_specific = moval_options[1][3],
        logits = logits_val,
        gt = gt_val,
        logits_tests = logits_test,
        gt_tests = gt_test
    )
#
moval_model_cs = test_cls(
        estim_algorithm = moval_options[2][0],
        mode = moval_options[2][1],
        confidence_scores = moval_options[2][2],
        class_specific = moval_options[2][3],
        logits = logits_val,
        gt = gt_val,
        logits_tests = logits_test,
        gt_tests = gt_test
    )
#
moval_model_cs_entropy = test_cls(
        estim_algorithm = moval_options[3][0],
        mode = moval_options[3][1],
        confidence_scores = moval_options[3][2],
        class_specific = moval_options[3][3],
        logits = logits_val,
        gt = gt_val,
        logits_tests = logits_test,
        gt_tests = gt_test
    )
Starting optimizing for model ac-model with confidence max_class_probability-conf based on metric f1score, class specific is False.
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-model with confidence max_class_probability-conf based on metric f1score, class specific is False.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-model with confidence max_class_probability-conf based on metric f1score, class specific is True.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...
Starting optimizing for model ts-atc-model with confidence entropy-conf based on metric f1score, class specific is True.
Opitimizing with 100 samples...
Be patient, it should take a while...
Calculating and saving the fitted case-wise performance...

Get the confidence scores

[13]:
conf_mcp = moval_model_MCP.model_.calibrate(logits_test) # n list of (H, W)
conf_baseline = moval_model_baseline.model_.calibrate(logits_test)  # n list of (H, W)
conf_cs = moval_model_cs.model_.calibrate(logits_test)  # n list of (H, W)
conf_cs_entropy = moval_model_cs_entropy.model_.calibrate(logits_test)  # n list of (H, W)
[14]:
def caculate_dsc(score, inp, gt):
    """Calculate real dsc and estimated dsc

    Args:
        score: Confidence scores of a list of n ``(H, W, (D))``.
        inp: Network output of a list of n ``(d, H, W, (D))``
        gt: Ground truth segmentation of a list of n ``(H, W, (D))``.

    Return:
        estim_dsc: The mean estimated DSC for cases of shape ``(n,)``.
        real_dsc: The mean real DSC of for cases shape ``(n,)``.

    """
    from moval.models.utils import SoftDiceLoss
    num_class = inp[0].shape[0]
    estim_dsc_list = []
    #
    pred_all_flatten_bg = []
    gt_all_flatten_bg = []
    dsc = []
    for n_case in range(len(inp)):
        pred_case = np.argmax(inp[n_case], axis = 0) # ``(H, W, (D))``
        pred_flatten = pred_case.flatten() # ``n``
        score_case = score[n_case] # ``(H, W, (D))``
        score_flatten = score[n_case].flatten() # ``n``
        score_filled = np.zeros((score_flatten.shape + (num_class,))) # ``(n, d)``
        score_filled[np.arange(score_filled.shape[0]), pred_flatten] = score_flatten # ``(n, d)``
        score_filled = score_filled.T # ``(d, n)``
        score_filled = score_filled.reshape(((num_class,) + score_case.shape)) # ``(d, H, W, (D))``
        #
        estim_dsc = SoftDiceLoss(score_filled[np.newaxis, ...], pred_case[np.newaxis, ...])

        #
        gt_case     = gt[n_case] # ``(H, W, (D))``
        pos_bg = np.where(pred_case.flatten() == 0)[0]

        pred_all_flatten_bg.append(pred_case.flatten()[pos_bg])
        gt_all_flatten_bg.append(gt_case.flatten()[pos_bg])

        dsc_case = np.zeros(inp[n_case].shape[0])
        for kcls in range(1, inp[n_case].shape[0]):
            if np.sum(gt_case == kcls) == 0:
                estim_dsc[kcls] = -1
                dsc_case[kcls] = -1
            else:
                dsc_cal, _, _ = ComputMetric(pred_case == kcls, gt_case == kcls)
                dsc_case[kcls] = dsc_cal

        estim_dsc_list.append(np.mean(estim_dsc[1:][dsc_case[1:] >= 0]))
        dsc.append(np.mean(dsc_case[1:][dsc_case[1:] >= 0]))

    return estim_dsc_list, dsc

Plot the discrepancy between estimatation and real dsc

[15]:
def lineplot(estim_dsc, dsc, title = "Estimation Error with MCP"):
    """Make the line plot to show the estimated segmentaiton error

    Args:
        estim_dsc: The estimated dsc for all the case with shape ``(n, )``
        dsc: The real dsc for all the case with shape ``(n, )``
        title: The given title to be shown.

    """

    d = {'Estimated DSC': estim_dsc, 'Real DSC': dsc}
    df = pd.DataFrame(data=d)
    df = df.sort_values(by="Real DSC", ascending=False, na_position='first')
    df = df.reset_index(drop=True)
    df.insert(2, "ID", range(len(dsc)), True)

    sns.set(rc={'figure.figsize':(12, 3)})
    sns.set_style("darkgrid")

    # Line plot
    ax = sns.lineplot(
        data=df,
        x="ID", y="Real DSC",
        color='#0074b6',
        label='Real DSC'
    )

    # Dot plot
    sns.scatterplot(
        data=df,
        x="ID", y="Real DSC",
        color='#0074b6',
        marker='o',
        s=50,
        label=None
    )

    # Fill between the lines
    plt.fill_between(
        df['ID'],
        df['Estimated DSC'],
        df['Real DSC'],
        color='#a52a2a',
        alpha=0.2,
        label='Estimatation Error'
    )
    ax.set_xticks([])

    # Add legend
    plt.legend(bbox_to_anchor=(1, 1), loc='upper left')

    plt.title(title)

    plt.show()

[16]:
estim_dsc_mcp, dsc_mcp = caculate_dsc(conf_mcp, logits_test, gt_test)
[17]:
lineplot(estim_dsc_mcp, dsc_mcp, title = "Estimation Error with MCP")
../_images/demos_analysis_seg2d_23_0.png

Plot the confidence map

[18]:
def visual_two_case(conf_maps, imgs_toshow, lbls_toshow, preds_toshow, _real_dscs, _estim_dscs, title = 'The Visualization of MCP for Cases with High and Low DSC'):
    """Visualization of confidence score for two cases.

    Args:
        conf_maps: A list of confidence maps, of shape ``(H, W)``.
        imgs_toshow: A list of image, of shape ``(H, W)``.
        lbls_toshow: A list of label, of shape ``(H, W)``.
        preds_toshow: A list of prediction, of shape ``(H, W)``.
        _real_dscs: A list of real DSC (float).
        _estim_dscs: A list of esimtaed DSC (float).
        title: The given title to be shown.

    """

    from matplotlib.gridspec import GridSpec

    fig = plt.figure(figsize=(15, 8))
    gs = GridSpec(2, 3, width_ratios=[1, 1, 1.2])  # Adjust the ratios as needed

    cmap = sns.color_palette("ch:s=-.2,r=.6", as_cmap=True)

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax2 = plt.subplot(gs[2])
    ax3 = plt.subplot(gs[3])
    ax4 = plt.subplot(gs[4])
    ax5 = plt.subplot(gs[5])

    # Define colors for each class
    class_colors = ['', 'green', 'blue', 'red']  # Add more colors if you have more classes

    ax0.imshow(imgs_toshow[0], cmap='gray')
    ax0.set_title('Image Overlapped with Ground Truth', x=0.5, y=1.05)
    # Overlay the label map using class-specific colors
    for i in range(len(class_colors)):
        if i == 0:  # Skip label 0
            continue
        mask = lbls_toshow[0] == i
        label_map_rgb = np.zeros((*lbls_toshow[0].shape, 4))
        label_map_rgb[mask, :] = plt.cm.colors.to_rgba(class_colors[i])

        ax0.imshow(label_map_rgb, alpha=0.5)
    ax0.set_xticks([])
    ax0.set_yticks([])
    sns.heatmap(np.abs(preds_toshow[0] - lbls_toshow[0]), cmap=cmap, xticklabels=False, yticklabels=False, cbar=False, ax=ax1)
    ax1.set_title('Absolute Error between Prediction and Ground Truth', x=0.5, y=1.1)
    sns.heatmap(1 - conf_maps[0], cmap=cmap, xticklabels=False, yticklabels=False, cbar_kws={'format': ''}, ax=ax2)
    cbar = ax2.collections[0].colorbar
    ax2.set_title('Confidence Score', x=0.5, y=1.1)
    cbar.set_ticks([])
    cbar.ax.text(2.5, 0.04, 'min', ha='center', va='center')
    cbar.ax.text(2.5, 0.5, 'max', ha='center', va='center')
    #
    ax3.imshow(imgs_toshow[1], cmap='gray')
    ax3.set_title('Image Overlapped with Ground Truth', x=0.5, y=1.05)
    # Overlay the label map using class-specific colors
    for i in range(len(class_colors)):
        if i == 0:  # Skip label 0
            continue
        mask = lbls_toshow[1] == i
        label_map_rgb = np.zeros((*lbls_toshow[1].shape, 4))
        label_map_rgb[mask, :] = plt.cm.colors.to_rgba(class_colors[i])

        ax3.imshow(label_map_rgb, alpha=0.5)
    ax3.set_xticks([])
    ax3.set_yticks([])
    sns.heatmap(np.abs(preds_toshow[1] - lbls_toshow[1]), cmap=cmap, xticklabels=False, yticklabels=False, cbar=False, ax=ax4)
    ax4.set_title('Absolute Error between Prediction and Ground Truth', x=0.5, y=1.1)
    sns.heatmap(1 - conf_maps[1], cmap=cmap, xticklabels=False, yticklabels=False, cbar_kws={'format': ''}, ax=ax5)
    cbar = ax5.collections[0].colorbar
    ax5.set_title('Confidence Score', x=0.5, y=1.1)
    cbar.set_ticks([])
    cbar.ax.text(2.5, 0.04, 'min', ha='center', va='center')

    ax1.text(0.5, 1.05,  f'Real DSC is {_real_dscs[0]:.3f}', transform=ax1.transAxes, ha='center', va='center', fontsize=12)
    ax4.text(0.5, 1.05,  f'Real DSC is {_real_dscs[1]:.3f}', transform=ax4.transAxes, ha='center', va='center', fontsize=12)
    ax2.text(0.5, 1.05,  f'Estimated DSC is {_estim_dscs[0]:.3f}', transform=ax2.transAxes, ha='center', va='center', fontsize=12)
    ax5.text(0.5, 1.05,  f'Estimated DSC is {_estim_dscs[1]:.3f}', transform=ax5.transAxes, ha='center', va='center', fontsize=12)
    cbar.ax.text(2.5, 0.5, 'max', ha='center', va='center')

    fig.suptitle(title, fontsize=16, y=1)
[19]:
dsc_sort = np.argsort(dsc_mcp)
imgs_toshow = [imgs_test[dsc_sort[-1]], imgs_test[dsc_sort[0]]]
lbls_toshow = [gt_test[dsc_sort[-1]], gt_test[dsc_sort[0]]]
preds_toshow = [np.argmax(logits_test[dsc_sort[-1]], axis=0), np.argmax(logits_test[dsc_sort[0]], axis=0)]
_real_dscs = [dsc_mcp[dsc_sort[-1]], dsc_mcp[dsc_sort[0]]]
#
conf_maps = [conf_mcp[dsc_sort[-1]], conf_mcp[dsc_sort[0]]]
_estim_dscs = [estim_dsc_mcp[dsc_sort[-1]], estim_dsc_mcp[dsc_sort[0]]]
[20]:
visual_two_case(conf_maps, imgs_toshow, lbls_toshow, preds_toshow, _real_dscs, _estim_dscs, title = 'The Visualization of MCP for Cases with High and Low DSC')
../_images/demos_analysis_seg2d_27_0.png
[ ]:

[ ]: