package kaun
Install
dune-project
Dependency
Authors
Maintainers
Sources
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/kaun.models/Kaun_models/Bert/index.html
Module Kaun_models.Bert
Source
BERT: Bidirectional Encoder Representations from Transformers.
BERT: Bidirectional Encoder Representations from Transformers.
Devlin et al., 2018: "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
A transformer-based model that uses bidirectional self-attention to understand context from both directions, revolutionizing NLP tasks.
Configuration
type config = {
vocab_size : int;
(*Size of vocabulary
*)num_attention_heads : int;
(*Number of attention heads
*)intermediate_size : int;
(*FFN intermediate dimension (typically 4 * hidden_size)
*)attention_probs_dropout_prob : float;
(*Dropout for attention probabilities
*)max_position_embeddings : int;
(*Maximum sequence length
*)type_vocab_size : int;
(*Token type vocabulary size (for segment embeddings)
*)layer_norm_eps : float;
(*Layer normalization epsilon
*)pad_token_id : int;
(*Padding token ID
*)position_embedding_type : [ `absolute | `relative ];
(*Type of position embeddings
*)use_cache : bool;
(*Whether to cache key/values
*)classifier_dropout : float option;
(*Dropout for classification head
*)
}
BERT model configuration
Model Components
BERT embeddings combining token, position, and segment embeddings
type 'a output = {
pooler_output : (float, 'a) Rune.t option;
(*Pooled
*)CLS
token representationbatch_size; hidden_size
attentions : (float, 'a) Rune.t list option;
(*Attention weights from all layers if output_attentions=true
*)
}
Model outputs
type 'a bert = {
model : Kaun.module_;
params : 'a Kaun.params;
config : config;
dtype : (float, 'a) Rune.dtype;
}
Unified BERT model type
type inputs = {
input_ids : (int32, Rune.int32_elt) Rune.t;
attention_mask : (int32, Rune.int32_elt) Rune.t;
token_type_ids : (int32, Rune.int32_elt) Rune.t option;
position_ids : (int32, Rune.int32_elt) Rune.t option;
}
Input tensors for BERT
Create a new BERT model
create ?config ?add_pooling_layer ()
creates a new BERT model.
val from_pretrained :
?model_id:string ->
?revision:Kaun_huggingface.revision ->
?cache_config:Kaun_huggingface.Config.t ->
dtype:(float, 'a) Rune.dtype ->
unit ->
'a bert
Load pretrained BERT from HuggingFace
from_pretrained ?model_id ?dtype ()
loads pretrained BERT.
Default model_id is "bert-base-uncased", device is CPU, dtype is Float32. Returns a unified bert record with model, params, and config.
Example:
let bert = BERT.from_pretrained () in
(* Or with options: *)
let bert = BERT.from_pretrained ~model_id:"bert-base-multilingual-cased" ()
val forward :
'a bert ->
inputs ->
?training:bool ->
?output_hidden_states:bool ->
?output_attentions:bool ->
unit ->
'a output
Forward pass through BERT
forward ~model ~params ~input_ids ... ()
performs a forward pass.
Task-Specific Heads
BERT for masked language modeling
BERT for sequence classification
BERT for token classification (NER, POS tagging)
Tokenization
Utilities
val create_attention_mask :
input_ids:(int32, Rune.int32_elt) Rune.t ->
pad_token_id:int ->
dtype:(float, 'a) Rune.dtype ->
(float, 'a) Rune.t
Create attention mask from input IDs
Creates attention mask where 1.0 for real tokens and 0.0 for padding
val get_embeddings :
model:Kaun.module_ ->
params:'a Kaun.params ->
input_ids:(int32, Rune.int32_elt) Rune.t ->
?attention_mask:(int32, Rune.int32_elt) Rune.t ->
layer_index:int ->
unit ->
(float, 'a) Rune.t
Get BERT embeddings for text analysis
Extract embeddings from a specific layer (0 = embeddings, 1..n = encoder layers)
Count total parameters in the model
Get human-readable parameter statistics
BERT Configuration Parsing
Parse BERT configuration from HuggingFace JSON format
Common Model Configurations
Load BERT Base Uncased (110M parameters)
Load BERT Large Uncased (340M parameters)
Load BERT Base Cased (110M parameters)
Load Multilingual BERT Base Cased (110M parameters, 104 languages)