Skip to content

cornucopia.base

Transform

Transform(*, returns=None, append=False, prefix=True, include=None, exclude=None, consume=None)

Bases: Module, ABC

Base class for all transforms.

Parameters:

Name Type Description Default
returns [list or dict of] str

Which tensors to return. Can be a nested structure. Most transforms accept 'input' and 'output' as valid returns. The default is 'output'.

None
append bool | str

Append the (structure of) returned tensors to the parent structure.

Warning

This option does not keep the input tensors in the returned structure! To preserve the input tensors, you should use append in conjunction with returns.

Example

# With lists
trf = MyTransform(returns=['input', 'output'], append=True)
x1, y1, x2, y2 = trf([x1, x2])

# With dicts
trf = MyTransform(returns={'x': 'input', 'y': 'output'}, append=True, prefix=True)
out = trf({'path1': x1, 'path2': x2})
assert out.keys() == {'path1.x', 'path1.y', 'path2.x', 'path2.y'}

v0.5 Can be a string since v0.5

If it is a str and parent is a dict, its value will be used as a separator between the prefix and the key. See prefix.

False
prefix bool | str

If append and parent is a dict, prefix the returned key before inserting it in the output dictionary.

If True, the prefix is the input key.

v0.5 Can be a string since v0.5

True
include [list of] str | re.Pattern

List of keys to which the transform should apply. Default: all.

v0.5 Can be a regex or glob pattern since v0.5

None
exclude [list of] str | re.Pattern

List of keys to which the transform should not apply. Default: none.

v0.5 Can be a regex or glob pattern since v0.5

None
consume [list of] str | re.Pattern

List of keys to remove from the output after applying the transform. Default: none.

v0.5 Added in v0.5. Can be a regex or glob pattern.

None

is_final property

is_final

Returns:

Type Description
bool

Whether the transform is final (i.e., deterministic) or not.

get_prm

get_prm()

Get the parameters of the transform, for use in subtransforms.

Returns:

Type Description
dict

A dictionary containing the attributes returns, append, prefix, include, exclude, and consume.

forward

forward(*a, **k)

Apply the transform recursively.

Parameters:

Name Type Description Default
*a [nested list or dict of] tensor

Input tensors, with shape (C, *shape)

()
**k [nested list or dict of] tensor

Input tensors, with shape (C, *shape)

()

Returns:

Type Description
[nested list or dict of] tensor

Output tensors. with shape (C, *shape)

xform

xform(x, /, args=NoArguments())

Apply the transform to a tensor.

Non-final transforms do not implement this method in general.

Parameters:

Name Type Description Default
x (C_inp, *spatial_inp) tensor

A single input tensor

required
args Arguments

The original inputs arguments to the transform, in case they are needed.

NoArguments()

Returns:

Name Type Description
y Returned | (C_out, *spatial_out) tensor

A single output tensor, or a Returned object containing multiple output tensors and their corresponding keys.

final

final(x, /, args=NoArguments(), **kwargs)

Generate the final version of the transform.

Some transforms save the output type of this function in their Final attribute.

v0.5 Added final method in v0.5.

Before this, one had to use make_final(x, max_depth=inf).

Parameters:

Name Type Description Default
x tensor

A single input tensor, with shape (C, *shape).

required
args Arguments

The original inputs arguments to the transform, in case they are needed.

NoArguments()

Returns:

Type Description
FinalTransform

A final version of the transform.

next

next(x, /, args=NoArguments(), **kwargs)

Generate the next version of the transform.

Some transforms save the output type of this function in their Next attribute.

v0.5 Added next method in v0.5.

Before this, one had to use make_final(x, max_depth=1).

Parameters:

Name Type Description Default
x tensor

A single input tensor, with shape (C, *shape).

required
args Arguments

The original inputs arguments to the transform, in case they are needed.

NoArguments()

Returns:

Type Description
Transform

A more specialized version of the transform.

unroll

unroll(x, /, max_depth=inf, args=NoArguments(), **kwargs)

Generate the next (i.e., more final) version(s) of the transform.

  • To completely finalize a transform, call unroll(x, max_depth=inf) or final().
  • To get the the next version of a transform, call unroll(x, max_depth=1) or next().

v0.5 Added unroll method in v0.5.

Before this, it was named make_final.

Parameters:

Name Type Description Default
x tensor

A single input tensor, with shape (C, *shape).

required
max_depth int | {inf}

Maximum depth to apply unroll recursively. If not inf, the resulting transform may not be fully final. Default: no limit.

inf
args Arguments

The original inputs arguments to the transform, in case they are needed.

NoArguments()

Returns:

Type Description
Transform

A more specialized version of the transform.

inverse

inverse(*a, **k)

Apply the inverse transform recursively

Parameters:

Name Type Description Default
*a [nested list or dict of] tensor

Input tensors, with shape (C, *shape)

()
**k [nested list or dict of] tensor

Input tensors, with shape (C, *shape)

()

Returns:

Type Description
[nested list or dict of] tensor

Output tensors. with shape (C, *shape)

make_inverse

make_inverse()

Generate the inverse transform

FinalTransform

FinalTransform(*, returns=None, append=False, prefix=True, include=None, exclude=None, consume=None)

Bases: Transform

Base class for determinstic transforms.

Final transforms must implement the xform method.

NonFinalTransform

NonFinalTransform(*, shared=False, **kwargs)

Bases: _SharedMixin, Transform

Transforms whose parameters depend on features of the input transform (shape, dtype, etc).

Non-final transforms implement unroll, and do not implement xform. Their aim is to generate a more-specialized transform at call time.

Parameters:

Name Type Description Default
shared (channels, tensors, channels + tensor, '')
  • 'channel': the same transform is applied to all channels in a tensor, but different transforms are used in different tensors.
  • 'tensors': the same transform is applied to all tensors, but with a different transform for each channel.
  • 'channels+tensors' or True: the same transform is applied to all channels of all tensors.
  • '' or False: A different transform is applied to each channel and each tensor.
'channels'

SpecialTransform

SpecialTransform(*, returns=None, append=False, prefix=True, include=None, exclude=None, consume=None)

Bases: Transform

Base class for transforms that act on other transforms.

Such transforms cannot be easily classified as "final" or "non-final", because this characeteristic depends on the transforms that they embed.

They all implement unroll, but some may also implement a "fast-track" xform that is applied in simple cases (e.g., when the transform is not shared across tensors) for efficiency.

v0.5 Added SpecialTransform class in v0.5.

Before this, special transforms inherited directly from Transform.