API reference

Fundamentals

Module Transformation

metoryx.transform(module: Module, *, to_callable: Callable[[Module], Callable[[...], Any]] | None = None) Transformed

Transform a module into initialization and applying functions.

Parameters:
  • module – The module to transform.

  • to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.

Returns:

A transformed module with separate init and apply functions.

metoryx.init(module: Module) InitFn

Transform a module into an initialization function.

Parameters:

module – The module to transform.

Returns:

The initialization function for the module.

metoryx.apply(module: Module, *, to_callable: Callable[[Module], Callable[[...], Any]] | None = None) ApplyFn

Transform a module into an apply function.

Parameters:
  • module – The module to transform.

  • to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.

Returns:

The apply function for the module.

metoryx.checkpoint(module: Module, *, to_callable: Callable[[Module], Callable] | None = None, concrete: bool = False, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: Callable[[...], bool] | None = None) Callable[[...], Any]

Transform a module into a function that applies jax.checkpoint to its call method.

Parameters:
  • module – The module to wrap.

  • to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.

  • concrete – Whether to use concrete mode in jax.checkpoint.

  • prevent_cse – Whether to prevent common subexpression elimination in jax.checkpoint.

  • static_argnums – Static argument numbers to pass to jax.checkpoint.

  • policy – A custom policy function to pass to jax.checkpoint.

Returns:

A function that applies jax.checkpoint to the module’s call method.

Modules, States, and Parameters

class metoryx.Module

Bases: object

Base class for all neural network modules.

class metoryx.State(col: str, init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, mutable: bool = False)

Bases: object

A container for a stateful variable in a module. Lazily initialized and can be mutable.

__init__(col: str, init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, mutable: bool = False)

Initializes the State.

Parameters:
  • col – The collection name for the state variable.

  • init – The initializer function for the state variable.

  • shape – The shape of the state variable.

  • dtype – The data type of the state variable.

  • param_dtype – The data type for computation. If None, uses dtype.

  • mutable – Whether the state variable is mutable during the applying phase.

property value: Array

Get or initialize the value of the state variable.

This property is only available during the initializing or applying phase. In the initializing phase, it initializes the variable using the provided initializer. In the applying phase, it retrieves the current value from the context.

This property can also be set to an array with the same shape as the state variable. In the applying phase, if the state variable is mutable, setting this property updates the value in the context. Otherwise, setting this property updates the initializer to return the provided array during the next initialization.

property id: str

A unique identifier for the state variable.

property T: Array

Transpose of the state variable.

property ndim: int

Number of dimensions of the state variable.

class metoryx.Parameter(init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: State

A container for a parameter variable in a module. Typically immutable during applying.

__init__(init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initializes the Parameter.

Parameters:
  • init – The initializer function for the parameter variable.

  • shape – The shape of the parameter variable.

  • dtype – The data type of the parameter variable.

  • param_dtype – The data type for computation. If None, uses dtype.

Random Numbers

class metoryx.PRNGKeys(default: PRNGKey | None = None, /, **kwargs: PRNGKey)

Bases:

Create a dictionary of PRNGKeys to feed into the apply function.

Parameters:
  • default – The default PRNGKey to use if none is provided.

  • **kwargs – Additional PRNGKeys to use for specific purposes.

Returns:

A dictionary of PRNGKeys to feed into the apply function.

metoryx.next_rng_key(name: str | None = None, num: int | tuple[int, ...] | None = None, *, strict: bool = False) PRNGKey

Get the next PRNGKey. This function is only available within the applying phase.

Parameters:
  • name – The name of the PRNGKey to use. If None, uses the default PRNGKey.

  • num – If provided, splits the PRNGKey into num keys.

  • strict – If True, raises an error if the specified name is not found in the context. If False, falls back to the default PRNGKey if the specified name is not found.

Returns:

The next PRNGKey.

Note

This function also updates the context with the new PRNGKey. Thus, subsequent calls to this function will return different keys.

Raises:

ValueError – If the context is not set, or if the specified name is not found in the context and strict is True, or if the default PRNGKey is not found in the context.

Common Modules and Functions

Linear

class metoryx.Dense(in_size: int, out_size: int, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: Module

Applies an affine linear transformation.

__init__(in_size: int, out_size: int, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initializes the Dense module.

Parameters:
  • in_size – Size of the input features.

  • out_size – Size of the output features.

  • use_bias – Whether to include a bias term.

  • kernel_init – Initializer for the weight matrix.

  • bias_init – Initializer for the bias term.

  • dtype – Data type of the parameters.

  • param_dtype – Data type for computation.

__call__(inputs: Array) Array

Applies an affine transformation to the input.

Parameters:

inputs – Input array of shape (*batch_size, in_size).

Returns:

The transformed output array of shape (*batch_size, out_size).

class metoryx.Conv(in_size: int, out_size: int, kernel_size: ~typing.Sequence[int], padding: PaddingLike = 'SAME', strides: int | ~typing.Sequence[int] = 1, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: Module

Applies a convolution.

__init__(in_size: int, out_size: int, kernel_size: ~typing.Sequence[int], padding: PaddingLike = 'SAME', strides: int | ~typing.Sequence[int] = 1, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initializes the Conv module.

Parameters:
  • in_size – Number of input channels.

  • out_size – Number of output channels.

  • kernel_size – Size of the convolutional kernel.

  • padding – Padding method, either ‘SAME’, ‘VALID’, or a sequence of padding tuples.

  • strides – Strides of the convolution.

  • dilation – Dilation of the convolution.

  • groups – Number of groups for grouped convolution.

  • use_bias – Whether to include a bias term.

  • kernel_init – Initializer for the convolutional kernel.

  • bias_init – Initializer for the bias term.

  • dtype – Data type of the parameters.

  • param_dtype – Data type for computation.

padding: PaddingLike
__call__(inputs: Array) Array

Applies the convolution to the input.

Parameters:

inputs – Input array of shape (*batch_size, height, width, in_size).

Returns:

The convolved output array.

class metoryx.Embed(size: int, num_embeddings: int, embedding_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: Module

Embeds the inputs along the last dimension.

__init__(size: int, num_embeddings: int, embedding_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initializes the Embed module.

Parameters:
  • size – Size of each embedding vector.

  • num_embeddings – Number of unique embeddings.

  • embedding_init – Initializer for the embedding matrix.

  • dtype – Data type of the parameters.

  • param_dtype – Data type for computation.

__call__(inputs: Array) Array

Embeds the input indices.

Parameters:

inputs – Input array of shape (*batch_size,) with integer indices.

Returns:

The embedded output array of shape (*batch_size, size).

attend(query: Array) Array

Computes the attention scores between the query and the embeddings.

Parameters:

query – Query array of shape (*batch_size, size).

Returns:

The attention scores of shape (*batch_size, num_embeddings).

Normalization

class metoryx.BatchNorm(size: int, momentum: float = 0.99, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, axis_name: ~typing.Any | None = None, axis_index_groups: ~typing.Any | None = None)

Bases: Module

Batch normalization.

Ref. https://arxiv.org/abs/1502.03167

Batch normalization keeps a moving average of batch statistics. These are stored in the batch_stats collection.

__init__(size: int, momentum: float = 0.99, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, axis_name: ~typing.Any | None = None, axis_index_groups: ~typing.Any | None = None)

Initialize BatchNorm layer.

Parameters:
  • size – Size of input features.

  • momentum – Momentum for the moving average.

  • epsilon – Small constant for numerical stability.

  • use_scale – Whether to use a scale parameter.

  • use_bias – Whether to use a bias parameter.

  • scale_init – Initializer for the scale parameter.

  • bias_init – Initializer for the bias parameter.

  • dtype – Data type for computation.

  • param_dtype – Data type of the parameters.

  • axis_name – Axis name to sync batch statistics along devices.

  • axis_index_groups – Axis index groups for distributed training.

__call__(inputs: Array, is_training: bool = False) Array

Call self as a function.

class metoryx.LayerNorm(size: int, epsilon: float = 1e-06, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: Module

Layer normalization.

Ref. https://arxiv.org/abs/1607.06450

__init__(size: int, epsilon: float = 1e-06, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initialize LayerNorm layer.

Parameters:
  • size – Size of input features.

  • epsilon – Small constant for numerical stability.

  • use_scale – Whether to use a scale parameter.

  • use_bias – Whether to use a bias parameter.

  • scale_init – Initializer for the scale parameter.

  • bias_init – Initializer for the bias parameter.

  • dtype – Data type for computation.

  • param_dtype – Data type of the parameters.

__call__(inputs: Array) Array

Call self as a function.

class metoryx.RMSNorm(size: int, epsilon: float = 1e-06, use_scale: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Bases: Module

RMS layer normalization.

Ref. https://arxiv.org/abs/1910.07467

__init__(size: int, epsilon: float = 1e-06, use_scale: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)

Initialize RMSNorm layer.

Parameters:
  • size – Size of input features.

  • epsilon – Small constant for numerical stability.

  • use_scale – Whether to use a scale parameter.

  • scale_init – Initializer for the scale parameter.

  • dtype – Data type for computation.

  • param_dtype – Data type of the parameters.

__call__(inputs: Array) Array

Call self as a function.

Dropout

metoryx.dropout(inputs: Array, rate: float, is_training: bool, *, rng_collection: str | None = None) Array

Applies dropout to the input array.

During training, randomly sets a fraction rate of input units to zero at each update step, which helps prevent overfitting. During evaluation, the input is returned unchanged.

Parameters:
  • inputs – Input array.

  • rate – Fraction of the input units to drop. Must be between 0 and 1.

  • is_training – Whether the model is in training mode.

  • rng_collection – Name of the RNG collection to use for generating dropout masks. If None, uses the default RNG collection.

Returns:

The array after applying dropout.

Pooling

metoryx.avg_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array
metoryx.max_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array
metoryx.min_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array

Activation Functions

Most activation functions are exported from jax.nn for convenience.

metoryx.celu(x: Array | ndarray | bool | number | bool | int | float | complex, alpha: Array | ndarray | bool | number | bool | int | float | complex = 1.0) Array

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

Parameters:
  • x – input array

  • alpha – array or scalar (default: 1.0)

Returns:

An array.

metoryx.elu(x: Array | ndarray | bool | number | bool | int | float | complex, alpha: Array | ndarray | bool | number | bool | int | float | complex = 1.0) Array

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
Parameters:
  • x – input array

  • alpha – scalar or array of alpha values (default: 1.0)

Returns:

An array.

See also

selu()

metoryx.gelu(x: Array | ndarray | bool | number | bool | int | float | complex, approximate: bool = True) Array

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters:
  • x – input array

  • approximate – whether to use the approximate or exact formulation.

metoryx.glu(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int = -1) Array

Gated linear unit activation function.

Computes the function:

\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]

where the array is split into two along axis. The size of the axis dimension must be divisible by two.

Parameters:
  • x – input array

  • axis – the axis along which the split should be computed (default: -1)

Returns:

An array.

See also

sigmoid()

metoryx.hard_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Hard Sigmoid activation function.

Computes the element-wise function

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
Parameters:

x – input array

Returns:

An array.

See also

relu6()

metoryx.hard_silu(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Hard SiLU (swish) activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

Both hard_silu() and hard_swish() are aliases for the same function.

Parameters:

x – input array

Returns:

An array.

See also

hard_sigmoid()

metoryx.hard_tanh(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Hard \(\mathrm{tanh}\) activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
Parameters:

x – input array

Returns:

An array.

metoryx.identity(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Identity activation function.

Returns the argument unmodified.

Parameters:

x – input array

Returns:

The argument x unmodified.

Examples

>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
metoryx.leaky_relu(x: Array | ndarray | bool | number | bool | int | float | complex, negative_slope: Array | ndarray | bool | number | bool | int | float | complex = 0.01) Array

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

Parameters:
  • x – input array

  • negative_slope – array or scalar specifying the negative slope (default: 0.01)

Returns:

An array.

See also

relu()

metoryx.log_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
Parameters:

x – input array

Returns:

An array.

See also

sigmoid()

metoryx.log_softmax(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters:
  • x – input array

  • axis – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.

  • where – Elements to include in the log_softmax. The output for any masked-out element is minus infinity.

Returns:

An array.

Note

If any input values are +inf, the result will be all NaN: this reflects the fact that inf / inf is not well-defined in the context of floating-point math.

See also

softmax()

metoryx.mish(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Mish activation function.

Computes the element-wise function:

\[\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))\]

For more information, see Mish: A Self Regularized Non-Monotonic Activation Function.

Parameters:

x – input array

Returns:

An array.

metoryx.one_hot(x: Any, num_classes: int, *, dtype: Any | None = None, axis: int | Hashable = -1) Array

One-hot encodes the given indices.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

Indices outside the range [0, num_classes) will be encoded as zeros:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
Parameters:
  • x – A tensor of indices.

  • num_classes – Number of classes in the one-hot dimension.

  • dtype – optional, a float dtype for the returned values (default jnp.float_).

  • axis – the axis or axes along which the function should be computed.

metoryx.relu(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Rectified linear unit activation function.

Computes the element-wise function:

\[\mathrm{relu}(x) = \max(x, 0)\]

except under differentiation, we take:

\[\nabla \mathrm{relu}(0) = 0\]

For more information see Numerical influence of ReLU’(0) on backpropagation.

Parameters:

x – input array

Returns:

An array.

Examples

>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

See also

relu6()

metoryx.relu6(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Rectified Linear Unit 6 activation function.

Computes the element-wise function

\[\mathrm{relu6}(x) = \min(\max(x, 0), 6)\]

except under differentiation, we take:

\[\nabla \mathrm{relu}(0) = 0\]

and

\[\nabla \mathrm{relu}(6) = 0\]
Parameters:

x – input array

Returns:

An array.

See also

relu()

metoryx.selu(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters:

x – input array

Returns:

An array.

See also

elu()

metoryx.sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
Parameters:

x – input array

Returns:

An array.

See also

log_sigmoid()

metoryx.silu(x: Array | ndarray | bool | number | bool | int | float | complex) Array

SiLU (aka swish) activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish() and silu() are both aliases for the same function.

Parameters:

x – input array

Returns:

An array.

See also

sigmoid()

metoryx.soft_sign(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Soft-sign activation function.

Computes the element-wise function

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
Parameters:

x – input array

metoryx.softmax(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters:
  • x – input array

  • axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.

  • where – Elements to include in the softmax. The output for any masked-out element is zero.

Returns:

An array.

Note

If any input values are +inf, the result will be all NaN: this reflects the fact that inf / inf is not well-defined in the context of floating-point math.

See also

log_softmax()

metoryx.softplus(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
Parameters:

x – input array

metoryx.sparse_plus(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Sparse plus function.

Computes the function:

\[\begin{split}\mathrm{sparse\_plus}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ x, & 1 \leq x \end{cases}\end{split}\]

This is the twin function of the softplus activation ensuring a zero output for inputs less than -1 and a linear output for inputs greater than 1, while remaining smooth, convex, monotonic by an adequate definition between -1 and 1.

Parameters:

x – input (float)

metoryx.sparse_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array

Sparse sigmoid activation function.

Computes the function:

\[\begin{split}\mathrm{sparse\_sigmoid}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{2}(x+1), & -1 < x < 1 \\ 1, & 1 \leq x \end{cases}\end{split}\]

This is the twin function of the sigmoid activation ensuring a zero output for inputs less than -1, a 1 output for inputs greater than 1, and a linear output for inputs between -1 and 1. It is the derivative of sparse_plus.

For more information, see Learning with Fenchel-Young Losses (section 6.2).

Parameters:

x – input array

Returns:

An array.

See also

sigmoid()

metoryx.squareplus(x: Array | ndarray | bool | number | bool | int | float | complex, b: Array | ndarray | bool | number | bool | int | float | complex = 4) Array

Squareplus activation function.

Computes the element-wise function

\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]

as described in https://arxiv.org/abs/2112.11687.

Parameters:
  • x – input array

  • b – smoothness parameter

metoryx.standardize(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, mean: Array | ndarray | bool | number | bool | int | float | complex | None = None, variance: Array | ndarray | bool | number | bool | int | float | complex | None = None, epsilon: Array | ndarray | bool | number | bool | int | float | complex = 1e-05, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array

Standardizes input to zero mean and unit variance.

The standardization is given by:

\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]

where \(\langle x\rangle\) indicates the mean of \(x\), and \(\epsilon\) is a small correction factor introduced to avoid division by zero.

Parameters:
  • x – input array to be standardized.

  • axis – integer or tuple of integers representing the axes along which to standardize. Defaults to the last axis (-1).

  • mean – optionally specify the mean used for standardization. If not specified, then x.mean(axis, where=where) will be used.

  • variance – optionally specify the variance used for standardization. If not specified, then x.var(axis, where=where) will be used.

  • epsilon – correction factor added to variance to avoid division by zero; defaults to 1E-5.

  • where – optional boolean mask specifying which elements to use when computing the mean and variance.

Returns:

An array of the same shape as x containing the standardized input.

Initializers

Initializers are exported from jax.nn.initializers for convenience.

metoryx.initializers.zeros(dtype: DType | None = None) Initializer

Builds an initializer that generates arrays initialized to 0.

Parameters:

dtype – The data type of the initialized array.

Returns:

An initializer that returns an array initialized to 0.

metoryx.initializers.ones(dtype: DType | None = None) Initializer

Builds an initializer that generates arrays initialized to 1.

Parameters:

dtype – The data type of the initialized array.

Returns:

An initializer that returns an array initialized to 1.

metoryx.initializers.constant(value: ArrayLike, dtype: DType | None = None) Initializer

Builds an initializer that generates arrays initialized to a constant value.

Parameters:
  • value – The constant value to initialize the array.

  • dtype – The data type of the initialized array.

Returns:

An initializer that returns an array initialized to the specified constant value.

metoryx.initializers.uniform(scale: float = 0.01, dtype: DType | None = None) Initializer

Builds an initializer that returns real uniformly-distributed random arrays.

Parameters:
  • scale – optional; the upper bound of the random distribution.

  • dtype – optional; the initializer’s default dtype.

Returns:

An initializer that returns arrays whose values are uniformly distributed in the range [0, scale).

metoryx.initializers.normal(stddev: float = 0.01, dtype: DType | None = None) Initializer

Builds an initializer that returns real normally-distributed random arrays.

Parameters:
  • stddev – optional; the standard deviation of the distribution.

  • dtype – optional; the initializer’s default dtype.

Returns:

An initializer that returns arrays whose values are normally distributed with mean 0 and standard deviation stddev.

metoryx.initializers.truncated_normal(stddev: float = 0.01, dtype: DType | None = None, lower: float = -2.0, upper: float = 2.0) Initializer

Builds an initializer that returns truncated-normal random arrays.

Parameters:
  • stddev – optional; the standard deviation of the untruncated distribution. Note that this function does not apply the stddev correction as is done in the variancescaling initializers, and users are expected to apply this correction themselves via the stddev arg if they wish to employ it.

  • dtype – optional; the initializer’s default dtype.

  • lower – Float representing the lower bound for truncation. Applied before the output is multiplied by the stddev.

  • upper – Float representing the upper bound for truncation. Applied before the output is multiplied by the stddev.

Returns:

An initializer that returns arrays whose values follow the truncated normal distribution with mean 0 and standard deviation stddev, and range lower * stddev < x < upper * stddev.

metoryx.initializers.variance_scaling(scale: float, mode: Literal['fan_in', 'fan_out', 'fan_avg', 'fan_geo_avg'], distribution: Literal['truncated_normal', 'normal', 'uniform'], in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Initializer that adapts its scale to the shape of the weights tensor.

With distribution="truncated_normal" or distribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of sqrt(scale/n), where n is, for each mode:

  • "fan_in": the number of inputs

  • "fan_out": the number of outputs

  • "fan_avg": the arithmetic average of the numbers of inputs and outputs

  • "fan_geo_avg": the geometric average of the numbers of inputs and outputs

This initializer can be configured with in_axis, out_axis, and batch_axis to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).

With distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.

With distribution="uniform", samples are drawn from:

  • a uniform interval, if dtype is real, or

  • a uniform disk, if dtype is complex,

with a mean of zero and a standard deviation of sqrt(scale/n) where n is defined above.

Parameters:
  • scale – scaling factor (positive float).

  • mode – one of "fan_in", "fan_out", "fan_avg", and "fan_geo_avg".

  • distribution – random distribution to use. One of "truncated_normal", "normal" and "uniform".

  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

metoryx.initializers.glorot_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.glorot_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.xavier_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.xavier_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.he_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of variance_scaling where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.he_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of variance_scaling where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.kaiming_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of variance_scaling where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.kaiming_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of variance_scaling where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.lecun_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Lecun uniform initializer.

A Lecun uniform initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_in", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.lecun_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer

Builds a Lecun normal initializer.

A Lecun normal initializer is a specialization of variance_scaling where scale = 1.0, mode="fan_in", and distribution="truncated_normal".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

metoryx.initializers.orthogonal(scale: float = 1.0, column_axis: int = -1, dtype: DType | None = None) Initializer

Builds an initializer that returns uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters:
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns:

An orthogonal initializer.

metoryx.initializers.delta_orthogonal(scale: float = 1.0, column_axis: int = -1, dtype: DType | None = None) Initializer

Builds an initializer for delta orthogonal kernels.

Parameters:
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns:

A delta orthogonal initializer. The shape passed to the initializer must be 3D, 4D, or 5D.

Utilities

class metoryx.utils.AverageMeter(*, with_timer: bool = False, timer_key: str = 'elapsed_time')

Bases: object

Stores and computes the weighted average of multiple metrics over time.

__init__(*, with_timer: bool = False, timer_key: str = 'elapsed_time')

Initializes the AverageMeter.

Parameters:
  • with_timer – Whether to track elapsed time from the meter reset to the compute call.

  • timer_key – The key name for the elapsed time metric.

update(metric_dict: dict[str, Array | ndarray | bool | number | bool | int | float | complex], n: int = 1, *, prefix: str = '') None

Updates the meter with new metric values.

Parameters:
  • metric_dict – A dictionary where keys are metric names and values are the corresponding metric values (can be arrays).

  • n – The number of samples the metrics correspond to (default is 1).

  • prefix – A string prefix to add to each metric name (default is “”).

Returns:

The updated AverageMeter instance.

compute() dict[str, float]

Computes the average for each metric.

Returns:

A dictionary where keys are metric names and values are their average.

reset() None

Resets the meter to initial state.

metoryx.assign_variables(module: Module, variables: Variables) Module

Assign arrays to the module.

Parameters:
  • module – The module to assign arrays to.

  • variables – The variables to assign to the module.

Returns:

The module with assigned arrays.

Notes

  • The assigned arrays will be reflected when the module is initialized.

  • Currently, the array tree structure must match the module’s variable structure.