Skip to content

Benchmark

import cornucopia as cc
import matplotlib.pyplot as plt
import os
from urllib.request import urlretrieve
from tempfile import gettempdir
import torch

def plot_mri(dat):
    plt.style.use('classic')
    plt.imshow(torch.rot90(dat[0, :, dat.shape[2]//2]).cpu(), cmap='gray')

# download demo dataset
url = 'https://bitbucket.org/nipy/nitest-freesurfer/raw/0d307865704df71c3b2248139714806aad47139d/fsaverage/mri/T1.mgz'
fname = os.path.join(gettempdir(), 'T1.mgz')
urlretrieve(url, fname)

# load from disk
dat = cc.LoadTransform(dtype='float32')(fname)
dat = cc.CropTransform(32)(dat)
plot_mri(dat)

Output figure

import torchio as tio

augment = tio.Compose([
    tio.RandomAffine(),
    tio.RandomElasticDeformation(),
    tio.RandomAnisotropy(),
    tio.RandomNoise(),
])

dat = dat.cpu()
%timeit dat1 = augment(dat)
996 ms ± 20.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
plot_mri(augment(dat))

Output figure

lowres = cc.randomize(cc.LowResSliceTransform)(cc.random.Uniform(0, 4),
                                               cc.random.Uniform(0, 1))
augment = cc.SequentialTransform([
    cc.RandomAffineElasticTransform(),
    lowres,
    cc.GaussianNoiseTransform()
])


dat = dat.cpu()
dat1 = augment(dat)  # precompile
%timeit augment(dat)

dat = dat.cuda()
dat1 = augment(dat)  # precompile
%timeit augment(dat)
3.12 s ± 559 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
54.8 ms ± 5.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
plot_mri(augment(dat))

Output figure