MRI Synthesis
In this notebook, we will use transforms that generate synthetic MRIs with varying contrast and resolution from label maps.
This is a reimplementation of the domain randomization approach described in
Billot, B., Greve, D.N., Puonti, O., Thielscher, A., Van Leemput, K., Fischl, B., Dalca, A.V. and Iglesias, J.E., 2023. SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining. Medical image analysis, 86, p.102789.
@article{billot2023synthseg,
title = {SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining},
author = {Billot, Benjamin and Greve, Douglas N and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V and Iglesias, Juan Eugenio and others},
journal = {Medical image analysis},
volume = {86},
pages = {102789},
year = {2023},
publisher = {Elsevier},
url = {https://www.sciencedirect.com/science/article/pii/S1361841523000506}
}
Let us download a demonstration label map from the SynthSeg repository.
!pushd $TMPDIR \
&& curl \
-L "https://github.com/BBillot/SynthSeg/raw/refs/heads/master/data/training_label_maps/training_seg_01.nii.gz" \
-o demo.nii.gz \
&& popd
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 1623k 100 1623k 0 0 2993k 0 --:--:-- --:--:-- --:--:-- 2993k
We will be using the following transforms:
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
LoadTransform, # Load nifti filesa
RelabelTransform, # Ensure contiguous labels
RandomGaussianMixtureTransform, # Sample values from a Gaussian in each label
IntensityTransform, # Set of common intensity augmentations
SynthFromLabelTransform, # Complete "label-to-image" transform
)
First, let's load our label map and display it.
# Path to the demonstration label map
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
# Load the label map (and presrve its integer data type)
lab = LoadTransform(dtype=torch.int)(fname)
# Extract a single 2D slice
lab = lab[:, :, lab.shape[-2]//2, :]
# Ensure that labels are contiguous
lab = RelabelTransform()(lab)
# Display the label map
plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()
We will start with defining a sequence of transformations that generate an MRI-like image from a label map. This does not include any geometric transformation.
Warning
Tensors fed to a Transform layer should have a channel dimension,
and no batch dimension.
# Define a transform that applies, in sequence:
# 1) A Gaussian mixture, such that intensities are sampled with
# different means and covariances in each label;
# 2) A series of common augmentations
# (bias field, resampling, smoothing, ...).
trf = RandomGaussianMixtureTransform(background=0) + IntensityTransform()
# Apply the transformation to the label map
img = trf(lab)
# Display the resulting image
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
Now, let's synthesize a bunch of them
plotgrid = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(plotgrid[0] * plotgrid[1]):
# Sample an image
img = trf(lab)
# Display the image
plt.subplot(*plotgrid, i+1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
Finally, let's try the full pipeline with deformations
# Define a complete "label-to-image" transform
# This is equivalent to the "GMM + intensity augmentation" sequence
# used earlier, but also applies geometric transformations to the
# label map, prior to the GMM.
trf = SynthFromLabelTransform()
# Apply the transform to generate a synthetic image and companion label map
img, newlab = trf(lab)
# Display the results
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.subplot(1, 2, 2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Label')
plt.show()
plotgrid = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(plotgrid[0] * plotgrid[1]//2):
# Generate an (image, label) pair
img, newlab = trf(lab)
# Display them
plt.subplot(*plotgrid, 2*i+1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(*plotgrid, 2*i+2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.show()