package kaun

  1. Overview
  2. Docs

Module Attention.Multi_headSource

Sourcetype config = {
  1. embed_dim : int;
  2. num_heads : int;
  3. num_kv_heads : int option;
  4. head_dim : int option;
  5. dropout : float;
  6. use_qk_norm : bool;
  7. attn_logits_soft_cap : float option;
  8. query_pre_attn_scalar : float option;
}
Sourceval make_config : embed_dim:int -> num_heads:int -> ?num_kv_heads:int -> ?head_dim:int -> ?dropout:float -> ?use_qk_norm:bool -> ?attn_logits_soft_cap:float -> ?query_pre_attn_scalar:float -> unit -> config
Sourcetype params = Ptree.t
Sourceval init : config -> rngs:Rune.Rng.key -> dtype:(float, 'layout) Rune.dtype -> params
Sourceval apply : ?rngs:Rune.Rng.key -> ?attention_mask:Rune.bool_t -> config -> params -> training:bool -> query:(float, 'layout) Rune.t -> key:(float, 'layout) Rune.t -> value:(float, 'layout) Rune.t -> (float, 'layout) Rune.t