MRI Synthesis
In this notebook, we will use transforms that generate synthetic MRIs with varying contrast and resolution from label maps.
This is a reimplementation of the domain randomization approach described in:
Billot, B., Greve, D.N., Puonti, O., Thielscher, A., Van Leemput, K., Fischl, B., Dalca, A.V. and Iglesias, J.E., 2023. SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining. Medical image analysis, 86, p.102789.
@article{billot2023synthseg,
title = {SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining},
author = {Billot, Benjamin and Greve, Douglas N and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V and Iglesias, Juan Eugenio and others},
journal = {Medical image analysis},
volume = {86},
pages = {102789},
year = {2023},
publisher = {Elsevier},
url = {https://www.sciencedirect.com/science/article/pii/S1361841523000506}
}
!pushd $TMPDIR \
&& curl \
-L "https://github.com/BBillot/SynthSeg/raw/refs/heads/master/data/training_label_maps/training_seg_01.nii.gz" \
-o demo.nii.gz \
&& popd
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 1623k 100 1623k 0 0 3349k 0 --:--:-- --:--:-- --:--:-- 3349k
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
LoadTransform, RelabelTransform,
IntensityTransform, SynthFromLabelTransform, RandomGaussianMixtureTransform)
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
lab = LoadTransform(dtype=torch.int)(fname)
lab = lab[:, :, lab.shape[-2]//2, :]
lab = RelabelTransform()(lab)
plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()
Then, instantiate a IntensityTransform and apply it to our labels.
Note that tensors fed to a Transform layer should have a channel dimension, and no batch dimension.
trf = RandomGaussianMixtureTransform(background=0) + IntensityTransform()
img = trf(lab)
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
Now, let's synthesize a bunch of them
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(trf(lab).squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
Finally, let's try the full pipeline with deformations
trf = SynthFromLabelTransform()
img, newlab = trf(lab)
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.subplot(1, 2, 2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Label')
plt.show()
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]//2):
img, newlab = trf(lab)
plt.subplot(*shape, 2*i+1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(*shape, 2*i+2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.show()