package kaun
Install
dune-project
Dependency
Authors
Maintainers
Sources
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f
doc/kaun.models/Kaun_models/LeNet/index.html
Module Kaun_models.LeNetSource
LeNet-5: Classic CNN for handwritten digit recognition.
LeNet-5: Classic convolutional neural network for handwritten digit recognition.
LeCun et al., 1998: "Gradient-Based Learning Applied to Document Recognition" One of the first successful CNNs, originally designed for MNIST digit classification.
type config = {num_classes : int;(*Number of output classes (default: 10 for digits)
*)input_channels : int;(*Number of input channels (default: 1 for grayscale)
*)input_size : int * int;(*Input image size (default: 32x32)
*)activation : [ `tanh | `relu | `sigmoid ];(*Activation function (original used tanh)
*)dropout_rate : float option;(*Optional dropout rate for regularization
*)
}Configuration for LeNet-5 model
LeNet-5 model instance
Create a new LeNet-5 model
create ?config () creates a new LeNet-5 model.
Architecture:
- Conv1: 6 filters of 5x5
- Pool1: 2x2 average pooling
- Conv2: 16 filters of 5x5
- Pool2: 2x2 average pooling
- FC1: 120 units
- FC2: 84 units
- Output: num_classes units
The original paper used average pooling and tanh activation, but modern implementations often use max pooling and ReLU.
Example:
let model = LeNet.create ~config:LeNet.mnist_config () in
let params = Kaun.init model ~rngs ~dtype:Float32 in
let output = Kaun.apply model params ~training:false input inCreate model for MNIST
for_mnist () creates a LeNet-5 model configured for MNIST digits. Equivalent to create ~config:mnist_config ().
Create model for CIFAR-10
for_cifar10 () creates a LeNet-5 model configured for CIFAR-10. Uses 3 input channels for RGB images.
val forward :
model:t ->
params:Kaun.params ->
training:bool ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.tForward pass through the model
forward ~model ~params ~training ~input performs a forward pass.
val extract_features :
model:t ->
params:Kaun.params ->
input:(float, 'a) Rune.t ->
(float, 'a) Rune.tExtract feature representations
extract_features ~model ~params ~input extracts feature representations from the second-to-last layer (FC2), useful for transfer learning or visualization. Returns features of shape batch_size; 84.
Model statistics
num_parameters params returns the total number of parameters in the model.
parameter_breakdown params returns a detailed breakdown of parameters by layer.
Training Helpers
type train_config = {learning_rate : float;batch_size : int;num_epochs : int;weight_decay : float option;momentum : float option;
}Training configuration
Default training configuration for MNIST
Compute accuracy
accuracy ~predictions ~labels computes classification accuracy. Predictions should be logits of shape batch_size; num_classes, labels should be class indices of shape batch_size.