package kaun

  1. Overview
  2. Docs

Module Kaun_models.BertSource

BERT: Bidirectional Encoder Representations from Transformers.

BERT: Bidirectional Encoder Representations from Transformers.

Devlin et al., 2018: "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"

A transformer-based model that uses bidirectional self-attention to understand context from both directions, revolutionizing NLP tasks.

Configuration

Sourcetype config = {
  1. vocab_size : int;
    (*

    Size of vocabulary

    *)
  2. hidden_size : int;
    (*

    Hidden dimension (d_model)

    *)
  3. num_hidden_layers : int;
    (*

    Number of transformer encoder layers

    *)
  4. num_attention_heads : int;
    (*

    Number of attention heads

    *)
  5. intermediate_size : int;
    (*

    FFN intermediate dimension (typically 4 * hidden_size)

    *)
  6. hidden_act : [ `gelu | `relu | `swish | `gelu_new ];
    (*

    Activation function

    *)
  7. hidden_dropout_prob : float;
    (*

    Dropout probability for hidden layers

    *)
  8. attention_probs_dropout_prob : float;
    (*

    Dropout for attention probabilities

    *)
  9. max_position_embeddings : int;
    (*

    Maximum sequence length

    *)
  10. type_vocab_size : int;
    (*

    Token type vocabulary size (for segment embeddings)

    *)
  11. layer_norm_eps : float;
    (*

    Layer normalization epsilon

    *)
  12. pad_token_id : int;
    (*

    Padding token ID

    *)
  13. position_embedding_type : [ `absolute | `relative ];
    (*

    Type of position embeddings

    *)
  14. use_cache : bool;
    (*

    Whether to cache key/values

    *)
  15. classifier_dropout : float option;
    (*

    Dropout for classification head

    *)
}

BERT model configuration

Sourceval default_config : config

Standard BERT configurations

Sourceval bert_base_uncased : config

BERT Base: 12 layers, 768 hidden, 12 heads, 110M parameters

Sourceval bert_large_uncased : config

BERT Large: 24 layers, 1024 hidden, 16 heads, 340M parameters

Sourceval bert_base_cased : config

Same as base_uncased but preserves case information

Sourceval bert_base_multilingual : config

BERT Base for 104 languages

Model Components

Sourceval embeddings : config:config -> unit -> Kaun.module_

BERT embeddings combining token, position, and segment embeddings

Sourceval pooler : hidden_size:int -> unit -> Kaun.module_
Sourcetype 'a output = {
  1. last_hidden_state : (float, 'a) Rune.t;
    (*

    Sequence of hidden states at the last layer batch_size; seq_len; hidden_size

    *)
  2. pooler_output : (float, 'a) Rune.t option;
    (*

    Pooled CLS token representation batch_size; hidden_size

    *)
  3. hidden_states : (float, 'a) Rune.t list option;
    (*

    Hidden states from all layers if output_hidden_states=true

    *)
  4. attentions : (float, 'a) Rune.t list option;
    (*

    Attention weights from all layers if output_attentions=true

    *)
}

Model outputs

Sourcetype 'a bert = {
  1. model : Kaun.module_;
  2. params : 'a Kaun.params;
  3. config : config;
  4. dtype : (float, 'a) Rune.dtype;
}

Unified BERT model type

Sourcetype inputs = {
  1. input_ids : (int32, Rune.int32_elt) Rune.t;
  2. attention_mask : (int32, Rune.int32_elt) Rune.t;
  3. token_type_ids : (int32, Rune.int32_elt) Rune.t option;
  4. position_ids : (int32, Rune.int32_elt) Rune.t option;
}

Input tensors for BERT

Sourceval create : ?config:config -> ?add_pooling_layer:bool -> unit -> Kaun.module_

Create a new BERT model

create ?config ?add_pooling_layer () creates a new BERT model.

  • parameter config

    Model configuration (default: bert_base_uncased)

  • parameter add_pooling_layer

    Whether to add pooling layer for CLS token (default: true)

Sourceval from_pretrained : ?model_id:string -> ?revision:Kaun_huggingface.revision -> ?cache_config:Kaun_huggingface.Config.t -> dtype:(float, 'a) Rune.dtype -> unit -> 'a bert

Load pretrained BERT from HuggingFace

from_pretrained ?model_id ?dtype () loads pretrained BERT.

Default model_id is "bert-base-uncased", device is CPU, dtype is Float32. Returns a unified bert record with model, params, and config.

Example:

  let bert = BERT.from_pretrained () in
  (* Or with options: *)
  let bert = BERT.from_pretrained ~model_id:"bert-base-multilingual-cased" ()
Sourceval forward : 'a bert -> inputs -> ?training:bool -> ?output_hidden_states:bool -> ?output_attentions:bool -> unit -> 'a output

Forward pass through BERT

forward ~model ~params ~input_ids ... () performs a forward pass.

  • parameter input_ids

    Token IDs batch_size; seq_len

  • parameter attention_mask

    Mask for padding tokens (1 for real tokens, 0 for padding)

  • parameter token_type_ids

    Segment IDs for sentence pairs (0 or 1)

  • parameter position_ids

    Custom position IDs (default: 0..seq_len-1)

  • parameter head_mask

    Mask to nullify specific attention heads

  • parameter training

    Whether in training mode (affects dropout)

  • parameter output_hidden_states

    Whether to return all hidden states

  • parameter output_attentions

    Whether to return attention weights

Task-Specific Heads

Sourcemodule For_masked_lm : sig ... end

BERT for masked language modeling

BERT for sequence classification

Sourcemodule For_token_classification : sig ... end

BERT for token classification (NER, POS tagging)

Tokenization

Sourcemodule Tokenizer : sig ... end

Utilities

Sourceval create_attention_mask : input_ids:(int32, Rune.int32_elt) Rune.t -> pad_token_id:int -> dtype:(float, 'a) Rune.dtype -> (float, 'a) Rune.t

Create attention mask from input IDs

Creates attention mask where 1.0 for real tokens and 0.0 for padding

Sourceval get_embeddings : model:Kaun.module_ -> params:'a Kaun.params -> input_ids:(int32, Rune.int32_elt) Rune.t -> ?attention_mask:(int32, Rune.int32_elt) Rune.t -> layer_index:int -> unit -> (float, 'a) Rune.t

Get BERT embeddings for text analysis

Extract embeddings from a specific layer (0 = embeddings, 1..n = encoder layers)

Sourceval num_parameters : 'a Kaun.params -> int

Count total parameters in the model

Sourceval parameter_stats : 'a Kaun.params -> string

Get human-readable parameter statistics

BERT Configuration Parsing

Sourceval parse_bert_config : Yojson.Safe.t -> config

Parse BERT configuration from HuggingFace JSON format

Common Model Configurations

Sourceval load_bert_base_uncased : dtype:(float, 'a) Rune.dtype -> unit -> 'a bert

Load BERT Base Uncased (110M parameters)

Sourceval load_bert_large_uncased : dtype:(float, 'a) Rune.dtype -> unit -> 'a bert

Load BERT Large Uncased (340M parameters)

Sourceval load_bert_base_cased : dtype:(float, 'a) Rune.dtype -> unit -> 'a bert

Load BERT Base Cased (110M parameters)

Sourceval load_bert_base_multilingual_cased : dtype:(float, 'a) Rune.dtype -> unit -> 'a bert

Load Multilingual BERT Base Cased (110M parameters, 104 languages)