package kaun
Install
dune-project
Dependency
Authors
Maintainers
Sources
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/kaun.models/Kaun_models/LeNet/index.html
Module Kaun_models.LeNetSource
LeNet-5: Classic CNN for handwritten digit recognition.
LeNet-5: Classic convolutional neural network for handwritten digit recognition.
LeCun et al., 1998: "Gradient-Based Learning Applied to Document Recognition" One of the first successful CNNs, originally designed for MNIST digit classification.
type config = {num_classes : int;(*Number of output classes (default: 10 for digits)
*)input_channels : int;(*Number of input channels (default: 1 for grayscale)
*)input_size : int * int;(*Input image size (default: 32x32)
*)activation : [ `tanh | `relu | `sigmoid ];(*Activation function (original used tanh)
*)dropout_rate : float option;(*Optional dropout rate for regularization
*)
}Configuration for LeNet-5 model
LeNet-5 model instance
Create a new LeNet-5 model
create ?config () creates a new LeNet-5 model.
Architecture:
- Conv1: 6 filters of 5x5
- Pool1: 2x2 average pooling
- Conv2: 16 filters of 5x5
- Pool2: 2x2 average pooling
- FC1: 120 units
- FC2: 84 units
- Output: num_classes units
The original paper used average pooling and tanh activation, but modern implementations often use max pooling and ReLU.
Example:
let model = LeNet.create ~config:LeNet.mnist_config () in
let params = Kaun.init model ~rngs ~dtype:Float32 in
let output = Kaun.apply model params ~training:false input inCreate model for MNIST
for_mnist () creates a LeNet-5 model configured for MNIST digits. Equivalent to create ~config:mnist_config ().
Create model for CIFAR-10
for_cifar10 () creates a LeNet-5 model configured for CIFAR-10. Uses 3 input channels for RGB images.
val forward :
model:t ->
params:'a Kaun.params ->
training:bool ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.tForward pass through the model
forward ~model ~params ~training ~input performs a forward pass.
val extract_features :
model:t ->
params:'a Kaun.params ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.tExtract feature representations
extract_features ~model ~params ~input extracts feature representations from the second-to-last layer (FC2), useful for transfer learning or visualization. Returns features of shape batch_size; 84.
Model statistics
num_parameters params returns the total number of parameters in the model.
parameter_breakdown params returns a detailed breakdown of parameters by layer.
Training Helpers
type train_config = {learning_rate : float;batch_size : int;num_epochs : int;weight_decay : float option;momentum : float option;
}Training configuration
Default training configuration for MNIST
Compute accuracy
accuracy ~predictions ~labels computes classification accuracy. Predictions should be logits of shape batch_size; num_classes, labels should be class indices of shape batch_size.