package nx

  1. Overview
  2. Docs
N-dimensional arrays for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha3.tbz
sha256=96d35ce03dfbebd2313657273e24c2e2d20f9e6c7825b8518b69bd1d6ed5870f
sha512=90c5053731d4108f37c19430e45456063e872b04b8a1bbad064c356e1b18e69222de8bfcf4ec14757e71f18164ec6e4630ba770dbcb1291665de5418827d1465

doc/src/nx.io/nx_npy.ml.html

Source file nx_npy.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
(*---------------------------------------------------------------------------
  Copyright (c) 2026 The Raven authors. All rights reserved.
  SPDX-License-Identifier: ISC
  ---------------------------------------------------------------------------*)

open Error
open Packed_nx

let strf = Printf.sprintf

(* Convert genarray from Npy (fortran layout) to Nx (c layout) *)
let npy_to_nx (Npy.P ga) =
  let ga = Nx_buffer.genarray_change_layout ga Bigarray.C_layout in
  let shape = Nx_buffer.genarray_dims ga in
  P (Nx.of_buffer (Nx_buffer.of_genarray ga) ~shape)

(* Uniform exception-to-result conversion *)
let wrap_exn f =
  try f () with
  | Npy.Read_error msg -> Error (Format_error msg)
  | Zip.Error (name, func, msg) ->
      Error (Io_error (strf "zip: %s in %s: %s" name func msg))
  | Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e))
  | Sys_error msg -> Error (Io_error msg)
  | Failure msg -> Error (Format_error msg)
  | ex -> Error (Other (Printexc.to_string ex))

let check_overwrite overwrite path =
  if (not overwrite) && Sys.file_exists path then
    failwith (strf "file already exists: %s" path)

(* Npy *)

let load_npy path = wrap_exn @@ fun () -> Ok (npy_to_nx (Npy.read_copy path))

let save_npy ?(overwrite = true) path arr =
  wrap_exn @@ fun () ->
  check_overwrite overwrite path;
  let buf = Nx.to_buffer arr in
  let shape = Nx.shape arr in
  Npy.write (Nx_buffer.to_genarray buf shape) path;
  Ok ()

(* Npz *)

let load_npz path =
  wrap_exn @@ fun () ->
  let zi = Npy.Npz.open_in path in
  Fun.protect ~finally:(fun () -> Npy.Npz.close_in zi) @@ fun () ->
  let entries = Npy.Npz.entries zi in
  let archive = Hashtbl.create (List.length entries) in
  List.iter
    (fun name -> Hashtbl.add archive name (npy_to_nx (Npy.Npz.read zi name)))
    entries;
  Ok archive

let load_npz_entry ~name path =
  wrap_exn @@ fun () ->
  let zi = Npy.Npz.open_in path in
  Fun.protect ~finally:(fun () -> Npy.Npz.close_in zi) @@ fun () ->
  match Npy.Npz.read zi name with
  | packed -> Ok (npy_to_nx packed)
  | exception Not_found -> Error (Missing_entry name)

let save_npz ?(overwrite = true) path items =
  wrap_exn @@ fun () ->
  check_overwrite overwrite path;
  let zo = Npy.Npz.open_out path in
  Fun.protect ~finally:(fun () -> Npy.Npz.close_out zo) @@ fun () ->
  List.iter
    (fun (name, P nx) ->
      let buf = Nx.to_buffer nx in
      Npy.Npz.write zo name (Nx_buffer.to_genarray buf (Nx.shape nx)))
    items;
  Ok ()