Skip to content

MRI Synthesis

In this notebook, we will use transforms that generate synthetic MRIs with varying contrast and resolution from label maps.

This is a reimplementation of the domain randomization approach described in:

Billot, B., Greve, D.N., Puonti, O., Thielscher, A., Van Leemput, K., Fischl, B., Dalca, A.V. and Iglesias, J.E., 2023. SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining. Medical image analysis, 86, p.102789.

    @article{billot2023synthseg,
      title     = {SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining},
      author    = {Billot, Benjamin and Greve, Douglas N and Puonti, Oula and Thielscher, Axel and Van Leemput, Koen and Fischl, Bruce and Dalca, Adrian V and Iglesias, Juan Eugenio and others},
      journal   = {Medical image analysis},
      volume    = {86},
      pages     = {102789},
      year      = {2023},
      publisher = {Elsevier},
      url       = {https://www.sciencedirect.com/science/article/pii/S1361841523000506}
    }
!pushd $TMPDIR \
    && curl  \
    -L "https://github.com/BBillot/SynthSeg/raw/refs/heads/master/data/training_label_maps/training_seg_01.nii.gz" \
    -o demo.nii.gz \
    && popd
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 1623k  100 1623k    0     0  3349k      0 --:--:-- --:--:-- --:--:-- 3349k
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
    LoadTransform, RelabelTransform,
    IntensityTransform, SynthFromLabelTransform, RandomGaussianMixtureTransform)
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
lab = LoadTransform(dtype=torch.int)(fname)
lab = lab[:, :, lab.shape[-2]//2, :]
lab = RelabelTransform()(lab)

plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()

Output figure

Then, instantiate a IntensityTransform and apply it to our labels. Note that tensors fed to a Transform layer should have a channel dimension, and no batch dimension.

trf = RandomGaussianMixtureTransform(background=0) + IntensityTransform()
img = trf(lab)

plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()

Output figure

Now, let's synthesize a bunch of them

shape = [4, 4]
plt.figure(figsize=(10, 10))

for i in range(shape[0] * shape[1]):
    plt.subplot(*shape, i+1)
    plt.imshow(trf(lab).squeeze().T.flip(0), cmap='gray', interpolation='nearest')
    plt.axis('off')
plt.show()

Output figure

Finally, let's try the full pipeline with deformations

trf = SynthFromLabelTransform()
img, newlab = trf(lab)

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.subplot(1, 2, 2)
plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Label')
plt.show()

Output figure

shape = [4, 4]
plt.figure(figsize=(10, 10))

for i in range(shape[0] * shape[1]//2):
    img, newlab = trf(lab)
    plt.subplot(*shape, 2*i+1)
    plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
    plt.axis('off')
    plt.subplot(*shape, 2*i+2)
    plt.imshow(newlab.squeeze().T.flip(0), cmap='tab20', interpolation='nearest')
    plt.axis('off')
plt.show()

Output figure