API reference¶
Fundamentals¶
Module Transformation¶
- metoryx.transform(module: Module, *, to_callable: Callable[[Module], Callable[[...], Any]] | None = None) Transformed¶
Transform a module into initialization and applying functions.
- Parameters:
module – The module to transform.
to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.
- Returns:
A transformed module with separate init and apply functions.
- metoryx.init(module: Module) InitFn¶
Transform a module into an initialization function.
- Parameters:
module – The module to transform.
- Returns:
The initialization function for the module.
- metoryx.apply(module: Module, *, to_callable: Callable[[Module], Callable[[...], Any]] | None = None) ApplyFn¶
Transform a module into an apply function.
- Parameters:
module – The module to transform.
to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.
- Returns:
The apply function for the module.
- metoryx.checkpoint(module: Module, *, to_callable: Callable[[Module], Callable] | None = None, concrete: bool = False, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: Callable[[...], bool] | None = None) Callable[[...], Any]¶
Transform a module into a function that applies jax.checkpoint to its call method.
- Parameters:
module – The module to wrap.
to_callable – An optional function to convert module into a callable. If None, the module itself must be callable.
concrete – Whether to use concrete mode in jax.checkpoint.
prevent_cse – Whether to prevent common subexpression elimination in jax.checkpoint.
static_argnums – Static argument numbers to pass to jax.checkpoint.
policy – A custom policy function to pass to jax.checkpoint.
- Returns:
A function that applies jax.checkpoint to the module’s call method.
Modules, States, and Parameters¶
- class metoryx.Module¶
Bases:
objectBase class for all neural network modules.
- class metoryx.State(col: str, init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, mutable: bool = False)¶
Bases:
objectA container for a stateful variable in a module. Lazily initialized and can be mutable.
- __init__(col: str, init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, mutable: bool = False)¶
Initializes the State.
- Parameters:
col – The collection name for the state variable.
init – The initializer function for the state variable.
shape – The shape of the state variable.
dtype – The data type of the state variable.
param_dtype – The data type for computation. If None, uses dtype.
mutable – Whether the state variable is mutable during the applying phase.
- property value: Array¶
Get or initialize the value of the state variable.
This property is only available during the initializing or applying phase. In the initializing phase, it initializes the variable using the provided initializer. In the applying phase, it retrieves the current value from the context.
This property can also be set to an array with the same shape as the state variable. In the applying phase, if the state variable is mutable, setting this property updates the value in the context. Otherwise, setting this property updates the initializer to return the provided array during the next initialization.
- property id: str¶
A unique identifier for the state variable.
- property T: Array¶
Transpose of the state variable.
- property ndim: int¶
Number of dimensions of the state variable.
- class metoryx.Parameter(init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
StateA container for a parameter variable in a module. Typically immutable during applying.
- __init__(init: ~metoryx._src.base.Initializer, shape: Shape, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initializes the Parameter.
- Parameters:
init – The initializer function for the parameter variable.
shape – The shape of the parameter variable.
dtype – The data type of the parameter variable.
param_dtype – The data type for computation. If None, uses dtype.
Random Numbers¶
- class metoryx.PRNGKeys(default: PRNGKey | None = None, /, **kwargs: PRNGKey)¶
Bases:
Create a dictionary of PRNGKeys to feed into the apply function.
- Parameters:
default – The default PRNGKey to use if none is provided.
**kwargs – Additional PRNGKeys to use for specific purposes.
- Returns:
A dictionary of PRNGKeys to feed into the apply function.
- metoryx.next_rng_key(name: str | None = None, num: int | tuple[int, ...] | None = None, *, strict: bool = False) PRNGKey¶
Get the next PRNGKey. This function is only available within the applying phase.
- Parameters:
name – The name of the PRNGKey to use. If None, uses the default PRNGKey.
num – If provided, splits the PRNGKey into num keys.
strict – If True, raises an error if the specified name is not found in the context. If False, falls back to the default PRNGKey if the specified name is not found.
- Returns:
The next PRNGKey.
Note
This function also updates the context with the new PRNGKey. Thus, subsequent calls to this function will return different keys.
- Raises:
ValueError – If the context is not set, or if the specified name is not found in the context and strict is True, or if the default PRNGKey is not found in the context.
Common Modules and Functions¶
Linear¶
- class metoryx.Dense(in_size: int, out_size: int, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
ModuleApplies an affine linear transformation.
- __init__(in_size: int, out_size: int, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initializes the Dense module.
- Parameters:
in_size – Size of the input features.
out_size – Size of the output features.
use_bias – Whether to include a bias term.
kernel_init – Initializer for the weight matrix.
bias_init – Initializer for the bias term.
dtype – Data type of the parameters.
param_dtype – Data type for computation.
- __call__(inputs: Array) Array¶
Applies an affine transformation to the input.
- Parameters:
inputs – Input array of shape (*batch_size, in_size).
- Returns:
The transformed output array of shape (*batch_size, out_size).
- class metoryx.Conv(in_size: int, out_size: int, kernel_size: ~typing.Sequence[int], padding: PaddingLike = 'SAME', strides: int | ~typing.Sequence[int] = 1, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
ModuleApplies a convolution.
- __init__(in_size: int, out_size: int, kernel_size: ~typing.Sequence[int], padding: PaddingLike = 'SAME', strides: int | ~typing.Sequence[int] = 1, dilation: int | ~typing.Sequence[int] = 1, groups: int = 1, use_bias: bool = True, kernel_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initializes the Conv module.
- Parameters:
in_size – Number of input channels.
out_size – Number of output channels.
kernel_size – Size of the convolutional kernel.
padding – Padding method, either ‘SAME’, ‘VALID’, or a sequence of padding tuples.
strides – Strides of the convolution.
dilation – Dilation of the convolution.
groups – Number of groups for grouped convolution.
use_bias – Whether to include a bias term.
kernel_init – Initializer for the convolutional kernel.
bias_init – Initializer for the bias term.
dtype – Data type of the parameters.
param_dtype – Data type for computation.
- padding: PaddingLike¶
- __call__(inputs: Array) Array¶
Applies the convolution to the input.
- Parameters:
inputs – Input array of shape (*batch_size, height, width, in_size).
- Returns:
The convolved output array.
- class metoryx.Embed(size: int, num_embeddings: int, embedding_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
ModuleEmbeds the inputs along the last dimension.
- __init__(size: int, num_embeddings: int, embedding_init: ~metoryx._src.base.Initializer = <function variance_scaling.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initializes the Embed module.
- Parameters:
size – Size of each embedding vector.
num_embeddings – Number of unique embeddings.
embedding_init – Initializer for the embedding matrix.
dtype – Data type of the parameters.
param_dtype – Data type for computation.
- __call__(inputs: Array) Array¶
Embeds the input indices.
- Parameters:
inputs – Input array of shape (*batch_size,) with integer indices.
- Returns:
The embedded output array of shape (*batch_size, size).
- attend(query: Array) Array¶
Computes the attention scores between the query and the embeddings.
- Parameters:
query – Query array of shape (*batch_size, size).
- Returns:
The attention scores of shape (*batch_size, num_embeddings).
Normalization¶
- class metoryx.BatchNorm(size: int, momentum: float = 0.99, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, axis_name: ~typing.Any | None = None, axis_index_groups: ~typing.Any | None = None)¶
Bases:
ModuleBatch normalization.
Ref. https://arxiv.org/abs/1502.03167
Batch normalization keeps a moving average of batch statistics. These are stored in the batch_stats collection.
- __init__(size: int, momentum: float = 0.99, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None, axis_name: ~typing.Any | None = None, axis_index_groups: ~typing.Any | None = None)¶
Initialize BatchNorm layer.
- Parameters:
size – Size of input features.
momentum – Momentum for the moving average.
epsilon – Small constant for numerical stability.
use_scale – Whether to use a scale parameter.
use_bias – Whether to use a bias parameter.
scale_init – Initializer for the scale parameter.
bias_init – Initializer for the bias parameter.
dtype – Data type for computation.
param_dtype – Data type of the parameters.
axis_name – Axis name to sync batch statistics along devices.
axis_index_groups – Axis index groups for distributed training.
- __call__(inputs: Array, is_training: bool = False) Array¶
Call self as a function.
- class metoryx.LayerNorm(size: int, epsilon: float = 1e-06, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
ModuleLayer normalization.
Ref. https://arxiv.org/abs/1607.06450
- __init__(size: int, epsilon: float = 1e-06, use_scale: bool = True, use_bias: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, bias_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initialize LayerNorm layer.
- Parameters:
size – Size of input features.
epsilon – Small constant for numerical stability.
use_scale – Whether to use a scale parameter.
use_bias – Whether to use a bias parameter.
scale_init – Initializer for the scale parameter.
bias_init – Initializer for the bias parameter.
dtype – Data type for computation.
param_dtype – Data type of the parameters.
- __call__(inputs: Array) Array¶
Call self as a function.
- class metoryx.RMSNorm(size: int, epsilon: float = 1e-06, use_scale: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Bases:
ModuleRMS layer normalization.
Ref. https://arxiv.org/abs/1910.07467
- __init__(size: int, epsilon: float = 1e-06, use_scale: bool = True, scale_init: ~metoryx._src.base.Initializer = <function constant.<locals>.init>, dtype: DType = <class 'jax.numpy.float32'>, param_dtype: DType | None = None)¶
Initialize RMSNorm layer.
- Parameters:
size – Size of input features.
epsilon – Small constant for numerical stability.
use_scale – Whether to use a scale parameter.
scale_init – Initializer for the scale parameter.
dtype – Data type for computation.
param_dtype – Data type of the parameters.
- __call__(inputs: Array) Array¶
Call self as a function.
Dropout¶
- metoryx.dropout(inputs: Array, rate: float, is_training: bool, *, rng_collection: str | None = None) Array¶
Applies dropout to the input array.
During training, randomly sets a fraction rate of input units to zero at each update step, which helps prevent overfitting. During evaluation, the input is returned unchanged.
- Parameters:
inputs – Input array.
rate – Fraction of the input units to drop. Must be between 0 and 1.
is_training – Whether the model is in training mode.
rng_collection – Name of the RNG collection to use for generating dropout masks. If None, uses the default RNG collection.
- Returns:
The array after applying dropout.
Pooling¶
- metoryx.avg_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array¶
- metoryx.max_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array¶
- metoryx.min_pool(inputs: Array, kernel_size: Sequence[int], strides: int | Sequence[int], padding: PaddingLike = 'VALID') Array¶
Activation Functions¶
Most activation functions are exported from jax.nn for convenience.
- metoryx.celu(x: Array | ndarray | bool | number | bool | int | float | complex, alpha: Array | ndarray | bool | number | bool | int | float | complex = 1.0) Array¶
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- Parameters:
x – input array
alpha – array or scalar (default: 1.0)
- Returns:
An array.
- metoryx.elu(x: Array | ndarray | bool | number | bool | int | float | complex, alpha: Array | ndarray | bool | number | bool | int | float | complex = 1.0) Array¶
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters:
x – input array
alpha – scalar or array of alpha values (default: 1.0)
- Returns:
An array.
See also
- metoryx.gelu(x: Array | ndarray | bool | number | bool | int | float | complex, approximate: bool = True) Array¶
Gaussian error linear unit activation function.
If
approximate=False, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]If
approximate=True, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- Parameters:
x – input array
approximate – whether to use the approximate or exact formulation.
- metoryx.glu(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int = -1) Array¶
Gated linear unit activation function.
Computes the function:
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]where the array is split into two along
axis. The size of theaxisdimension must be divisible by two.- Parameters:
x – input array
axis – the axis along which the split should be computed (default: -1)
- Returns:
An array.
See also
- metoryx.hard_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Hard Sigmoid activation function.
Computes the element-wise function
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.hard_silu(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Hard SiLU (swish) activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]Both
hard_silu()andhard_swish()are aliases for the same function.- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.hard_tanh(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Hard \(\mathrm{tanh}\) activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- Parameters:
x – input array
- Returns:
An array.
- metoryx.identity(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Identity activation function.
Returns the argument unmodified.
- Parameters:
x – input array
- Returns:
The argument x unmodified.
Examples
>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
- metoryx.leaky_relu(x: Array | ndarray | bool | number | bool | int | float | complex, negative_slope: Array | ndarray | bool | number | bool | int | float | complex = 0.01) Array¶
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope.- Parameters:
x – input array
negative_slope – array or scalar specifying the negative slope (default: 0.01)
- Returns:
An array.
See also
- metoryx.log_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.log_softmax(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array¶
Log-Softmax function.
Computes the logarithm of the
softmaxfunction, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters:
x – input array
axis – the axis or axes along which the
log_softmaxshould be computed. Either an integer or a tuple of integers.where – Elements to include in the
log_softmax. The output for any masked-out element is minus infinity.
- Returns:
An array.
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also
- metoryx.mish(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Mish activation function.
Computes the element-wise function:
\[\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))\]For more information, see Mish: A Self Regularized Non-Monotonic Activation Function.
- Parameters:
x – input array
- Returns:
An array.
- metoryx.one_hot(x: Any, num_classes: int, *, dtype: Any | None = None, axis: int | Hashable = -1) Array¶
One-hot encodes the given indices.
Each index in the input
xis encoded as a vector of zeros of lengthnum_classeswith the element atindexset to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters:
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default
jnp.float_).axis – the axis or axes along which the function should be computed.
- metoryx.relu(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Rectified linear unit activation function.
Computes the element-wise function:
\[\mathrm{relu}(x) = \max(x, 0)\]except under differentiation, we take:
\[\nabla \mathrm{relu}(0) = 0\]For more information see Numerical influence of ReLU’(0) on backpropagation.
- Parameters:
x – input array
- Returns:
An array.
Examples
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
See also
- metoryx.relu6(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Rectified Linear Unit 6 activation function.
Computes the element-wise function
\[\mathrm{relu6}(x) = \min(\max(x, 0), 6)\]except under differentiation, we take:
\[\nabla \mathrm{relu}(0) = 0\]and
\[\nabla \mathrm{relu}(6) = 0\]- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.selu(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.silu(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
SiLU (aka swish) activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]swish()andsilu()are both aliases for the same function.- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.soft_sign(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Soft-sign activation function.
Computes the element-wise function
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- Parameters:
x – input array
- metoryx.softmax(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array¶
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axissum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
x – input array
axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
where – Elements to include in the
softmax. The output for any masked-out element is zero.
- Returns:
An array.
Note
If any input values are
+inf, the result will be allNaN: this reflects the fact thatinf / infis not well-defined in the context of floating-point math.See also
- metoryx.softplus(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Softplus activation function.
Computes the element-wise function
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- Parameters:
x – input array
- metoryx.sparse_plus(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Sparse plus function.
Computes the function:
\[\begin{split}\mathrm{sparse\_plus}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ x, & 1 \leq x \end{cases}\end{split}\]This is the twin function of the softplus activation ensuring a zero output for inputs less than -1 and a linear output for inputs greater than 1, while remaining smooth, convex, monotonic by an adequate definition between -1 and 1.
- Parameters:
x – input (float)
- metoryx.sparse_sigmoid(x: Array | ndarray | bool | number | bool | int | float | complex) Array¶
Sparse sigmoid activation function.
Computes the function:
\[\begin{split}\mathrm{sparse\_sigmoid}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{2}(x+1), & -1 < x < 1 \\ 1, & 1 \leq x \end{cases}\end{split}\]This is the twin function of the
sigmoidactivation ensuring a zero output for inputs less than -1, a 1 output for inputs greater than 1, and a linear output for inputs between -1 and 1. It is the derivative ofsparse_plus.For more information, see Learning with Fenchel-Young Losses (section 6.2).
- Parameters:
x – input array
- Returns:
An array.
See also
- metoryx.squareplus(x: Array | ndarray | bool | number | bool | int | float | complex, b: Array | ndarray | bool | number | bool | int | float | complex = 4) Array¶
Squareplus activation function.
Computes the element-wise function
\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]as described in https://arxiv.org/abs/2112.11687.
- Parameters:
x – input array
b – smoothness parameter
- metoryx.standardize(x: Array | ndarray | bool | number | bool | int | float | complex, axis: int | Sequence[int] | None = -1, mean: Array | ndarray | bool | number | bool | int | float | complex | None = None, variance: Array | ndarray | bool | number | bool | int | float | complex | None = None, epsilon: Array | ndarray | bool | number | bool | int | float | complex = 1e-05, where: Array | ndarray | bool | number | bool | int | float | complex | None = None) Array¶
Standardizes input to zero mean and unit variance.
The standardization is given by:
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]where \(\langle x\rangle\) indicates the mean of \(x\), and \(\epsilon\) is a small correction factor introduced to avoid division by zero.
- Parameters:
x – input array to be standardized.
axis – integer or tuple of integers representing the axes along which to standardize. Defaults to the last axis (
-1).mean – optionally specify the mean used for standardization. If not specified, then
x.mean(axis, where=where)will be used.variance – optionally specify the variance used for standardization. If not specified, then
x.var(axis, where=where)will be used.epsilon – correction factor added to variance to avoid division by zero; defaults to
1E-5.where – optional boolean mask specifying which elements to use when computing the mean and variance.
- Returns:
An array of the same shape as
xcontaining the standardized input.
Initializers¶
Initializers are exported from jax.nn.initializers for convenience.
- metoryx.initializers.zeros(dtype: DType | None = None) Initializer¶
Builds an initializer that generates arrays initialized to 0.
- Parameters:
dtype – The data type of the initialized array.
- Returns:
An initializer that returns an array initialized to 0.
- metoryx.initializers.ones(dtype: DType | None = None) Initializer¶
Builds an initializer that generates arrays initialized to 1.
- Parameters:
dtype – The data type of the initialized array.
- Returns:
An initializer that returns an array initialized to 1.
- metoryx.initializers.constant(value: ArrayLike, dtype: DType | None = None) Initializer¶
Builds an initializer that generates arrays initialized to a constant value.
- Parameters:
value – The constant value to initialize the array.
dtype – The data type of the initialized array.
- Returns:
An initializer that returns an array initialized to the specified constant value.
- metoryx.initializers.uniform(scale: float = 0.01, dtype: DType | None = None) Initializer¶
Builds an initializer that returns real uniformly-distributed random arrays.
- Parameters:
scale – optional; the upper bound of the random distribution.
dtype – optional; the initializer’s default dtype.
- Returns:
An initializer that returns arrays whose values are uniformly distributed in the range
[0, scale).
- metoryx.initializers.normal(stddev: float = 0.01, dtype: DType | None = None) Initializer¶
Builds an initializer that returns real normally-distributed random arrays.
- Parameters:
stddev – optional; the standard deviation of the distribution.
dtype – optional; the initializer’s default dtype.
- Returns:
An initializer that returns arrays whose values are normally distributed with mean
0and standard deviationstddev.
- metoryx.initializers.truncated_normal(stddev: float = 0.01, dtype: DType | None = None, lower: float = -2.0, upper: float = 2.0) Initializer¶
Builds an initializer that returns truncated-normal random arrays.
- Parameters:
stddev – optional; the standard deviation of the untruncated distribution. Note that this function does not apply the stddev correction as is done in the variancescaling initializers, and users are expected to apply this correction themselves via the stddev arg if they wish to employ it.
dtype – optional; the initializer’s default dtype.
lower – Float representing the lower bound for truncation. Applied before the output is multiplied by the stddev.
upper – Float representing the upper bound for truncation. Applied before the output is multiplied by the stddev.
- Returns:
An initializer that returns arrays whose values follow the truncated normal distribution with mean
0and standard deviationstddev, and range lower * stddev < x < upper * stddev.
- metoryx.initializers.variance_scaling(scale: float, mode: Literal['fan_in', 'fan_out', 'fan_avg', 'fan_geo_avg'], distribution: Literal['truncated_normal', 'normal', 'uniform'], in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"ordistribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of sqrt(scale/n), where n is, for eachmode:"fan_in": the number of inputs"fan_out": the number of outputs"fan_avg": the arithmetic average of the numbers of inputs and outputs"fan_geo_avg": the geometric average of the numbers of inputs and outputs
This initializer can be configured with
in_axis,out_axis, andbatch_axisto work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform", samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of sqrt(scale/n) where n is defined above.
- Parameters:
scale – scaling factor (positive float).
mode – one of
"fan_in","fan_out","fan_avg", and"fan_geo_avg".distribution – random distribution to use. One of
"truncated_normal","normal"and"uniform".in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- metoryx.initializers.glorot_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A Glorot uniform initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_avg", anddistribution="uniform".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.glorot_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Glorot normal initializer (aka Xavier normal initializer).
A Glorot normal initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_avg", anddistribution="truncated_normal".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.xavier_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A Glorot uniform initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_avg", anddistribution="uniform".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.xavier_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Glorot normal initializer (aka Xavier normal initializer).
A Glorot normal initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_avg", anddistribution="truncated_normal".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.he_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a He uniform initializer (aka Kaiming uniform initializer).
A He uniform initializer is a specialization of variance_scaling where
scale = 2.0,mode="fan_in", anddistribution="uniform".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.he_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a He normal initializer (aka Kaiming normal initializer).
A He normal initializer is a specialization of variance_scaling where
scale = 2.0,mode="fan_in", anddistribution="truncated_normal".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.kaiming_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a He uniform initializer (aka Kaiming uniform initializer).
A He uniform initializer is a specialization of variance_scaling where
scale = 2.0,mode="fan_in", anddistribution="uniform".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.kaiming_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a He normal initializer (aka Kaiming normal initializer).
A He normal initializer is a specialization of variance_scaling where
scale = 2.0,mode="fan_in", anddistribution="truncated_normal".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.lecun_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Lecun uniform initializer.
A Lecun uniform initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_in", anddistribution="uniform".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.lecun_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, batch_axis: int | Sequence[int] = (), dtype: DType | None = None) Initializer¶
Builds a Lecun normal initializer.
A Lecun normal initializer is a specialization of variance_scaling where
scale = 1.0,mode="fan_in", anddistribution="truncated_normal".- Parameters:
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns:
An initializer.
- metoryx.initializers.orthogonal(scale: float = 1.0, column_axis: int = -1, dtype: DType | None = None) Initializer¶
Builds an initializer that returns uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
- Parameters:
scale – the upper bound of the uniform distribution.
column_axis – the axis that contains the columns that should be orthogonal.
dtype – the default dtype of the weights.
- Returns:
An orthogonal initializer.
- metoryx.initializers.delta_orthogonal(scale: float = 1.0, column_axis: int = -1, dtype: DType | None = None) Initializer¶
Builds an initializer for delta orthogonal kernels.
- Parameters:
scale – the upper bound of the uniform distribution.
column_axis – the axis that contains the columns that should be orthogonal.
dtype – the default dtype of the weights.
- Returns:
A delta orthogonal initializer. The shape passed to the initializer must be 3D, 4D, or 5D.
Utilities¶
- class metoryx.utils.AverageMeter(*, with_timer: bool = False, timer_key: str = 'elapsed_time')¶
Bases:
objectStores and computes the weighted average of multiple metrics over time.
- __init__(*, with_timer: bool = False, timer_key: str = 'elapsed_time')¶
Initializes the AverageMeter.
- Parameters:
with_timer – Whether to track elapsed time from the meter reset to the compute call.
timer_key – The key name for the elapsed time metric.
- update(metric_dict: dict[str, Array | ndarray | bool | number | bool | int | float | complex], n: int = 1, *, prefix: str = '') None¶
Updates the meter with new metric values.
- Parameters:
metric_dict – A dictionary where keys are metric names and values are the corresponding metric values (can be arrays).
n – The number of samples the metrics correspond to (default is 1).
prefix – A string prefix to add to each metric name (default is “”).
- Returns:
The updated AverageMeter instance.
- compute() dict[str, float]¶
Computes the average for each metric.
- Returns:
A dictionary where keys are metric names and values are their average.
- reset() None¶
Resets the meter to initial state.
- metoryx.assign_variables(module: Module, variables: Variables) Module¶
Assign arrays to the module.
- Parameters:
module – The module to assign arrays to.
variables – The variables to assign to the module.
- Returns:
The module with assigned arrays.
Notes
The assigned arrays will be reflected when the module is initialized.
Currently, the array tree structure must match the module’s variable structure.