package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/kaun.models/Kaun_models/LeNet/index.html

Module Kaun_models.LeNetSource

LeNet-5: Classic CNN for handwritten digit recognition.

LeNet-5: Classic convolutional neural network for handwritten digit recognition.

LeCun et al., 1998: "Gradient-Based Learning Applied to Document Recognition" One of the first successful CNNs, originally designed for MNIST digit classification.

Sourcetype config = {
  1. num_classes : int;
    (*

    Number of output classes (default: 10 for digits)

    *)
  2. input_channels : int;
    (*

    Number of input channels (default: 1 for grayscale)

    *)
  3. input_size : int * int;
    (*

    Input image size (default: 32x32)

    *)
  4. activation : [ `tanh | `relu | `sigmoid ];
    (*

    Activation function (original used tanh)

    *)
  5. dropout_rate : float option;
    (*

    Optional dropout rate for regularization

    *)
}

Configuration for LeNet-5 model

Sourceval default_config : config

Default configuration (original LeNet-5 for MNIST)

Sourceval mnist_config : config

MNIST-specific configuration (28x28 input, padded to 32x32)

Sourceval cifar10_config : config

CIFAR-10 configuration (32x32 RGB input)

LeNet-5 model instance

Sourceval create : ?config:config -> unit -> t

Create a new LeNet-5 model

create ?config () creates a new LeNet-5 model.

Architecture:

  • Conv1: 6 filters of 5x5
  • Pool1: 2x2 average pooling
  • Conv2: 16 filters of 5x5
  • Pool2: 2x2 average pooling
  • FC1: 120 units
  • FC2: 84 units
  • Output: num_classes units

The original paper used average pooling and tanh activation, but modern implementations often use max pooling and ReLU.

Example:

  let model = LeNet.create ~config:LeNet.mnist_config () in
  let params = Kaun.init model ~rngs ~dtype:Float32 in
  let output = Kaun.apply model params ~training:false input in
Sourceval for_mnist : unit -> t

Create model for MNIST

for_mnist () creates a LeNet-5 model configured for MNIST digits. Equivalent to create ~config:mnist_config ().

Sourceval for_cifar10 : unit -> t

Create model for CIFAR-10

for_cifar10 () creates a LeNet-5 model configured for CIFAR-10. Uses 3 input channels for RGB images.

Sourceval forward : model:t -> params:'a Kaun.params -> training:bool -> input:(float, 'a) Rune.t -> (float, 'a) Rune.t

Forward pass through the model

forward ~model ~params ~training ~input performs a forward pass.

  • parameter model

    The LeNet-5 model

  • parameter params

    Model parameters

  • parameter training

    Whether in training mode (affects dropout if configured)

  • parameter input

    Input tensor of shape batch_size; channels; height; width

  • returns

    Output logits of shape batch_size; num_classes

Sourceval extract_features : model:t -> params:'a Kaun.params -> input:(float, 'a) Rune.t -> (float, 'a) Rune.t

Extract feature representations

extract_features ~model ~params ~input extracts feature representations from the second-to-last layer (FC2), useful for transfer learning or visualization. Returns features of shape batch_size; 84.

Sourceval num_parameters : 'a Kaun.params -> int

Model statistics

num_parameters params returns the total number of parameters in the model.

Sourceval parameter_breakdown : 'a Kaun.params -> string

parameter_breakdown params returns a detailed breakdown of parameters by layer.

Training Helpers

Sourcetype train_config = {
  1. learning_rate : float;
  2. batch_size : int;
  3. num_epochs : int;
  4. weight_decay : float option;
  5. momentum : float option;
}

Training configuration

Sourceval default_train_config : train_config

Default training configuration for MNIST

Sourceval accuracy : predictions:(float, 'a) Rune.t -> labels:(int, Rune.int32_elt) Rune.t -> float

Compute accuracy

accuracy ~predictions ~labels computes classification accuracy. Predictions should be logits of shape batch_size; num_classes, labels should be class indices of shape batch_size.

On This Page
  1. Training Helpers