package nx

  1. Overview
  2. Docs

Source file packed_nx.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
type t = P : ('a, 'b) Nx.t -> t
type archive = (string, t) Hashtbl.t

let convert_result_with_error : type a b.
    (a, b) Nx.dtype -> t -> ((a, b) Nx.t, Error.t) result =
 fun target_dtype packed ->
  match packed with
  | P nx -> (
      let source_dtype = Nx.dtype nx in
      let source_ba_kind = Nx_core.Dtype.to_bigarray_ext_kind source_dtype in
      let target_ba_kind = Nx_core.Dtype.to_bigarray_ext_kind target_dtype in
      match Npy.Eq.Kind.( === ) source_ba_kind target_ba_kind with
      | Some Npy.Eq.W -> Ok nx
      | None -> Error Unsupported_dtype)

let as_float16 packed = convert_result_with_error Nx.float16 packed
let as_float32 packed = convert_result_with_error Nx.float32 packed
let as_float64 packed = convert_result_with_error Nx.float64 packed
let as_int8 packed = convert_result_with_error Nx.int8 packed
let as_int16 packed = convert_result_with_error Nx.int16 packed
let as_int32 packed = convert_result_with_error Nx.int32 packed
let as_int64 packed = convert_result_with_error Nx.int64 packed
let as_uint8 packed = convert_result_with_error Nx.uint8 packed
let as_uint16 packed = convert_result_with_error Nx.uint16 packed
let as_complex32 packed = convert_result_with_error Nx.complex32 packed
let as_complex64 packed = convert_result_with_error Nx.complex64 packed