package nx

  1. Overview
  2. Docs

Source file nx_io.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
open Bigarray_ext
open Error

(* Type definitions *)

type packed_nx = Packed_nx.t = P : ('a, 'b) Nx.t -> packed_nx
type archive = (string, packed_nx) Hashtbl.t

module Safe = struct
  type error = Error.t =
    | Io_error of string
    | Format_error of string
    | Unsupported_dtype
    | Unsupported_shape
    | Missing_entry of string
    | Other of string

  (* Image dimensions *)

  type nx_dims = [ `Gray of int * int | `Color of int * int * int ]

  let get_nx_dims arr : nx_dims =
    match Nx.shape arr with
    | [| h; w |] -> `Gray (h, w)
    | [| h; w; c |] -> `Color (h, w, c)
    | s ->
        fail_msg "Invalid nx dimensions: expected 2 or 3, got %d (%s)"
          (Array.length s)
          (Array.to_list s |> List.map string_of_int |> String.concat "x")

  let load_image ?grayscale path =
    let grayscale = Option.value grayscale ~default:false in
    try
      let desired_channels = if grayscale then 1 else 3 in
      match Stb_image.load ~channels:desired_channels path with
      | Ok img ->
          let h = Stb_image.height img in
          let w = Stb_image.width img in
          let c = Stb_image.channels img in
          let buffer = Stb_image.data img in
          let nd = Nx.of_bigarray_ext (genarray_of_array1 buffer) in
          let shape = if c = 1 then [| h; w |] else [| h; w; c |] in
          Ok (Nx.reshape shape nd)
      | Error (`Msg msg) -> Error (Format_error msg)
    with
    | Sys_error msg -> Error (Io_error msg)
    | ex -> Error (Other (Printexc.to_string ex))

  let save_image ?(overwrite = true) path img =
    try
      (* Check if file exists and overwrite is false *)
      if (not overwrite) && Sys.file_exists path then
        Error (Io_error (Printf.sprintf "File '%s' already exists" path))
      else
        let h, w, c =
          match get_nx_dims img with
          | `Gray (h, w) -> (h, w, 1)
          | `Color (h, w, c) -> (h, w, c)
        in
        (* Ensure the input array is uint8 *)
        let data_gen = Nx.to_bigarray_ext img in
        let data =
          match Genarray.kind data_gen with
          | Int8_unsigned -> array1_of_genarray data_gen
        in
        let extension = Filename.extension path |> String.lowercase_ascii in
        match extension with
        | ".png" ->
            Stb_image_write.png path ~w ~h ~c data;
            Ok ()
        | ".bmp" ->
            Stb_image_write.bmp path ~w ~h ~c data;
            Ok ()
        | ".tga" ->
            Stb_image_write.tga path ~w ~h ~c data;
            Ok ()
        | ".jpg" | ".jpeg" ->
            Stb_image_write.jpg path ~w ~h ~c ~quality:90 data;
            Ok ()
        | _ ->
            Error
              (Format_error
                 (Printf.sprintf
                    "Unsupported image format: '%s'. Use .png, .bmp, .tga, .jpg"
                    extension))
    with
    | Sys_error msg -> Error (Io_error msg)
    | Invalid_argument msg -> Error (Other msg)
    | Failure msg -> Error (Other msg)
    | ex -> Error (Other (Printexc.to_string ex))

  let load_npy path = Nx_npy.load_npy path

  let save_npy ?(overwrite = true) path arr =
    Nx_npy.save_npy ~overwrite path arr

  let load_npz path = Nx_npy.load_npz path
  let load_npz_member ~name path = Nx_npy.load_npz_member ~name path

  let save_npz ?(overwrite = true) path items =
    Nx_npy.save_npz ~overwrite path items

  (* Conversions from packed arrays *)

  let as_float16 = Packed_nx.as_float16
  let as_float32 = Packed_nx.as_float32
  let as_float64 = Packed_nx.as_float64
  let as_int8 = Packed_nx.as_int8
  let as_int16 = Packed_nx.as_int16
  let as_int32 = Packed_nx.as_int32
  let as_int64 = Packed_nx.as_int64
  let as_uint8 = Packed_nx.as_uint8
  let as_uint16 = Packed_nx.as_uint16
  let as_complex32 = Packed_nx.as_complex32
  let as_complex64 = Packed_nx.as_complex64

  (* SafeTensors support *)
  let load_safetensor path = Nx_safetensors.load_safetensor path

  let save_safetensor ?overwrite path items =
    Nx_safetensors.save_safetensor ?overwrite path items
end

(* Main module functions - these fail directly instead of returning results *)

let unwrap_result = function
  | Ok v -> v
  | Error err -> failwith (Error.to_string err)

let as_float16 packed = Packed_nx.as_float16 packed |> unwrap_result
let as_float32 packed = Packed_nx.as_float32 packed |> unwrap_result
let as_float64 packed = Packed_nx.as_float64 packed |> unwrap_result
let as_int8 packed = Packed_nx.as_int8 packed |> unwrap_result
let as_int16 packed = Packed_nx.as_int16 packed |> unwrap_result
let as_int32 packed = Packed_nx.as_int32 packed |> unwrap_result
let as_int64 packed = Packed_nx.as_int64 packed |> unwrap_result
let as_uint8 packed = Packed_nx.as_uint8 packed |> unwrap_result
let as_uint16 packed = Packed_nx.as_uint16 packed |> unwrap_result
let as_complex32 packed = Packed_nx.as_complex32 packed |> unwrap_result
let as_complex64 packed = Packed_nx.as_complex64 packed |> unwrap_result

let load_image ?grayscale path =
  Safe.load_image ?grayscale path |> unwrap_result

let save_image ?overwrite path img =
  Safe.save_image ?overwrite path img |> unwrap_result

let load_npy path = Safe.load_npy path |> unwrap_result

let save_npy ?overwrite path arr =
  Safe.save_npy ?overwrite path arr |> unwrap_result

let load_npz path = Safe.load_npz path |> unwrap_result

let load_npz_member ~name path =
  Safe.load_npz_member ~name path |> unwrap_result

let save_npz ?overwrite path items =
  Safe.save_npz ?overwrite path items |> unwrap_result

let load_safetensor path = Safe.load_safetensor path |> unwrap_result

let save_safetensor ?overwrite path items =
  Safe.save_safetensor ?overwrite path items |> unwrap_result