package kaun

  1. Overview
  2. Docs

Module Bert.For_token_classificationSource

BERT for token classification (NER, POS tagging)

Sourceval create : ?config:config -> num_labels:int -> unit -> Kaun.module_
Sourceval forward : model:Kaun.module_ -> params:'a Kaun.params -> input_ids:(int32, Rune.int32_elt) Rune.t -> ?attention_mask:(int32, Rune.int32_elt) Rune.t -> ?token_type_ids:(int32, Rune.int32_elt) Rune.t -> ?labels:(int32, Rune.int32_elt) Rune.t -> training:bool -> unit -> (float, 'a) Rune.t * (float, 'a) Rune.t option

Returns (logits, loss) where logits has shape batch_size; seq_len; num_labels