package nx

  1. Overview
  2. Docs

Source file nx_safetensors.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
open Bigarray_ext
open Error
open Packed_nx

let load_safetensor path =
  try
    let ic = open_in_bin path in
    let len = in_channel_length ic in
    let buffer = really_input_string ic len in
    close_in ic;
    match Safetensors.deserialize buffer with
    | Ok safetensors ->
        let tensors = Safetensors.tensors safetensors in
        let result = Hashtbl.create (List.length tensors) in
        List.iter
          (fun (name, (view : Safetensors.tensor_view)) ->
            let open Safetensors in
            let shape = Array.of_list view.shape in
            let num_elems = Array.fold_left ( * ) 1 shape in

            (* Convert safetensors dtype to Nx array *)
            let process_float32 () =
              let ba = Array1.create Float32 c_layout num_elems in
              for i = 0 to num_elems - 1 do
                let offset = view.offset + (i * 4) in
                let b0 = Char.code view.data.[offset] in
                let b1 = Char.code view.data.[offset + 1] in
                let b2 = Char.code view.data.[offset + 2] in
                let b3 = Char.code view.data.[offset + 3] in
                let bits =
                  Int32.(
                    logor
                      (shift_left (of_int b3) 24)
                      (logor
                         (shift_left (of_int b2) 16)
                         (logor (shift_left (of_int b1) 8) (of_int b0))))
                in
                Array1.unsafe_set ba i (Int32.float_of_bits bits)
              done;
              let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
              Nx.reshape shape nx_arr
            in

            let process_float64 () =
              let ba = Array1.create Float64 c_layout num_elems in
              for i = 0 to num_elems - 1 do
                let offset = view.offset + (i * 8) in
                let bits = Safetensors.read_u64_le view.data offset in
                Array1.unsafe_set ba i (Int64.float_of_bits bits)
              done;
              let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
              Nx.reshape shape nx_arr
            in

            let process_int32 () =
              let ba = Array1.create Int32 c_layout num_elems in
              for i = 0 to num_elems - 1 do
                let offset = view.offset + (i * 4) in
                let b0 = Char.code view.data.[offset] in
                let b1 = Char.code view.data.[offset + 1] in
                let b2 = Char.code view.data.[offset + 2] in
                let b3 = Char.code view.data.[offset + 3] in
                let bits =
                  Int32.(
                    logor
                      (shift_left (of_int b3) 24)
                      (logor
                         (shift_left (of_int b2) 16)
                         (logor (shift_left (of_int b1) 8) (of_int b0))))
                in
                Array1.unsafe_set ba i bits
              done;
              let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
              Nx.reshape shape nx_arr
            in

            match view.dtype with
            | F32 -> Hashtbl.add result name (P (process_float32 ()))
            | F64 -> Hashtbl.add result name (P (process_float64 ()))
            | I32 -> Hashtbl.add result name (P (process_int32 ()))
            | F16 ->
                let ba = Array1.create Float16 c_layout num_elems in
                for i = 0 to num_elems - 1 do
                  let offset = view.offset + (i * 2) in
                  let _b0 = Char.code view.data.[offset] in
                  let _b1 = Char.code view.data.[offset + 1] in
                  (* TODO: Convert bits to float16 properly *)
                  Array1.unsafe_set ba i 0.0
                done;
                let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
                Hashtbl.add result name (P (Nx.reshape shape nx_arr))
            | BF16 ->
                let ba = Array1.create Bfloat16 c_layout num_elems in
                for i = 0 to num_elems - 1 do
                  let offset = view.offset + (i * 2) in
                  let _b0 = Char.code view.data.[offset] in
                  let _b1 = Char.code view.data.[offset + 1] in
                  (* TODO: Convert bits to bfloat16 properly *)
                  Array1.unsafe_set ba i 0.0
                done;
                let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
                Hashtbl.add result name (P (Nx.reshape shape nx_arr))
            | _ ->
                Printf.eprintf
                  "Warning: Skipping tensor '%s' with unsupported dtype %s\n"
                  name
                  (Safetensors.dtype_to_string view.dtype))
          tensors;
        Ok result
    | Error err -> Error (Format_error (Safetensors.string_of_error err))
  with
  | Sys_error msg -> Error (Io_error msg)
  | ex -> Error (Other (Printexc.to_string ex))

let save_safetensor ?(overwrite = true) path items =
  try
    if (not overwrite) && Sys.file_exists path then
      Error (Io_error (Printf.sprintf "File '%s' already exists" path))
    else
      let tensor_views =
        List.map
          (fun (name, P arr) ->
            let shape = Array.to_list (Nx.shape arr) in
            let ba = Nx.to_bigarray_ext arr in
            let num_elems = Array.fold_left ( * ) 1 (Nx.shape arr) in

            (* Create data buffer and determine dtype based on Nx array type *)
            let dtype, data =
              match Genarray.kind ba with
              | Float32 ->
                  let bytes = Bytes.create (num_elems * 4) in
                  let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
                  let ba1 = array1_of_genarray ba_flat in
                  for i = 0 to num_elems - 1 do
                    let bits = Int32.bits_of_float (Array1.unsafe_get ba1 i) in
                    let offset = i * 4 in
                    Bytes.set bytes offset
                      (Char.chr (Int32.to_int (Int32.logand bits 0xffl)));
                    Bytes.set bytes (offset + 1)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right bits 8) 0xffl)));
                    Bytes.set bytes (offset + 2)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right bits 16) 0xffl)));
                    Bytes.set bytes (offset + 3)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right bits 24) 0xffl)))
                  done;
                  (Safetensors.F32, Bytes.unsafe_to_string bytes)
              | Float64 ->
                  let bytes = Bytes.create (num_elems * 8) in
                  let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
                  let ba1 = array1_of_genarray ba_flat in
                  for i = 0 to num_elems - 1 do
                    let bits = Int64.bits_of_float (Array1.unsafe_get ba1 i) in
                    Safetensors.write_u64_le bytes (i * 8) bits
                  done;
                  (Safetensors.F64, Bytes.unsafe_to_string bytes)
              | Int32 ->
                  let bytes = Bytes.create (num_elems * 4) in
                  let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
                  let ba1 = array1_of_genarray ba_flat in
                  for i = 0 to num_elems - 1 do
                    let value = Array1.unsafe_get ba1 i in
                    let offset = i * 4 in
                    Bytes.set bytes offset
                      (Char.chr (Int32.to_int (Int32.logand value 0xffl)));
                    Bytes.set bytes (offset + 1)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right value 8) 0xffl)));
                    Bytes.set bytes (offset + 2)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right value 16) 0xffl)));
                    Bytes.set bytes (offset + 3)
                      (Char.chr
                         (Int32.to_int
                            (Int32.logand (Int32.shift_right value 24) 0xffl)))
                  done;
                  (Safetensors.I32, Bytes.unsafe_to_string bytes)
              | Float16 ->
                  (* For float16, we need to copy the raw bytes directly *)
                  let bytes = Bytes.create (num_elems * 2) in
                  (* Copy raw memory - float16 is already in the right format *)
                  for i = 0 to (num_elems * 2) - 1 do
                    Bytes.set bytes i '\000'
                    (* Placeholder - proper implementation would copy raw
                       bytes *)
                  done;
                  (Safetensors.F16, Bytes.unsafe_to_string bytes)
              | Bfloat16 ->
                  (* For bfloat16, we need to copy the raw bytes directly *)
                  let bytes = Bytes.create (num_elems * 2) in
                  (* Copy raw memory - bfloat16 is already in the right
                     format *)
                  for i = 0 to (num_elems * 2) - 1 do
                    Bytes.set bytes i '\000'
                    (* Placeholder - proper implementation would copy raw
                       bytes *)
                  done;
                  (Safetensors.BF16, Bytes.unsafe_to_string bytes)
              | _ ->
                  fail_msg "Unsupported dtype for safetensors: %s"
                    (Nx_core.Dtype.of_bigarray_ext_kind (Genarray.kind ba)
                    |> Nx_core.Dtype.to_string)
            in

            match Safetensors.tensor_view_new ~dtype ~shape ~data with
            | Ok view -> (name, view)
            | Error err ->
                fail_msg "Failed to create tensor view for '%s': %s" name
                  (Safetensors.string_of_error err))
          items
      in
      match Safetensors.serialize_to_file tensor_views None path with
      | Ok () -> Ok ()
      | Error err -> Error (Format_error (Safetensors.string_of_error err))
  with
  | Sys_error msg -> Error (Io_error msg)
  | ex -> Error (Other (Printexc.to_string ex))