This module provides weight initialization strategies matching the Flax/JAX neural network library API. All initializers return functions that take RNG seed, shape, device, and dtype parameters.
Returns uniformly distributed orthogonal matrices. If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
parameterscale
Scaling factor (default: 1.0)
parametercolumn_axis
Axis containing columns that should be orthogonal (default: -1)
Delta orthogonal initializer for convolutional layers
Initializer for convolutional layers that preserves identity in the spatial dimensions. Requires 3D, 4D, or 5D tensor shape with square spatial dimensions.
parameterscale
Scaling factor (default: 1.0)
parametercolumn_axis
Axis containing columns that should be orthogonal (default: -1)