package neural_nets_lib
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>
A from-scratch Deep Learning framework with an optimizing compiler, shape inference, concise syntax
Install
dune-project
Dependency
Authors
Maintainers
Sources
0.3.3.3.tar.gz
md5=9170d4d98422350c9a73a95adfb795dc
sha512=c1b024a69b1d0338af6e34508dbf6dccf3c2b6cc156e7628c3d7853c7040e225bdfc0a8731bb4db5a97edba90e26439987bfa505154d23af46f119c07ad809ed
doc/src/neural_nets_lib/tensor.ml.html
Source file tensor.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607open Base module Nd = Arrayjit.Ndarray module Tn = Arrayjit.Tnode module Asgns = Arrayjit.Assignments module Idx = Arrayjit.Indexing module Debug_runtime = Arrayjit.Utils.Debug_runtime type tn = Tn.t type asgns = Asgns.t type init_op = Arrayjit.Ops.init_op type fetch_op = Asgns.fetch_op type projections = Arrayjit.Indexing.projections [%%global_debug_log_level Nothing] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] type diff = { grad : (Tn.t[@sexp.opaque]); zero_grads : Asgns.t; backprop : Asgns.t } [@@deriving sexp_of] type t = { forward : Asgns.t; diff : diff option; id : int; value : Tn.t; shape : Shape.t; children : subtensor list; } and subtensor = { subtensor : t; embedded : bool } let rec sexp_of_t t = Sexp.message "Tensor" [ ("id", sexp_of_int t.id); ("label", [%sexp_of: string list] t.value.label); ("forward", [%sexp_of: Asgns.t] t.forward); ("diff", [%sexp_of: diff option] t.diff); ("children", [%sexp_of: subtensor list] t.children); ] and sexp_of_subtensor ch = Sexp.message "child" [ (if ch.embedded then ("", sexp_of_t ch.subtensor) else ("ref-id", sexp_of_int ch.subtensor.id)) ] include Comparator.Make (struct type nonrec t = t let compare t1 t2 = Int.compare t1.id t2.id let sexp_of_t = sexp_of_t end) type session_state = { mutable next_id : int; mutable forward_roots : t Map.M(Int).t; mutable backprop_roots : t Map.M(Int).t; } let session_state = { next_id = 0; forward_roots = Map.empty (module Int); backprop_roots = Map.empty (module Int) } let is_fwd_root t = Map.mem session_state.forward_roots t.id let remove_fwd_root t = session_state.forward_roots <- Map.remove session_state.forward_roots t.id let is_bprop_root t = Map.mem session_state.backprop_roots t.id let remove_bprop_root t = session_state.backprop_roots <- Map.remove session_state.backprop_roots t.id let with_unchanged_roots ~f = let fwd_roots = session_state.forward_roots in let bprop_roots = session_state.backprop_roots in let restore () = session_state.forward_roots <- fwd_roots; session_state.backprop_roots <- bprop_roots in try let result = f () in restore (); result with e -> restore (); raise e let default_value_prec = ref Arrayjit.Ops.single let default_grad_prec = ref Arrayjit.Ops.single let debug_name ?label t = let label = Option.value label ~default:t.value.label in Tn.debug_name ~id:t.id ~label exception Session_error of string * t option [@@deriving sexp] let session_error_printer = function | Session_error (msg, None) -> Some msg | Session_error (msg, Some m) -> Some [%string "For #%{m.id#Int} %{debug_name m}: %{msg}"] | _ -> None let () = Stdlib.Printexc.register_printer session_error_printer let lazy_to_dims shape = lazy (Shape.to_dims shape) let fetch_zeros array shape = Asgns.Fetch { array; fetch_op = Constant 0.; dims = lazy_to_dims shape } let default_init_op = Arrayjit.Ops.Constant_fill { values = [| 0.0 |]; strict = false } let max_sublabel_length = ref 25 let raw_binop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) ~(rhs1_is_grad : bool) ~(t2 : t) ~rhs2_is_grad ~logic : Asgns.t = let shape = t.shape in let shape_logic = Shape.Broadcast (logic, t1.shape, t2.shape) in let local_shape_update = Shape.{ shape; logic = shape_logic; id = get_update_id () } in Shape.propagate_shapes local_shape_update; let projections = lazy (Shape.derive_projections local_shape_update) in let lhs = if lhs_is_grad then (Option.value_exn t.diff).grad else t.value in let rhs1 = if rhs1_is_grad then (Option.value_exn t1.diff).grad else t1.value in let rhs2 = if rhs2_is_grad then (Option.value_exn t2.diff).grad else t2.value in Asgns.Accum_binop { initialize_neutral; accum; lhs; op; rhs1; rhs2; projections } let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) ~(rhs_is_grad : bool) ~logic = let shape = t.shape in let shape_logic = Shape.Transpose (logic, t1.shape) in let local_shape_update = Shape.{ shape; logic = shape_logic; id = get_update_id () } in Shape.propagate_shapes local_shape_update; let projections = lazy (Shape.derive_projections local_shape_update) in let lhs = if lhs_is_grad then (Option.value_exn t.diff).grad else t.value in let rhs = if rhs_is_grad then (Option.value_exn t1.diff).grad else t1.value in Asgns.Accum_unop { initialize_neutral; accum; lhs; op; rhs; projections } type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equal, variants] let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) ?(init_op = default_init_op) ~op_asn ~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t = let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in let children = List.folding_map orig_ts ~init:(Set.empty (module Int)) ~f:(fun used ti -> (Set.add used ti.id, { subtensor = ti; embedded = is_fwd_root ti && not (Set.mem used ti.id) })) in let id = session_state.next_id in session_state.next_id <- session_state.next_id + 1; let shape = make_shape ~debug_name:(Tn.debug_name ~id ~label) ~id in let prec = List.map orig_ts ~f:(fun ti -> ti.value.prec) |> List.reduce ~f:Arrayjit.Ops.promote_prec |> Option.value ~default:!default_value_prec in let rec shape_logics = function | [] -> [ Shape.Terminal init_op ] | [ t1 ] -> [ Shape.Transpose (transpose_op, t1.shape) ] | [ t1; t2 ] -> [ Shape.Broadcast (compose_op, t1.shape, t2.shape) ] | t1 :: (t2 :: _ as ts) -> Shape.Broadcast (compose_op, t1.shape, t2.shape) :: shape_logics ts in let local_shape_updates = List.map ~f:(fun logic -> Shape.{ shape; logic; id = get_update_id () }) @@ shape_logics orig_ts in let dims = lazy_to_dims shape in List.iter ~f:Shape.propagate_shapes local_shape_updates; let projections = lazy (Shape.derive_projections @@ List.hd_exn local_shape_updates) in let v = Tn.create prec ~id ~label ~dims init_op in (* The code needs to be included in the order it was computed due to potential non-tree DAGs. *) let fwds = List.map ordered_ts ~f:(fun ti -> if is_fwd_root ti then ti.forward else Asgns.Noop) in let forward = Asgns.sequential @@ fwds @ [ op_asn ~v ~projections ] in List.iter ordered_ts ~f:(fun ti -> remove_fwd_root ti); if is_prohibit_grad grad_spec || (Fn.non is_require_grad grad_spec && List.for_all orig_ts ~f:(fun ti -> Option.is_none ti.diff)) then ( let tensor = { forward; diff = None; id; value = v; shape; children } in session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor; tensor) else let g_prec = let f ti = Option.map ti.diff ~f:(fun d -> d.grad.Tn.prec) in Option.value ~default:!default_grad_prec @@ List.reduce ~f:Arrayjit.Ops.promote_prec @@ List.filter_map orig_ts ~f in let grad_id = session_state.next_id in session_state.next_id <- session_state.next_id + 1; let g = Tn.create g_prec ~id:grad_id ~label:("grad" :: label) ~dims default_init_op in let dcode ti = Option.value_map ti.diff ~default:Asgns.Noop in let is_bck_root ti = Map.mem session_state.backprop_roots ti.id in let zero_grads = let zero_g = dcode ~f:(fun diff -> diff.zero_grads) in let zeros = List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then zero_g ti else Asgns.Noop) in Asgns.sequential @@ zeros @ [ fetch_zeros g shape ] in (* The code needs to be included in the reverse order to which it was computed! This guarantees that all ancestors of a node are backpropagated before the node is backpropagated, even for non-tree DAGs. *) let backprop = let bprop = dcode ~f:(fun diff -> diff.backprop) in let bcks = List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then bprop ti else Asgns.Noop) in Asgns.sequential @@ (grad_asn ~v ~g ~projections :: List.rev bcks) in List.iter ordered_ts ~f:(fun ti -> session_state.backprop_roots <- Map.remove session_state.backprop_roots ti.id); (* The order is not relevant, we keep the same order as in backprop for readability. *) let diff = Some { grad = g; zero_grads; backprop } in let tensor = { forward; diff; id; value = v; shape; children } in session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor; session_state.backprop_roots <- Map.add_exn session_state.backprop_roots ~key:id ~data:tensor; tensor let binop ~label ?compose_op ~op_asn ~grad_asn ?grad_spec t1 t2 = let op_asn ~v ~projections = op_asn ~v ~t1 ~t2 ~projections in let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~t2 ~projections in op ~label ?compose_op ?transpose_op:None ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1; t2 ] let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 = let op_asn ~v ~projections = op_asn ~v ~t1 ~projections in let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~projections in op ~label ?compose_op:None ?transpose_op ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1 ] let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ?init_op ?fetch_op () = let op_asn ~v ~projections = let open Asgns in let dims = lazy (Lazy.force projections).Idx.lhs_dims in match (fetch_op, init_op) with | None, Some (Arrayjit.Ops.Constant_fill { values = [| _ |]; strict = _ }) when not (is_require_grad grad_spec) -> (* The scalar literal case. *) let fetch_op = match init_op with | Some (Arrayjit.Ops.Constant_fill { values = [| c |]; _ }) -> Constant c | _ -> assert false in Fetch { array = v; fetch_op; dims } | None, _ -> Noop | Some fetch_op, _ -> let fetch_op = fetch_op ~v in (match fetch_op with | Constant _ | Slice _ | Embed_symbol _ -> () | Imported _ -> (* Note: [Imported] can be used for merging across devices. But, some use cases of [Imported] will require a hosted tensor node. *) Tn.update_memory_mode v Materialized 22); Fetch { array = v; fetch_op; dims } in let grad_asn ~v:_ ~g:_ ~projections:_ = Asgns.Noop in let make_shape = Shape.make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced () in op ~label ?compose_op:None ?transpose_op:None ?init_op ~op_asn ~grad_asn ~grad_spec make_shape [] let float_to_label v = Float.to_string v let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c = (* Note: no axis label so that we do not conflict with user labels. *) let label = float_to_label c :: label in let init_op = Arrayjit.Ops.Constant_fill { values = [| c |]; strict = true } in let t = term ~label ~grad_spec ~batch_dims:[] ~input_dims:[] ~init_op in let t = match axis_label with | None -> t ~output_dims:[ 1 ] () | Some axis_label -> t ~output_axes:[ (axis_label, 1) ] () in Tn.update_memory_mode t.value Effectively_constant 24; t let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?(strict = true) values = let to_dim_list dims axes = Option.value ~default:[] @@ Option.first_some dims @@ Option.map axes ~f:(List.map ~f:snd) in let batch_ds = to_dim_list batch_dims batch_axes in let output_ds = to_dim_list output_dims output_axes in let input_ds = to_dim_list input_dims input_axes in let op_label = Stdlib.Format.pp_set_geometry Stdlib.Format.str_formatter ~max_indent:!max_sublabel_length ~margin:(!max_sublabel_length * 2); let dims = Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list in let ndarr = Nd.create_array Arrayjit.Ops.double ~dims (Constant_fill { values; strict }) in let ( ! ) = List.length in Nd.pp_array_inline ~num_batch_axes:!batch_ds ~num_output_axes:!output_ds ~num_input_axes:!input_ds Stdlib.Format.str_formatter ndarr; Stdlib.Format.flush_str_formatter () in let op_label = if String.contains op_label '\n' then "c" ^ Idx.dims_to_string @@ Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list else op_label in let label = op_label :: label in let batch_dims = Option.first_some batch_dims @@ Option.some_if (Option.is_none batch_axes) [] in let input_dims = Option.first_some input_dims @@ Option.some_if (Option.is_none input_axes) [] in let output_dims = Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) [] in let t = term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ~deduced:Not_constrained ~init_op:(Constant_fill { values; strict }) () in Tn.update_memory_mode t.value Effectively_constant 24; t let param ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?(strict = false) ?values label = let init_op = match values with | Some values -> Arrayjit.Ops.Constant_fill { values; strict } | None -> Standard_uniform in let t = term ~label:[ label ] ~grad_spec:Require_grad ~batch_dims:[] ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ~init_op () in let v = t.value in (* It is convenient to use the param syntax for volatiles (mutable inputs). *) Tn.update_memory_mode v (Hosted Nonconstant) 24; (* In principle, gradients can even be local, if a single jitted block does forward, backprop, and update computations. Use-cases needing [Materialized] gradients need to request that before any jitting. *) let g = (Option.value_exn t.diff).grad in Tn.update_memory_mode g Never_virtual 26; t let rec iter_embedded_arrays ~f t = f t.value; Option.iter t.diff ~f:(fun diff -> f diff.grad); List.iter ~f:(fun ch -> if ch.embedded then iter_embedded_arrays ~f ch.subtensor) t.children let consume_forward_code t = if not @@ is_fwd_root t then raise @@ Session_error ( "Tensor.consume_forward_code: tensor is not a root for tnode: " ^ debug_name t ~label:t.value.label, Some t ); let unsafe_roots = Map.data session_state.forward_roots |> List.filter ~f:(fun r -> not (List.is_empty r.children || r.id = t.id)) in if not @@ List.is_empty unsafe_roots then raise @@ Session_error ( [%string {|Tensor.consume_forward_code for %{debug_name t ~label:t.value.label}: found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_name unsafe_roots}|}], Some t ); remove_fwd_root t; t.forward let consume_backprop_code t = let diff = Option.value_or_thunk t.diff ~default:(fun () -> raise @@ Session_error ( "Tensor.consume_backprop_code: tensor is not differentiable for tnode: " ^ debug_name t ~label:t.value.label, Some t )) in if not @@ is_bprop_root t then raise @@ Session_error ( "Tensor.consume_backprop_code: tensor is not a root for tnode: " ^ debug_name t ~label:diff.grad.label, Some t ); let unsafe_roots = Map.data session_state.backprop_roots |> List.filter ~f:(fun r -> not (List.is_empty r.children || r.id = t.id)) in if not @@ List.is_empty unsafe_roots then raise @@ Session_error ( [%string {|Tensor.consume_backprop_code for %{debug_name t ~label:diff.grad.label}: found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_name unsafe_roots}|}], Some t ); remove_bprop_root t; (diff.zero_grads, diff.backprop) let header t = let v_dims_s = Tn.dims_to_string t.value in let g_dims_s = match t.diff with None -> "<no-grad>" | Some diff -> Tn.dims_to_string diff.grad in let dims_s = if String.equal v_dims_s g_dims_s then "dims " ^ v_dims_s else "dims val " ^ v_dims_s ^ " grad " ^ g_dims_s in "#" ^ Int.to_string t.id ^ " " ^ Tn.label t.value ^ " " ^ dims_s ^ " [" ^ String.concat ~sep:"," (List.map t.children ~f:(fun { subtensor = { id; _ }; _ } -> Int.to_string id)) ^ "]" (*^" "^PrintBox_text.to_string (PrintBox.Simple.to_box v.label)*) let lazy_optional_payload ~present ~missing v = if Lazy.is_val v then match Lazy.force v with | Some p -> present p | None -> `Vlist (false, [ `Text (missing ()); `Text "<void>" ]) else `Vlist (false, [ `Text (missing ()); `Text "<not-in-yet> " ]) type array_print_style = [ `Default | `Inline | `Label_layout of (string * int) list | `N5_layout of string ] let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_value ~with_grad t = let rec to_dag { subtensor = t; embedded } : PrintBox_utils.dag = let id = Int.to_string t.id in let children = if single_node then [] else List.map ~f:to_dag t.children in let indices = Shape.default_display_indices t.shape in let labels = Shape.to_labels t.shape in let where_located a = match a.Tn.memory_mode with | None -> "<waiting>" | Some (m, prov) -> [%string "<%{Sexp.to_string_hum @@ Tn.sexp_of_memory_mode m} %{prov#Int}>"] in let txt = if with_id then "#" ^ id ^ " " ^ Tn.label t.value (* ^ " DEBUG: " ^ where_located t.value *) else Tn.label t.value in let grad_txt diff = let label = Tn.label diff.grad in let label = if String.is_substring (String.lowercase label) ~substring:"grad" then label else label ^ " Gradient" in if with_id then "#" ^ Int.to_string diff.grad.id ^ " " ^ label (* ^ " DEBUG: " ^ where_located diff.grad *) else label in let add_shape nodes = if with_shape then let shape = `Box (PrintBox.asprintf "%a" Sexp.pp_hum ([%sexp_of: Shape.t] t.shape)) in `Vlist (false, nodes @ [ shape ]) else `Vlist (false, nodes) in match (not embedded, with_value, with_grad, t.diff) with | true, _, _, _ -> `Embed_subtree_ID (Int.to_string t.id) | _, false, false, _ | _, false, true, None -> `Subtree_with_ID (id, `Tree (add_shape [ `Text txt ], children)) | _, true, false, _ | _, true, true, None -> let node = lazy_optional_payload t.value.array ~present:(fun v_array -> `Box (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array)) ~missing:(fun () -> txt ^ " " ^ where_located t.value) in `Subtree_with_ID (id, `Tree (add_shape [ node ], children)) | _, false, true, Some diff -> let prefix = grad_txt diff in let node = match Lazy.force diff.grad.array with | Some g_array -> `Box (Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array) | None -> `Text (prefix ^ " " ^ where_located diff.grad) in `Subtree_with_ID (id, `Tree (add_shape [ node ], children)) | _, true, true, Some diff -> let node = let value = lazy_optional_payload t.value.array ~present:(fun v_array -> `Box (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array)) ~missing:(fun () -> txt ^ " " ^ where_located t.value) in let grad = lazy_optional_payload diff.grad.array ~present:(fun g_array -> `Box (Nd.render_array ~brief:true ~prefix:(grad_txt diff) ?entries_per_axis ~labels ~indices g_array)) ~missing:(fun () -> grad_txt diff ^ " " ^ where_located diff.grad) in `Vlist (false, [ value; grad ]) in `Subtree_with_ID (id, `Tree (add_shape [ node ], children)) in to_dag { subtensor = t; embedded = true } let to_printbox ?single_node ?entries_per_axis ?(with_id = false) ?(with_shape = false) ?(with_value = true) ~with_grad ~depth t = to_dag ?single_node ?entries_per_axis ~with_id ~with_shape ~with_value ~with_grad t |> PrintBox_utils.reformat_dag depth let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (style : array_print_style) t = let sh = t.shape in let label = Tn.label t.value in let prefix = "[" ^ Int.to_string t.id ^ "]: " ^ label ^ " shape " ^ Shape.to_string_hum ~style:`Axis_number_and_size sh ^ " " in let grad_txt diff = let label = Tn.label diff.grad in if String.is_substring (String.lowercase label) ~substring:"grad" then label else label ^ " Gradient" in let labels = Shape.to_labels t.shape in let indices = match style with | `Default -> Shape.default_display_indices sh | `N5_layout priorities -> let f : (string, int) Either.t -> int = function | Either.Second i -> i | First _ -> invalid_arg "`N5_layout requires integer-only labels" in let p_labels = Shape.(axis_labels @@ axis_labels_of_spec priorities) in (Shape.axis_map_to_dims_index p_labels : (string, int) Either.t array) |> Array.map ~f | `Label_layout label_idcs -> let inv_labels = Array.mapi labels ~f:(fun i l -> (l, i)) |> Array.to_list |> Map.of_alist (module String) in let inv_labels = match inv_labels with | `Duplicate_key l -> raise @@ Session_error ("`Label_layout found a repeating label: " ^ l, Some t) | `Ok inv_labels -> inv_labels in let result = Array.create ~len:(Array.length labels) 0 in List.iter label_idcs ~f:(fun (l, priority) -> match Map.find inv_labels l with | Some pos -> result.(pos) <- priority | None -> raise @@ Session_error ("`Label_layout label not found in shape: " ^ l, Some t)); result | `Inline -> [||] in let needs_spec = Array.exists ~f:(Fn.non String.is_empty) labels || Shape.(List.exists ~f:Row.(equal_dim @@ get_dim ~d:1 ()) sh.input.dims) in let axes_spec = if needs_spec then Some (Shape.to_string_hum ~style:`Only_labels sh) else None in let num_batch_axes = List.length sh.batch.dims in let num_input_axes = List.length sh.input.dims in let num_output_axes = List.length sh.output.dims in (* TODO: code sharing with [to_dag] *) (if not (force || Lazy.is_val t.value.array) then Stdlib.Format.printf "%s <not-in-yet>@ " prefix else match (style, t.value.array) with | `Inline, (lazy None) -> Stdlib.Format.printf "<virtual>@ " | `Inline, (lazy (Some arr)) -> Nd.pp_array_inline (Stdlib.Format.get_std_formatter ()) ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr | _, (lazy None) -> Stdlib.Format.printf "<virtual>@ " | _, (lazy (Some arr)) -> Nd.pp_array (Stdlib.Format.get_std_formatter ()) ~prefix ~labels ~indices arr; Stdlib.Format.print_newline ()); if with_grad then Option.iter t.diff ~f:(fun diff -> if not (force || Lazy.is_val diff.grad.array) then Stdlib.Format.printf "%s <not-in-yet>@ " (grad_txt diff) else match (style, diff.grad.array) with | `Inline, (lazy (Some arr)) -> Nd.pp_array_inline (Stdlib.Format.get_std_formatter ()) ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr; Stdlib.Format.print_newline () | _, (lazy (Some arr)) -> Nd.pp_array (Stdlib.Format.get_std_formatter ()) ~prefix:(prefix ^ " " ^ grad_txt diff) ~labels ~indices arr; Stdlib.Format.print_newline () | _, (lazy None) -> Stdlib.Format.printf "%s <virtual>@ " (grad_txt diff)); if with_code then ( (match t.forward with | Noop -> () | fwd_code -> Stdlib.Format.printf "@[<v 2>Current forward body:%a@]@," (Asgns.fprint_hum ()) fwd_code); match t.diff with | Some { backprop = Noop; _ } -> () | Some { backprop = bwd_code; _ } -> Stdlib.Format.printf "@[<v 2>Current backprop body:%a@]@," (Asgns.fprint_hum ()) bwd_code | None -> ()); if with_low_level then ( (match t.forward with | Noop -> () | fwd_code -> Stdlib.Format.printf "@[<v 2>Current forward low-level body:%a@]@," (Arrayjit.Low_level.fprint_hum ()) @@ Asgns.to_low_level fwd_code); match t.diff with | Some { backprop = Noop; _ } -> () | Some { backprop = bwd_code; _ } -> Stdlib.Format.printf "@[<v 2>Current backprop low-level body:%a@]@," (Arrayjit.Low_level.fprint_hum ()) @@ Asgns.to_low_level bwd_code | None -> ()); Stdlib.Format.printf "\n%!" let print_forward_roots ~with_grad ~with_code (style : array_print_style) = List.iter (Map.to_alist ~key_order:`Increasing session_state.forward_roots) ~f:(fun (id, root) -> assert (id = root.id); print ~with_grad ~with_code style root) let print_tree ?entries_per_axis ?(with_backend_info = false) ?(with_id = true) ?(with_shape = false) ?(with_value = true) ~with_grad ~depth t = (* FIXME: print backend info *) ignore with_backend_info; PrintBox_text.output Stdio.stdout @@ PrintBox_utils.dag_to_box @@ PrintBox_utils.boxify depth @@ to_dag ?entries_per_axis ~with_id ~with_shape ~with_value ~with_grad t let value_1d_points ?from_axis ~xdim t = Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ~xdim arr) @@ Lazy.force t.value.array let value_2d_points ?from_axis ~xdim ~ydim t = Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr) @@ Lazy.force t.value.array let grad_1d_points ?from_axis ~xdim t = match t.diff with | None -> [||] | Some diff -> Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ~xdim arr) @@ Lazy.force diff.grad.array let grad_2d_points ?from_axis ~xdim ~ydim t = match t.diff with | None -> [||] | Some diff -> Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr) @@ Lazy.force diff.grad.array let set_value t = Nd.set_from_float @@ Option.value_exn @@ Lazy.force t.value.array let get_value t = Nd.get_as_float @@ Option.value_exn @@ Lazy.force t.value.array let set_grad t = Nd.set_from_float @@ Option.value_exn @@ Lazy.force @@ (Option.value_exn t.diff).grad.array let get_grad t = Nd.get_as_float @@ Option.value_exn @@ Lazy.force @@ (Option.value_exn t.diff).grad.array let set_values t values = Nd.(reset (Constant_fill { values; strict = false }) @@ Option.value_exn @@ Lazy.force t.value.array) let get_values t = Nd.(retrieve_flat_values @@ Option.value_exn @@ Lazy.force t.value.array)
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>