package rune
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>
Automatic differentiation and JIT compilation for OCaml
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f
doc/src/rune.jit/shape_expr.ml.html
Source file shape_expr.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98(* Minimal symbolic shape expressions for Rune JIT. This module mirrors the capabilities we need from Nx's Symbolic_shape while keeping lib-jit independent from Nx. Shapes are arrays of expressions. Each expression can be a constant, a symbolic variable, or basic arithmetic combinations (addition, multiplication, negation). Variables carry a unique id, an optional user-facing name, and bounds. *) module Var = struct type t = { id : int; name : string; min : int; max : int } let create ~id ~name ~min ~max = if min < 0 then invalid_arg "Shape_expr.Var.create: min must be non-negative"; if min > max then invalid_arg "Shape_expr.Var.create: min must be <= max"; { id; name; min; max } let id v = v.id let name v = v.name let min v = v.min let max v = v.max end type expr = | Const of int | Var of Var.t | Add of expr * expr | Mul of expr * expr | Neg of expr type shape = expr array let const n = Const n let var v = Var v let add a b = Add (a, b) let mul a b = Mul (a, b) let neg e = Neg e let of_int_array arr = Array.map const arr let rec to_string_expr = function | Const n -> string_of_int n | Var v -> if v.name = "" then Printf.sprintf "v%d" v.id else Printf.sprintf "%s#%d" v.name v.id | Add (a, b) -> Printf.sprintf "(%s + %s)" (to_string_expr a) (to_string_expr b) | Mul (a, b) -> Printf.sprintf "(%s * %s)" (to_string_expr a) (to_string_expr b) | Neg e -> Printf.sprintf "(-%s)" (to_string_expr e) let to_string shape = "[" ^ String.concat "; " (Array.to_list (Array.map (fun e -> to_string_expr e) shape)) ^ "]" let rec eval_expr bindings = function | Const n -> Some n | Var v -> List.assoc_opt v.id bindings | Add (a, b) -> ( match (eval_expr bindings a, eval_expr bindings b) with | Some x, Some y -> Some (x + y) | _ -> None) | Mul (a, b) -> ( match (eval_expr bindings a, eval_expr bindings b) with | Some x, Some y -> Some (x * y) | _ -> None) | Neg e -> Option.map (fun x -> -x) (eval_expr bindings e) let eval bindings shape = Array.map (fun e -> eval_expr bindings e) shape let to_int_array_exn bindings shape = let evaluated = eval bindings shape in Array.mapi (fun i -> function | Some n -> n | None -> invalid_arg (Printf.sprintf "Shape_expr: dimension %d unresolved" i)) evaluated let map f shape = Array.map f shape let map2 f s1 s2 = if Array.length s1 <> Array.length s2 then invalid_arg "Shape_expr.map2: shape rank mismatch"; Array.init (Array.length s1) (fun i -> f s1.(i) s2.(i)) let fold f init shape = Array.fold_left f init shape let rec upper_bound_expr = function | Const n -> n | Var v -> v.max | Add (a, b) -> upper_bound_expr a + upper_bound_expr b | Mul (a, b) -> upper_bound_expr a * upper_bound_expr b | Neg _ -> 0 let upper_bounds shape = Array.map upper_bound_expr shape
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>