package kaun

  1. Overview
  2. Docs

Module Kaun_models.GPT2Source

GPT-2: Generative Pre-trained Transformer 2 for causal language modeling.

GPT-2: Generative Pre-trained Transformer 2.

Radford et al., 2019: "Language Models are Unsupervised Multitask Learners"

A transformer-based autoregressive language model that uses causal self-attention for text generation and language understanding tasks.

Configuration

Sourcetype config = {
  1. vocab_size : int;
    (*

    Size of vocabulary

    *)
  2. n_positions : int;
    (*

    Maximum sequence length

    *)
  3. n_embd : int;
    (*

    Hidden dimension (d_model)

    *)
  4. n_layer : int;
    (*

    Number of transformer decoder layers

    *)
  5. n_head : int;
    (*

    Number of attention heads

    *)
  6. n_inner : int option;
    (*

    FFN intermediate dimension (defaults to 4 * n_embd)

    *)
  7. activation_function : [ `gelu | `relu | `swish | `gelu_new ];
    (*

    Activation function

    *)
  8. resid_pdrop : float;
    (*

    Dropout probability for residual connections

    *)
  9. embd_pdrop : float;
    (*

    Dropout probability for embeddings

    *)
  10. attn_pdrop : float;
    (*

    Dropout for attention probabilities

    *)
  11. layer_norm_epsilon : float;
    (*

    Layer normalization epsilon

    *)
  12. initializer_range : float;
    (*

    Standard deviation for weight initialization

    *)
  13. scale_attn_weights : bool;
    (*

    Whether to scale attention weights

    *)
  14. use_cache : bool;
    (*

    Whether to cache key/values

    *)
  15. scale_attn_by_inverse_layer_idx : bool;
    (*

    Scale attention by 1/sqrt(layer_idx)

    *)
  16. reorder_and_upcast_attn : bool;
    (*

    Reorder and upcast attention

    *)
  17. bos_token_id : int option;
    (*

    Beginning of sequence token ID

    *)
  18. eos_token_id : int option;
    (*

    End of sequence token ID

    *)
  19. pad_token_id : int option;
    (*

    Padding token ID

    *)
}

GPT-2 model configuration

Sourceval default_config : config

Standard GPT-2 configurations

Sourceval gpt2_small : config

GPT-2 Small: 12 layers, 768 hidden, 12 heads, 124M parameters

Sourceval gpt2_medium : config

GPT-2 Medium: 24 layers, 1024 hidden, 16 heads, 355M parameters

Sourceval gpt2_large : config

GPT-2 Large: 36 layers, 1280 hidden, 20 heads, 774M parameters

Sourceval gpt2_xl : config

GPT-2 XL: 48 layers, 1600 hidden, 25 heads, 1.5B parameters

Model Components

Sourceval embeddings : config:config -> unit -> Kaun.module_

GPT-2 embeddings combining token and position embeddings

Sourcetype 'a output = {
  1. last_hidden_state : (float, 'a) Rune.t;
    (*

    Sequence of hidden states at the last layer batch_size; seq_len; hidden_size

    *)
  2. hidden_states : (float, 'a) Rune.t list option;
    (*

    Hidden states from all layers if output_hidden_states=true

    *)
  3. attentions : (float, 'a) Rune.t list option;
    (*

    Attention weights from all layers if output_attentions=true

    *)
}

Model outputs

Sourcetype 'a gpt2 = {
  1. model : Kaun.module_;
  2. params : 'a Kaun.params;
  3. config : config;
  4. dtype : (float, 'a) Rune.dtype;
}

Unified GPT-2 model type

Sourcetype inputs = {
  1. input_ids : (int32, Rune.int32_elt) Rune.t;
  2. attention_mask : (int32, Rune.int32_elt) Rune.t option;
  3. position_ids : (int32, Rune.int32_elt) Rune.t option;
}

Input tensors for GPT-2

Sourceval create : ?config:config -> unit -> Kaun.module_

Create a new GPT-2 model

create ?config () creates a new GPT-2 model.

  • parameter config

    Model configuration (default: gpt2_small)

Sourceval from_pretrained : ?model_id:string -> ?revision:Kaun_huggingface.revision -> ?cache_config:Kaun_huggingface.Config.t -> dtype:(float, 'a) Rune.dtype -> unit -> 'a gpt2

Load pretrained GPT-2 from HuggingFace

from_pretrained ?model_id ?dtype () loads pretrained GPT-2.

Default model_id is "gpt2" is CPU, dtype is Float32. Returns a unified gpt2 record with model, params, and config.

Example:

  let gpt2 = GPT2.from_pretrained () in
  (* Or with options: *)
  let gpt2 = GPT2.from_pretrained ~model_id:"gpt2-medium" ()
Sourceval forward : 'a gpt2 -> inputs -> ?training:bool -> ?output_hidden_states:bool -> ?output_attentions:bool -> unit -> 'a output

Forward pass through GPT-2

forward ~model ~params ~input_ids ... () performs a forward pass.

  • parameter input_ids

    Token IDs batch_size; seq_len

  • parameter attention_mask

    Mask for padding tokens (1 for real tokens, 0 for padding)

  • parameter position_ids

    Custom position IDs (default: 0..seq_len-1)

  • parameter training

    Whether in training mode (affects dropout)

  • parameter output_hidden_states

    Whether to return all hidden states

  • parameter output_attentions

    Whether to return attention weights

Task-Specific Heads

Sourcemodule For_causal_lm : sig ... end

GPT-2 for causal language modeling

Tokenization

Sourcemodule Tokenizer : sig ... end

Utilities

Sourceval num_parameters : 'a Kaun.params -> int

Count total parameters in the model

Sourceval parameter_stats : 'a Kaun.params -> string

Get human-readable parameter statistics

GPT-2 Configuration Parsing

Sourceval parse_gpt2_config : Yojson.Safe.t -> config

Parse GPT-2 configuration from HuggingFace JSON format

Common Model Configurations

Sourceval load_gpt2_small : dtype:(float, 'a) Rune.dtype -> unit -> 'a gpt2

Load GPT-2 Small (124M parameters)

Sourceval load_gpt2_medium : dtype:(float, 'a) Rune.dtype -> unit -> 'a gpt2

Load GPT-2 Medium (355M parameters)

Sourceval load_gpt2_large : dtype:(float, 'a) Rune.dtype -> unit -> 'a gpt2

Load GPT-2 Large (774M parameters)

Sourceval load_gpt2_xl : dtype:(float, 'a) Rune.dtype -> unit -> 'a gpt2

Load GPT-2 XL (1.5B parameters)