package nx

  1. Overview
  2. Docs
N-dimensional arrays for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha3.tbz
sha256=96d35ce03dfbebd2313657273e24c2e2d20f9e6c7825b8518b69bd1d6ed5870f
sha512=90c5053731d4108f37c19430e45456063e872b04b8a1bbad064c356e1b18e69222de8bfcf4ec14757e71f18164ec6e4630ba770dbcb1291665de5418827d1465

doc/src/nx.io/packed_nx.ml.html

Source file packed_nx.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
(*---------------------------------------------------------------------------
  Copyright (c) 2026 The Raven authors. All rights reserved.
  SPDX-License-Identifier: ISC
  ---------------------------------------------------------------------------*)

let strf = Printf.sprintf

type t = P : ('a, 'b) Nx.t -> t
type archive = (string, t) Hashtbl.t

let err_dtype_mismatch ~expected ~got =
  strf "dtype mismatch: expected %s, got %s" expected got

let to_typed : type a b. (a, b) Nx.dtype -> t -> (a, b) Nx.t =
 fun target (P nx) ->
  let source = Nx.dtype nx in
  match Nx_core.Dtype.equal_witness source target with
  | Some Type.Equal -> (nx : (a, b) Nx.t)
  | None ->
      let expected = Nx_core.Dtype.to_string target in
      let got = Nx_core.Dtype.to_string source in
      failwith (err_dtype_mismatch ~expected ~got)

let packed_shape (P nx) = Nx.shape nx