package nx

  1. Overview
  2. Docs

Source file nx_txt.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
open Bigarray_ext

type error = Error.t =
  | Io_error of string
  | Format_error of string
  | Unsupported_dtype
  | Unsupported_shape
  | Missing_entry of string
  | Other of string

type layout = Scalar | Vector of int | Matrix of int * int

let layout_of_shape shape =
  match shape with
  | [||] -> Some Scalar
  | [| n |] -> Some (Vector n)
  | [| rows; cols |] -> Some (Matrix (rows, cols))
  | _ -> None

let option_exists pred = function Some x -> pred x | None -> false

let split_lines_opt = function
  | None -> []
  | Some text -> String.split_on_char '\n' text

let rec trim_trailing_whitespace s =
  if s = "" then ""
  else
    let len = String.length s in
    match s.[len - 1] with
    | ' ' | '\t' -> trim_trailing_whitespace (String.sub s 0 (len - 1))
    | _ -> s

let try_parse f s = try Some (f s) with _ -> None
let float_of_string_opt = try_parse float_of_string
let int_of_string_opt = try_parse int_of_string
let int32_of_string_opt = try_parse Int32.of_string
let int64_of_string_opt = try_parse Int64.of_string
let nativeint_of_string_opt = try_parse Nativeint.of_string

module type SPEC = sig
  type elt
  type kind

  val kind : (elt, kind) Bigarray_ext.kind
  val print : out_channel -> elt -> unit
  val parse : string -> (elt, error) result
end

let invalid_literal dtype_name token =
  Format_error
    (Printf.sprintf "Invalid %s literal: %S" dtype_name
       (trim_trailing_whitespace token))

let out_of_range dtype_name token =
  Format_error
    (Printf.sprintf "Value %S is out of range for %s"
       (trim_trailing_whitespace token)
       dtype_name)

let parse_float dtype token =
  match float_of_string_opt token with
  | Some v -> Ok v
  | None -> Error (invalid_literal dtype token)

let parse_bool token =
  let lowered = String.lowercase_ascii (String.trim token) in
  match lowered with
  | "true" | "t" | "yes" | "y" -> Ok true
  | "false" | "f" | "no" | "n" -> Ok false
  | _ -> (
      match int_of_string_opt lowered with
      | Some 0 -> Ok false
      | Some _ -> Ok true
      | None -> (
          match float_of_string_opt lowered with
          | Some f -> Ok (abs_float f > 0.0)
          | None -> Error (invalid_literal "bool" token)))

let parse_int_with_bounds dtype token ~min ~max =
  match int_of_string_opt token with
  | Some v when v >= min && v <= max -> Ok v
  | Some _ -> Error (out_of_range dtype token)
  | None -> Error (invalid_literal dtype token)

let spec_of_dtype (type a) (type b) (dtype : (a, b) Nx.dtype) :
    (module SPEC with type elt = a and type kind = b) option =
  let dtype_name = Nx_core.Dtype.to_string dtype in
  let module M (X : sig
    type elt
    type kind

    val kind : (elt, kind) Bigarray_ext.kind
    val print : out_channel -> elt -> unit
    val parse : string -> (elt, error) result
  end) =
  struct
    include X
  end in
  let open Nx_core.Dtype in
  match dtype with
  | Float16 ->
      let module S = M (struct
        type elt = float
        type kind = Bigarray_ext.float16_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = Printf.fprintf oc "%.18e" v
        let parse token = parse_float dtype_name token
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Float32 ->
      let module S = M (struct
        type elt = float
        type kind = Bigarray_ext.float32_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = Printf.fprintf oc "%.18e" v
        let parse token = parse_float dtype_name token
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Float64 ->
      let module S = M (struct
        type elt = float
        type kind = Bigarray_ext.float64_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = Printf.fprintf oc "%.18e" v
        let parse token = parse_float dtype_name token
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | BFloat16 ->
      let module S = M (struct
        type elt = float
        type kind = Bigarray_ext.bfloat16_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = Printf.fprintf oc "%.18e" v
        let parse token = parse_float dtype_name token
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Int8 ->
      let module S = M (struct
        type elt = int
        type kind = Bigarray_ext.int8_signed_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (string_of_int v)

        let parse token =
          parse_int_with_bounds dtype_name token ~min:(-128) ~max:127
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | UInt8 ->
      let module S = M (struct
        type elt = int
        type kind = Bigarray_ext.int8_unsigned_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (string_of_int v)
        let parse token = parse_int_with_bounds dtype_name token ~min:0 ~max:255
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Int16 ->
      let module S = M (struct
        type elt = int
        type kind = Bigarray_ext.int16_signed_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (string_of_int v)

        let parse token =
          parse_int_with_bounds dtype_name token ~min:(-32768) ~max:32767
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | UInt16 ->
      let module S = M (struct
        type elt = int
        type kind = Bigarray_ext.int16_unsigned_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (string_of_int v)

        let parse token =
          parse_int_with_bounds dtype_name token ~min:0 ~max:65535
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Int32 ->
      let module S = M (struct
        type elt = int32
        type kind = Bigarray_ext.int32_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (Int32.to_string v)

        let parse token =
          match int32_of_string_opt token with
          | Some v -> Ok v
          | None -> Error (invalid_literal dtype_name token)
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Int64 ->
      let module S = M (struct
        type elt = int64
        type kind = Bigarray_ext.int64_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (Int64.to_string v)

        let parse token =
          match int64_of_string_opt token with
          | Some v -> Ok v
          | None -> Error (invalid_literal dtype_name token)
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Int ->
      let module S = M (struct
        type elt = int
        type kind = Bigarray_ext.int_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (string_of_int v)

        let parse token =
          match int_of_string_opt token with
          | Some v -> Ok v
          | None -> Error (invalid_literal dtype_name token)
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | NativeInt ->
      let module S = M (struct
        type elt = nativeint
        type kind = Bigarray_ext.nativeint_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (Nativeint.to_string v)

        let parse token =
          match nativeint_of_string_opt token with
          | Some v -> Ok v
          | None -> Error (invalid_literal dtype_name token)
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | Bool ->
      let module S = M (struct
        type elt = bool
        type kind = Bigarray_ext.bool_elt

        let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
        let print oc v = output_string oc (if v then "1" else "0")
        let parse = parse_bool
      end) in
      Some (module S : SPEC with type elt = a and type kind = b)
  | _ -> None

let save ?(sep = " ") ?(append = false) ?(newline = "\n") ?header ?footer
    ?(comments = "# ") ~out (type a) (type b) (arr : (a, b) Nx.t) =
  match layout_of_shape (Nx.shape arr) with
  | None -> Error Unsupported_shape
  | Some layout -> (
      match spec_of_dtype (Nx.dtype arr) with
      | None -> Error Unsupported_dtype
      | Some spec_module -> (
          let module S =
            (val spec_module : SPEC with type elt = a and type kind = b)
          in
          let perm = 0o666 in
          let flags =
            if append then [ Open_wronly; Open_creat; Open_append; Open_text ]
            else [ Open_wronly; Open_creat; Open_trunc; Open_text ]
          in
          try
            let oc = open_out_gen flags perm out in
            Fun.protect
              ~finally:(fun () -> close_out oc)
              (fun () ->
                let write_prefixed line =
                  if comments <> "" then output_string oc comments;
                  output_string oc line;
                  output_string oc newline
                in
                List.iter write_prefixed (split_lines_opt header);
                let data =
                  (Nx.to_bigarray_ext arr
                    : (S.elt, S.kind, Bigarray.c_layout) Genarray.t)
                in
                (match layout with
                | Scalar ->
                    let value = Genarray.get data [||] in
                    S.print oc value;
                    output_string oc newline
                | Vector length ->
                    let view = array1_of_genarray data in
                    for j = 0 to length - 1 do
                      if j > 0 then output_string oc sep;
                      S.print oc (Array1.unsafe_get view j)
                    done;
                    output_string oc newline
                | Matrix (rows, cols) ->
                    let view = array2_of_genarray data in
                    for i = 0 to rows - 1 do
                      for j = 0 to cols - 1 do
                        if j > 0 then output_string oc sep;
                        S.print oc (Array2.unsafe_get view i j)
                      done;
                      output_string oc newline
                    done);
                List.iter write_prefixed (split_lines_opt footer);
                Ok ())
          with
          | Sys_error msg -> Error (Io_error msg)
          | Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e))
          ))

let load ?(sep = " ") ?(comments = "#") ?(skiprows = 0) ?max_rows (type a)
    (type b) (dtype : (a, b) Nx.dtype) path =
  if skiprows < 0 then Error (Format_error "skiprows must be non-negative")
  else if option_exists (fun rows -> rows <= 0) max_rows then
    Error (Format_error "max_rows must be strictly positive")
  else
    match spec_of_dtype dtype with
    | None -> Error Unsupported_dtype
    | Some spec_module -> (
        let module S =
          (val spec_module : SPEC with type elt = a and type kind = b)
        in
        try
          let ic = open_in path in
          Fun.protect
            ~finally:(fun () -> close_in ic)
            (fun () ->
              let comment_prefix = String.trim comments in
              let is_comment_line line =
                if comment_prefix = "" then false
                else
                  let trimmed = String.trim line in
                  let len = String.length comment_prefix in
                  String.length trimmed >= len
                  && String.sub trimmed 0 len = comment_prefix
              in
              let split_fields line =
                let trimmed = String.trim line in
                if trimmed = "" then [||]
                else if sep = "" then [| trimmed |]
                else if String.length sep = 1 then
                  trimmed
                  |> String.split_on_char sep.[0]
                  |> List.filter (fun s -> s <> "")
                  |> Array.of_list
                else
                  let len_sep = String.length sep in
                  let len = String.length trimmed in
                  let rec aux acc start =
                    if start >= len then List.rev acc
                    else
                      match String.index_from_opt trimmed start sep.[0] with
                      | None ->
                          let part = String.sub trimmed start (len - start) in
                          if part = "" then List.rev acc
                          else List.rev (part :: acc)
                      | Some idx ->
                          if
                            idx + len_sep <= len
                            && String.sub trimmed idx len_sep = sep
                          then
                            let part = String.sub trimmed start (idx - start) in
                            let acc = if part = "" then acc else part :: acc in
                            aux acc (idx + len_sep)
                          else aux acc (idx + 1)
                  in
                  aux [] 0 |> Array.of_list
              in
              let rows_rev = ref [] in
              let cols = ref None in
              let rows_read = ref 0 in
              let read_error = ref None in
              let rec loop skip_remaining =
                if Option.is_some !read_error then ()
                else if option_exists (fun rows -> !rows_read >= rows) max_rows
                then ()
                else
                  match input_line ic with
                  | line ->
                      if skip_remaining > 0 then loop (skip_remaining - 1)
                      else if is_comment_line line then loop 0
                      else
                        let fields = split_fields line in
                        if Array.length fields = 0 then loop 0
                        else (
                          (match !cols with
                          | None -> cols := Some (Array.length fields)
                          | Some expected ->
                              if Array.length fields <> expected then
                                read_error :=
                                  Some
                                    (Format_error
                                       "Inconsistent number of columns"));
                          if Option.is_none !read_error then (
                            rows_rev := fields :: !rows_rev;
                            incr rows_read);
                          loop 0)
                  | exception End_of_file -> ()
              in
              loop skiprows;
              let parsed_result =
                match (!read_error, !cols, !rows_rev) with
                | Some err, _, _ -> Error err
                | _, None, _ -> Error (Format_error "No data found")
                | _, _, [] -> Error (Format_error "No data found")
                | _, Some col_count, rows_rev_list -> (
                    let rows = List.rev rows_rev_list |> Array.of_list in
                    let row_count = Array.length rows in
                    let dims = [| row_count; col_count |] in
                    let ba = Genarray.create S.kind Bigarray.c_layout dims in
                    let parse_error = ref None in
                    for i = 0 to row_count - 1 do
                      let row = rows.(i) in
                      for j = 0 to col_count - 1 do
                        if Option.is_none !parse_error then
                          match S.parse row.(j) with
                          | Ok value -> Genarray.set ba [| i; j |] value
                          | Error err -> parse_error := Some err
                      done
                    done;
                    match !parse_error with
                    | Some err -> Error err
                    | None ->
                        let tensor = Nx.of_bigarray_ext ba in
                        let result =
                          if row_count = 1 then
                            Nx.reshape [| col_count |] tensor
                          else if col_count = 1 then
                            Nx.reshape [| row_count |] tensor
                          else tensor
                        in
                        Ok result)
              in
              parsed_result)
        with
        | Sys_error msg -> Error (Io_error msg)
        | Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e)))