package kaun

  1. Overview
  2. Docs

Module Kaun.LayerSource

Neural network layer constructors.

This module provides functional layer constructors for building neural networks. Each function creates a layer configuration that returns a module_, which encapsulates parameter initialization and forward computation. Layers can be composed using sequential to build complex architectures.

All layers follow a consistent pattern: they take architecture parameters (dimensions, hyperparameters) and optional initialization strategies, returning a module that can be initialized with random number generators and applied to input tensors.

Usage Overview

Create layers by calling constructor functions:

  let dense = Layer.linear ~in_features:784 ~out_features:128 () in
  let activation = Layer.relu () in

Compose layers into networks:

  let network = Layer.sequential [
    Layer.linear ~in_features:784 ~out_features:128 ();
    Layer.relu ();
    Layer.dropout ~rate:0.2 ();
    Layer.linear ~in_features:128 ~out_features:10 ();
  ] in

Initialize and apply:

  let params = Kaun.init network ~rngs ~dtype in
  let output = Kaun.apply network params ~training:true input in
Sourcetype module_ = {
  1. init : 'layout. rngs:Rune.Rng.key -> dtype:(float, 'layout) Rune.dtype -> Ptree.t;
    (*

    init ~rngs ~dtype initializes module parameters.

    Creates a parameter tree containing all trainable parameters for this module. The function is polymorphic over layout and device to support different tensor backends and memory layouts.

    • parameter rngs

      Random number generator key for deterministic initialization

    • parameter device

      Target device (CPU, CUDA, etc.) for parameter allocation

    • parameter dtype

      Data type specification, typically Rune.float32

    The RNG key should be split appropriately for modules with multiple parameters to ensure independent initialization.

    *)
  2. apply : 'layout. Ptree.t -> training:bool -> ?rngs:Rune.Rng.key -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t;
    (*

    apply params ~training ?rngs input performs forward computation.

    Executes the module's forward pass using the provided parameters and input tensor.

    • parameter params

      Parameter tree from init function

    • parameter training

      Whether module is in training mode (affects dropout, batch norm, etc.)

    • parameter rngs

      Optional RNG key for stochastic operations (dropout, etc.)

    • parameter input

      Input tensor to transform

    The training flag enables different behaviors:

    • Dropout: Applied only when training=true
    • Batch normalization: Uses batch statistics when training=true
    • Other regularization: Activated based on training mode

    RNG is required for stochastic operations during training. Operations needing randomness will fail if rngs is None when training=true.

    *)
}

Convolutional Layers

Sourceval conv1d : in_channels:int -> out_channels:int -> ?kernel_size:int -> ?stride:int -> ?dilation:int -> ?padding:[ `Same | `Valid | `Causal ] -> unit -> module_

conv1d ~in_channels ~out_channels ?kernel_size ?stride ?dilation ?padding () creates a 1D convolutional layer over inputs of shape batch; in_channels; length. Supports `Same, `Valid, and `Causal padding. Default: kernel_size=3, stride=1, dilation=1, padding=`Same.

Sourceval conv2d : in_channels:int -> out_channels:int -> ?kernel_size:(int * int) -> unit -> module_

conv2d ~in_channels ~out_channels ?kernel_size () creates a 2D convolutional layer.

Performs 2D convolution over 4D input tensors of shape batch_size, in_channels, height, width. The layer maintains learnable weight and bias parameters.

  • parameter in_channels

    Number of input channels

  • parameter out_channels

    Number of output filters

  • parameter kernel_size

    Filter dimensions as (height, width). Default: (3, 3)

The weight tensor has shape out_channels, in_channels, kernel_height, kernel_width and is initialized using Glorot uniform initialization. The bias tensor has shape out_channels and is zero-initialized.

Example

  let conv = Layer.conv2d ~in_channels:3 ~out_channels:64 ~kernel_size:(5, 5) () in
  (* Processes RGB images (3 channels) to produce 64 feature maps with 5x5 filters *)

Dense Layers

Sourceval linear : in_features:int -> out_features:int -> ?weight_init:Initializers.t -> ?bias_init:Initializers.t -> unit -> module_

linear ~in_features ~out_features ?weight_init ?bias_init () creates a fully connected layer.

Applies linear transformation y = xW^T + b where x is input, W is weight matrix, and b is bias vector. Accepts inputs of any shape with last dimension matching in_features.

  • parameter in_features

    Size of input feature dimension

  • parameter out_features

    Size of output feature dimension

The weight tensor has shape out_features, in_features and bias has shape out_features.

Examples

  let classifier = Layer.linear ~in_features:512 ~out_features:10 () in
  (* Maps 512-dimensional features to 10 class logits *)

  let custom_init = Layer.linear
    ~in_features:256 ~out_features:128
    ~weight_init:(Initializers.he_normal ())
    ~bias_init:(Initializers.constant 0.1) () in

Regularization Layers

Sourceval dropout : rate:float -> unit -> module_

dropout ~rate () creates a dropout layer for regularization.

During training, randomly sets elements to zero with probability rate and scales remaining elements by 1 / (1 - rate) to maintain expected values. During evaluation, applies identity transformation.

  • parameter rate

    Dropout probability in range 0.0, 1.0

Requires random number generator during training. No learnable parameters.

Example

  let drop = Layer.dropout ~rate:0.5 () in
  (* Randomly zeros 50% of activations during training *)
Sourceval batch_norm : num_features:int -> unit -> module_

batch_norm ~num_features () creates a batch normalization layer.

Normalizes inputs across the batch dimension, learning scale and shift parameters. Applies transformation y = γ((x - μ) / σ) + β where μ and σ are batch statistics, and γ, β are learnable parameters.

  • parameter num_features

    Number of features to normalize (typically channel dimension)

Maintains running statistics for evaluation mode. Parameters include scale (γ), bias (β), running mean, and running variance.

Pooling Layers

Sourceval max_pool2d : kernel_size:(int * int) -> ?stride:(int * int) -> unit -> module_

max_pool2d ~kernel_size ?stride () creates a 2D max pooling layer.

Applies maximum operation over spatial windows, reducing spatial dimensions while preserving channel dimension.

  • parameter kernel_size

    Pooling window size as (height, width)

  • parameter stride

    Pooling stride as (height, width). Default: same as kernel_size

No learnable parameters.

Sourceval avg_pool2d : kernel_size:(int * int) -> ?stride:(int * int) -> unit -> module_

avg_pool2d ~kernel_size ?stride () creates a 2D average pooling layer.

Applies average operation over spatial windows, providing smoother downsampling compared to max pooling.

  • parameter kernel_size

    Pooling window size as (height, width)

  • parameter stride

    Pooling stride as (height, width). Default: same as kernel_size

No learnable parameters.

Reshape Layers

Sourceval flatten : unit -> module_

flatten () creates a flatten layer that reshapes multidimensional inputs to 2D.

Preserves batch dimension while flattening all other dimensions. Transforms shape batch_size, d1, d2, ..., dn to batch_size, d1 * d2 * ... * dn.

Commonly used before dense layers in CNN architectures. No learnable parameters.

Activation Functions

Sourceval relu : unit -> module_

relu () creates a ReLU activation layer applying max(0, x) elementwise.

Most common activation for hidden layers. Computationally efficient with good gradient flow for positive inputs. No learnable parameters.

Sourceval sigmoid : unit -> module_

sigmoid () creates a sigmoid activation layer applying 1 / (1 + exp(-x)) elementwise.

Maps inputs to range (0, 1). Commonly used for binary classification and gating mechanisms. No learnable parameters.

Sourceval tanh : unit -> module_

tanh () creates a hyperbolic tangent activation layer applying tanh(x) elementwise.

Maps inputs to range (-1, 1). Provides stronger gradients than sigmoid but can suffer from vanishing gradients. No learnable parameters.

Sourceval gelu : unit -> module_

gelu () creates a GELU activation layer.

Applies Gaussian Error Linear Unit activation, popular in transformer architectures. Smoother alternative to ReLU with better gradient properties. No learnable parameters.

Sourceval swish : unit -> module_

swish () creates a Swish activation layer applying x * sigmoid(x) elementwise.

Self-gated activation function that can outperform ReLU in deep networks. No learnable parameters.

Composition

Sourceval sequential : module_ list -> module_

sequential layers creates a sequential composition of layers.

Applies layers in order, threading output of each layer as input to the next. The resulting module's parameters are the union of all component layer parameters.

  • parameter layers

    List of layers to compose

Example

  let mlp = Layer.sequential [
    Layer.linear ~in_features:784 ~out_features:256 ();
    Layer.relu ();
    Layer.dropout ~rate:0.3 ();
    Layer.linear ~in_features:256 ~out_features:10 ();
  ] in

Advanced Layers

Sourceval einsum : einsum_str:string -> shape:int array -> ?kernel_init:Initializers.t -> unit -> module_

einsum ~einsum_str ~shape ?kernel_init () creates a parameterized Einstein summation layer.

Implements learnable tensor contractions specified by Einstein notation. Useful for implementing custom linear transformations and attention mechanisms.

  • parameter einsum_str

    Einstein summation string describing the contraction

  • parameter shape

    Shape of the learnable kernel parameter

Normalization Layers

Sourceval rms_norm : dim:int -> ?eps:float -> ?scale_init:Initializers.t -> unit -> module_

rms_norm ~dim ?eps ?scale_init () creates a Root Mean Square normalization layer.

Applies RMS normalization with learnable scaling. Normalizes by the RMS of activations rather than full statistics like batch normalization.

  • parameter dim

    Dimension to normalize over

  • parameter eps

    Small constant for numerical stability. Default: 1e-6

Sourceval layer_norm : dim:int -> ?eps:float -> ?elementwise_affine:bool -> unit -> module_

layer_norm ~dim ?eps ?elementwise_affine () creates a layer normalization layer.

Normalizes activations across the feature dimension within each sample. Popular in transformer architectures for stable training.

  • parameter dim

    Dimension to normalize over

  • parameter eps

    Small constant for numerical stability. Default: 1e-6

  • parameter elementwise_affine

    Whether to learn scale and shift parameters. Default: true

Embedding Layers

Sourceval embedding : vocab_size:int -> embed_dim:int -> ?scale:bool -> ?embedding_init:Initializers.t -> unit -> module_

embedding ~vocab_size ~embed_dim ?scale ?embedding_init () creates an embedding lookup layer.

Maps discrete tokens (integers) to dense vectors. Commonly used as the first layer in NLP models to convert token IDs to continuous representations.

  • parameter vocab_size

    Size of the vocabulary (number of possible tokens)

  • parameter embed_dim

    Dimensionality of embedding vectors

  • parameter scale

    Whether to scale embeddings by sqrt(embed_dim). Default: false

The embedding matrix has shape vocab_size, embed_dim.

Attention and Position Encoding

Sourceval mlp : in_features:int -> hidden_features:int -> out_features:int -> ?activation:[ `relu | `gelu | `swish ] -> ?dropout:float -> unit -> module_

mlp ~in_features ~hidden_features ~out_features ... creates a multi-layer perceptron (feed-forward network).

Standard MLP architecture: Linear -> Activation -> Dropout -> Linear -> Dropout Commonly used in transformers and other architectures.

  • parameter in_features

    Input dimension

  • parameter hidden_features

    Hidden layer dimension

  • parameter out_features

    Output dimension

  • parameter activation

    Activation function. Default: `gelu

  • parameter dropout

    Dropout probability. Default: 0.0

Recurrent Layers

Sourceval rnn : input_size:int -> hidden_size:int -> ?return_sequences:bool -> ?learned_init:bool -> unit -> module_

Simple tanh RNN over a sequence. Input batch; seq; input_size, output batch; hidden_size (last hidden state).

Sourceval gru : input_size:int -> hidden_size:int -> ?return_sequences:bool -> ?learned_init:bool -> unit -> module_

GRU over a sequence. Input/output like rnn.

Sourceval lstm : input_size:int -> hidden_size:int -> ?return_sequences:bool -> ?learned_init:bool -> unit -> module_

LSTM over a sequence. Input/output like rnn.

Positional Encodings

Sourceval positional_embedding_learned : max_len:int -> embed_dim:int -> unit -> module_

Adds learned positional embeddings to input batch; seq; embed_dim.

Sourceval positional_encoding_sinusoidal_table : max_len:int -> embed_dim:int -> dtype:(float, 'layout) Rune.dtype -> (float, 'layout) Rune.t

Create a max_len; embed_dim sinusoidal positional encoding table (not trainable). Can be added to token embeddings.