Skip to content

Contrast augmentation

A large number of MRI contrasts exist. However, most label data comes from a small number of sequences. It is therefore complicated to train networks that generalize to any contrast. Here, we showcase a contrast augmentation transform that operates by fitting a Gaussian mixture model to the input image and shifting their means and variances.

First, let's download an example image.

!pushd $TMPDIR \
&& wget https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/tutorial_data.tar.gz -O data.tar.gz \
&& tar -xzvf data.tar.gz \
&& popd
/home/scratch /autofs/space/pade_001/users/yb947/code/yb/cornucopia/docs/examples
--2023-08-15 14:31:35--  https://surfer.nmr.mgh.harvard.edu/pub/data/voxelmorph/tutorial_data.tar.gz
Resolving surfer.nmr.mgh.harvard.edu (surfer.nmr.mgh.harvard.edu)... 132.183.1.43
Connecting to surfer.nmr.mgh.harvard.edu (surfer.nmr.mgh.harvard.edu)|132.183.1.43|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16644702 (16M) [application/x-gzip]
Saving to: ‘data.tar.gz’

data.tar.gz         100%[===================>]  15.87M  63.8MB/s    in 0.2s    

2023-08-15 14:31:35 (63.8 MB/s) - ‘data.tar.gz’ saved [16644702/16644702]

brain_2d_no_smooth.h5
brain_2d_smooth.h5
brain_3d.h5
fs_rgb.npy
subj1.npz
subj2.npz
tutorial_data.npz
/autofs/space/pade_001/users/yb947/code/yb/cornucopia/docs/examples
import torch
import numpy as np
import matplotlib.pyplot as plt
from cornucopia import ContrastMixtureTransform
import os
fname = os.path.join(os.environ['TMPDIR'], 'tutorial_data.npz')
dat = np.load(fname)['train'][0]
dat = torch.as_tensor(dat)

plt.figure(figsize=(10, 10))
plt.imshow(dat, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()

Output figure

Let us now instantiate a contrast augmentation layer and apply it to the MRI. We use fewer classes (6) than the default (12), because we're dealing with skull-stripped 2D images, that have much fewer intensity modes than an intact 3D volume.

trf = ContrastMixtureTransform(nk=6)
aug = trf(dat[None])[0]

plt.figure(figsize=(10, 10))
plt.imshow(aug, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('MRI')
plt.show()

Output figure

Now, let's synthesize a bunch of them

shape = [4, 4]

plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
    plt.subplot(*shape, i+1)
    plt.imshow(trf(dat[None])[0], cmap='gray', interpolation='nearest')
    plt.axis('off')
plt.show()

Output figure