package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f

doc/kaun.models/Kaun_models/Bert/For_token_classification/index.html

Module Bert.For_token_classificationSource

BERT for token classification (NER, POS tagging)

Sourceval create : ?config:config -> num_labels:int -> unit -> Kaun.module_
Sourceval forward : model:Kaun.module_ -> params:Kaun.params -> compute_dtype:(float, 'a) Rune.dtype -> input_ids:(int32, Rune.int32_elt) Rune.t -> ?config:config -> ?attention_mask:(int32, Rune.int32_elt) Rune.t -> ?token_type_ids:(int32, Rune.int32_elt) Rune.t -> ?labels:(int32, Rune.int32_elt) Rune.t -> training:bool -> ?rngs:Rune.Rng.key -> unit -> (float, 'a) Rune.t * (float, 'a) Rune.t option

Returns (logits, loss) where logits has shape batch_size; seq_len; num_labels