package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/kaun.models/Kaun_models/Bert/For_sequence_classification/index.html

Module Bert.For_sequence_classificationSource

BERT for sequence classification

Sourceval create : ?config:config -> num_labels:int -> unit -> Kaun.module_
Sourceval forward : model:Kaun.module_ -> params:'a Kaun.params -> input_ids:(int32, Rune.int32_elt) Rune.t -> ?attention_mask:(int32, Rune.int32_elt) Rune.t -> ?token_type_ids:(int32, Rune.int32_elt) Rune.t -> ?labels:(int32, Rune.int32_elt) Rune.t -> training:bool -> unit -> (float, 'a) Rune.t * (float, 'a) Rune.t option

Returns (logits, loss) where logits has shape batch_size; num_labels