package torch

  1. Overview
  2. Docs

A Long Short Term Memory (LSTM) recurrent neural network.

type t
type state = [
  1. | `h_c of Tensor.t * Tensor.t
]
val create : Var_store.t -> input_dim:int -> hidden_size:int -> t

create vs ~input_dim ~hidden_size creates a new RNN with the specified input dimension and hidden size.

val step : t -> state -> Tensor.t -> state

step t state input_ applies one step of the RNN on the given input using the specified state. The updated state is returned.

val seq : t -> Tensor.t -> is_training:bool -> Tensor.t * state

seq t inputs ~is_training applies multiple steps of the RNN starting from a zero state. The hidden states and the final state are returned. inputs should have shape batch_size * timesteps * input_dim, the returned output tensor then has shape batch_size * timesteps * hidden_size.

val zero_state : t -> batch_size:int -> state

zero_state t ~batch_size returns an initial state to be used for a RNN.

OCaml

Innovation. Community. Security.