qMRI Synthesis
This notebook is related to MRI synthesis, except that instead of using a generic contrast model, we use a physics-based forward model. This allows other types of artefacts to be included (for example, inhomogeneity of the excitation field, which acts on the intensity in a nonlinear way). However, the range of parameters that yield a "useful" contrast is much narrower. Depending on the application, it may therefore be useful to train using both "nonphysical" and "physical" random contrasts.
!pushd $TMPDIR \
&& curl \
-L "https://github.com/BBillot/SynthSeg/raw/master/data/training_label_maps/training_seg_01.nii.gz" \
-o demo.nii.gz \
&& popd
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 1623k 100 1623k 0 0 2506k 0 --:--:-- --:--:-- --:--:-- 2506k
import torch
import os
import matplotlib.pyplot as plt
from cornucopia import (
LoadTransform, RelabelTransform, RandomGMMGradientEchoTransform,
IntensityTransform, QuantileTransform, MakeAffinePair, RandomAffineElasticTransform,
RandomAffineTransform,
)
fname = os.path.join(os.environ['TMPDIR'], 'demo.nii.gz')
lab = LoadTransform(dtype=torch.int)(fname)
lab = lab[:, :, lab.shape[-2]//2, :]
lab = RelabelTransform()(lab)
plt.figure(figsize=(10, 10))
plt.imshow(lab[0].T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Labels')
plt.show()
Then, instantiate a IntensityTransform and apply it to our labels.
Note that tensors fed to a Transform layer should have a channel dimension, and no batch dimension.
trf = RandomGMMGradientEchoTransform() + QuantileTransform()
img = trf(lab)
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest', vmin=0, vmax=1)
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
Now, let's synthesize a bunch of them
shape = [4, 4]
imgs = []
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
imgs.append(trf(lab))
plt.imshow(imgs[-1].squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
plt.imshow(imgs[i].squeeze().T.flip(0), cmap='gray', interpolation='nearest', vmax=0.8)
plt.axis('off')
plt.show()
Finally, let's try the full pipeline with deformations
trf = (
RandomGMMGradientEchoTransform() +
QuantileTransform() +
IntensityTransform()
)
img = trf(lab)
plt.figure(figsize=(10, 10))
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.title('Synthetic Image')
plt.show()
shape = [4, 4]
plt.figure(figsize=(10, 10))
for i in range(shape[0] * shape[1]):
plt.subplot(*shape, i+1)
img = trf(lab)
plt.imshow(img.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
wrp = RandomAffineElasticTransform()
aff = MakeAffinePair(
RandomAffineTransform(shears=0, zooms=0),
returns=dict(left='left', right='right', flow='flow')
)
gre = RandomGMMGradientEchoTransform(
returns=dict(label='input', image='output'), append=True, exclude='flow'
)
qtl = QuantileTransform(include=['left.image', 'right.image'])
aug = IntensityTransform(include=['left.image', 'right.image'])
trf = (wrp + aff + gre + qtl + aug)
out = trf(lab)
image_left = out['left.image']
image_right = out['right.image']
label_left = out['left.label']
label_right = out['right.label']
flow = out['flow']
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.imshow(image_left.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(2, 2, 2)
plt.imshow(image_right.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(2, 2, 3)
plt.imshow(label_left.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.subplot(2, 2, 4)
plt.imshow(label_right.squeeze().T.flip(0), cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()