Source file kaun_datasets.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
(** Ready-to-use datasets for machine learning *)
let src = Logs.Src.create "kaun.datasets" ~doc:"Kaun datasets module"
module Log = (val Logs.src_log src : Logs.LOG)
(** {1 Core Types} *)
type ('elt, 'kind) tensor_dataset = ('elt, 'kind) Rune.t Kaun.Dataset.t
(** {1 Vision Datasets} *)
let mnist ?(train = true) ?(flatten = false) ?(normalize = true)
?(data_format = `NCHW) ?cache_dir:_ () =
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_mnist () in
let x, y = if train then (x_train, y_train) else (x_test, y_test) in
let x = Nx.cast Nx.float32 x in
let y = Nx.cast Nx.float32 y in
let x = Rune.of_bigarray (Nx.to_bigarray x) in
let y = Rune.of_bigarray (Nx.to_bigarray y) in
let x =
if normalize then Rune.div x (Rune.scalar Rune.float32 255.0) else x
in
let x =
match data_format with
| `NCHW ->
let shape = Rune.shape x in
let n, h, w, _ = (shape.(0), shape.(1), shape.(2), shape.(3)) in
let x_reshaped = Rune.reshape [| n; h; w; 1 |] x in
Rune.transpose x_reshaped ~axes:[ 0; 3; 1; 2 ]
| `NHWC ->
x
in
let x =
if flatten then
let shape = Rune.shape x in
let n = shape.(0) in
Rune.reshape [| n; 28 * 28 |] x
else x
in
let y = Rune.squeeze y ~axes:[ 1 ] in
Kaun.Dataset.from_tensors (x, y)
let cifar10 ?(train = true) ?(normalize = true) ?(data_format = `NCHW)
?(augmentation = false) ?cache_dir:_ () =
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_cifar10 () in
let x, y = if train then (x_train, y_train) else (x_test, y_test) in
let x = Nx.cast Nx.float32 x in
let y = Nx.cast Nx.float32 y in
let x = Rune.of_bigarray (Nx.to_bigarray x) in
let y = Rune.of_bigarray (Nx.to_bigarray y) in
let x =
if normalize then
let mean_arr =
Bigarray.Array1.of_array Bigarray.float32 Bigarray.c_layout
[| 0.485; 0.456; 0.406 |]
in
let std_arr =
Bigarray.Array1.of_array Bigarray.float32 Bigarray.c_layout
[| 0.229; 0.224; 0.225 |]
in
let mean = Rune.of_bigarray (Bigarray.genarray_of_array1 mean_arr) in
let std = Rune.of_bigarray (Bigarray.genarray_of_array1 std_arr) in
let x = Rune.div x (Rune.scalar Rune.float32 255.0) in
let mean = Rune.reshape [| 1; 3; 1; 1 |] mean in
let std = Rune.reshape [| 1; 3; 1; 1 |] std in
Rune.div (Rune.sub x mean) std
else x
in
let x =
match data_format with
| `NCHW -> x
| `NHWC ->
Rune.transpose x ~axes:[ 0; 2; 3; 1 ]
in
let y = Rune.squeeze y ~axes:[ 1 ] in
let dataset = Kaun.Dataset.from_tensors (x, y) in
if augmentation && train then
dataset
else dataset
let fashion_mnist ?(train = true) ?(flatten = false) ?(normalize = true)
?(data_format = `NCHW) ?cache_dir:_ () =
let (x_train, y_train), (x_test, y_test) =
Nx_datasets.load_fashion_mnist ()
in
let x, y = if train then (x_train, y_train) else (x_test, y_test) in
let x = Nx.cast Nx.float32 x in
let y = Nx.cast Nx.float32 y in
let x = Rune.of_bigarray (Nx.to_bigarray x) in
let y = Rune.of_bigarray (Nx.to_bigarray y) in
let x =
if normalize then Rune.div x (Rune.scalar Rune.float32 255.0) else x
in
let x =
match data_format with
| `NCHW ->
let shape = Rune.shape x in
let n, h, w, _ = (shape.(0), shape.(1), shape.(2), shape.(3)) in
let x_reshaped = Rune.reshape [| n; h; w; 1 |] x in
Rune.transpose x_reshaped ~axes:[ 0; 3; 1; 2 ]
| `NHWC ->
x
in
let x =
if flatten then
let shape = Rune.shape x in
let n = shape.(0) in
Rune.reshape [| n; 28 * 28 |] x
else x
in
let y = Rune.squeeze y ~axes:[ 1 ] in
Kaun.Dataset.from_tensors (x, y)
(** {1 Text Datasets} *)
let imdb ?(train = true) ?tokenizer ?(max_length = 512) ?cache_dir:_ () =
let num_samples = if train then 25000 else 25000 in
let texts =
Array.init num_samples (fun i ->
if i mod 2 = 0 then
"This movie was absolutely fantastic great amazing wonderful"
else "This movie was terrible awful bad horrible worst")
in
let labels =
Array.init num_samples (fun i ->
Rune.scalar Rune.float32 (float_of_int (i mod 2)))
in
let text_dataset = Kaun.Dataset.from_array texts in
let tokenized =
match tokenizer with
| Some tok ->
Kaun.Dataset.tokenize tok ~max_length ~truncation:true text_dataset
| None ->
Kaun.Dataset.tokenize Kaun.Dataset.whitespace_tokenizer ~max_length
~truncation:true text_dataset
in
let label_dataset = Kaun.Dataset.from_array labels in
Kaun.Dataset.zip tokenized label_dataset
let wikitext ?(dataset_name = `Wikitext2) ?tokenizer ?(sequence_length = 1024)
?cache_dir:_ () =
let _ = dataset_name in
let text =
String.concat " "
[
"The quick brown fox jumps over the lazy dog.";
"Machine learning is a subset of artificial intelligence.";
"Neural networks are inspired by biological neurons.";
"Deep learning has revolutionized computer vision.";
"Natural language processing enables machines to understand text.";
]
in
let tokenizer =
Option.value tokenizer ~default:Kaun.Dataset.whitespace_tokenizer
in
let tokens = tokenizer text in
let num_windows =
max 1 ((Array.length tokens - sequence_length - 1) / sequence_length)
in
let windows =
Array.init num_windows (fun i ->
let start = i * sequence_length in
let input_ids = Array.sub tokens start sequence_length in
let target_ids = Array.sub tokens (start + 1) sequence_length in
(input_ids, target_ids))
in
Kaun.Dataset.from_array windows
(** {1 Structured Data} *)
let iris ?(normalize = true) ?(train_split = 0.8) ?shuffle_seed () =
let _ = train_split in
let x, y = Nx_datasets.load_iris () in
let x = Nx.cast Nx.float32 x in
let y = Nx.cast Nx.float32 y in
let x = Rune.of_bigarray (Nx.to_bigarray x) in
let y = Rune.of_bigarray (Nx.to_bigarray y) in
let x =
if normalize then
let mean = Rune.mean x ~axes:[ 1 ] ~keepdims:true in
let std = Rune.std x ~axes:[ 1 ] ~keepdims:true in
Rune.div (Rune.sub x mean) (Rune.add std (Rune.scalar Rune.float32 1e-8))
else x
in
let dataset = Kaun.Dataset.from_tensors (x, y) in
match shuffle_seed with
| Some seed ->
let key = Rune.Rng.key seed in
Kaun.Dataset.shuffle ~rng:key dataset
| None -> dataset
let boston_housing ?(normalize = true) ?(train_split = 0.8) () =
let _ = train_split in
let x, y = Nx_datasets.load_california_housing () in
let x = Nx.cast Nx.float32 x in
let y = Nx.cast Nx.float32 y in
let x = Rune.of_bigarray (Nx.to_bigarray x) in
let y = Rune.of_bigarray (Nx.to_bigarray y) in
let x =
if normalize then
let mean = Rune.mean x ~axes:[ 1 ] ~keepdims:true in
let std = Rune.std x ~axes:[ 1 ] ~keepdims:true in
Rune.div (Rune.sub x mean) (Rune.add std (Rune.scalar Rune.float32 1e-8))
else x
in
Kaun.Dataset.from_tensors (x, y)
(** {1 Dataset Utilities} *)
let download_and_extract ~url ~cache_dir ?( = true) () =
if not (Sys.file_exists cache_dir) then
Sys.command (Printf.sprintf "mkdir -p %s" cache_dir) |> ignore;
let filename = Filename.basename url in
let filepath = Filename.concat cache_dir filename in
if not (Sys.file_exists filepath) then (
Log.info (fun m -> m "Downloading %s to %s..." url filepath);
let cmd = Printf.sprintf "curl -L -o %s %s" filepath url in
match Sys.command cmd with
| 0 -> Log.info (fun m -> m "Download complete")
| _ -> failwith (Printf.sprintf "Failed to download %s" url));
if
extract
&& (Filename.check_suffix filename ".tar.gz"
|| Filename.check_suffix filename ".zip")
then (
let = Filename.chop_extension filepath in
if not (Sys.file_exists extract_dir) then (
Log.info (fun m -> m "Extracting %s..." filename);
let cmd =
if Filename.check_suffix filename ".tar.gz" then
Printf.sprintf "tar -xzf %s -C %s" filepath cache_dir
else Printf.sprintf "unzip -q %s -d %s" filepath cache_dir
in
match Sys.command cmd with
| 0 -> Log.info (fun m -> m "Extraction complete")
| _ -> failwith (Printf.sprintf "Failed to extract %s" filename));
extract_dir)
else filepath
let train_test_split ?(test_size = 0.2) ?(shuffle = true) ?seed dataset =
let total_length =
match Kaun.Dataset.cardinality dataset with
| Kaun.Dataset.Finite n -> n
| _ -> failwith "Cannot split dataset with unknown or infinite cardinality"
in
let test_length = int_of_float (float_of_int total_length *. test_size) in
let train_length = total_length - test_length in
let dataset =
if shuffle then
match seed with
| Some s ->
let key = Rune.Rng.key s in
Kaun.Dataset.shuffle ~rng:key dataset
| None -> Kaun.Dataset.shuffle dataset
else dataset
in
let train_dataset = Kaun.Dataset.take train_length dataset in
let test_dataset =
dataset |> Kaun.Dataset.skip train_length |> Kaun.Dataset.take test_length
in
(train_dataset, test_dataset)