package kaun
Install
dune-project
Dependency
Authors
Maintainers
Sources
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/kaun.models/Kaun_models/LeNet/index.html
Module Kaun_models.LeNet
Source
LeNet-5: Classic CNN for handwritten digit recognition.
LeNet-5: Classic convolutional neural network for handwritten digit recognition.
LeCun et al., 1998: "Gradient-Based Learning Applied to Document Recognition" One of the first successful CNNs, originally designed for MNIST digit classification.
type config = {
num_classes : int;
(*Number of output classes (default: 10 for digits)
*)input_channels : int;
(*Number of input channels (default: 1 for grayscale)
*)input_size : int * int;
(*Input image size (default: 32x32)
*)activation : [ `tanh | `relu | `sigmoid ];
(*Activation function (original used tanh)
*)dropout_rate : float option;
(*Optional dropout rate for regularization
*)
}
Configuration for LeNet-5 model
LeNet-5 model instance
Create a new LeNet-5 model
create ?config ()
creates a new LeNet-5 model.
Architecture:
- Conv1: 6 filters of 5x5
- Pool1: 2x2 average pooling
- Conv2: 16 filters of 5x5
- Pool2: 2x2 average pooling
- FC1: 120 units
- FC2: 84 units
- Output: num_classes units
The original paper used average pooling and tanh activation, but modern implementations often use max pooling and ReLU.
Example:
let model = LeNet.create ~config:LeNet.mnist_config () in
let params = Kaun.init model ~rngs ~dtype:Float32 in
let output = Kaun.apply model params ~training:false input in
Create model for MNIST
for_mnist ()
creates a LeNet-5 model configured for MNIST digits. Equivalent to create ~config:mnist_config ()
.
Create model for CIFAR-10
for_cifar10 ()
creates a LeNet-5 model configured for CIFAR-10. Uses 3 input channels for RGB images.
val forward :
model:t ->
params:'a Kaun.params ->
training:bool ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.t
Forward pass through the model
forward ~model ~params ~training ~input
performs a forward pass.
val extract_features :
model:t ->
params:'a Kaun.params ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.t
Extract feature representations
extract_features ~model ~params ~input
extracts feature representations from the second-to-last layer (FC2), useful for transfer learning or visualization. Returns features of shape batch_size; 84
.
Model statistics
num_parameters params
returns the total number of parameters in the model.
parameter_breakdown params
returns a detailed breakdown of parameters by layer.
Training Helpers
type train_config = {
learning_rate : float;
batch_size : int;
num_epochs : int;
weight_decay : float option;
momentum : float option;
}
Training configuration
Default training configuration for MNIST
Compute accuracy
accuracy ~predictions ~labels
computes classification accuracy. Predictions should be logits of shape batch_size; num_classes
, labels should be class indices of shape batch_size
.