package nx

  1. Overview
  2. Docs

Source file symbolic_shape.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
(* Symbolic shapes for shape-polymorphic tensors. *)

type var = { name : string; min : int; max : int; mutable value : int option }

type expr =
  | Const of int
  | Var of var
  | Add of expr * expr
  | Mul of expr * expr
  | Neg of expr

type dim = expr
type t = dim array

let static n =
  if n < 0 then
    Error.invalid ~op:"static"
      ~what:(Printf.sprintf "dimension %d" n)
      ~reason:"negative dimension" ();
  Const n

let dynamic name ~min ~max =
  if min < 0 then
    Error.invalid ~op:"dynamic"
      ~what:(Printf.sprintf "min=%d" min)
      ~reason:"must be non-negative" ();
  if min > max then
    Error.invalid ~op:"dynamic"
      ~what:(Printf.sprintf "bounds [%d, %d]" min max)
      ~reason:(Printf.sprintf "min > max")
      ();
  Var { name; min; max; value = None }

let add d1 d2 = Add (d1, d2)
let mul d1 d2 = Mul (d1, d2)
let neg d = Neg d
let of_ints arr = Array.map (fun n -> static n) arr
let of_list lst = of_ints (Array.of_list lst)

let bind_var var value =
  if value < var.min || value > var.max then
    Error.invalid ~op:"bind"
      ~what:(Printf.sprintf "value %d for variable %s" value var.name)
      ~reason:(Printf.sprintf "outside bounds [%d, %d]" var.min var.max)
      ();
  var.value <- Some value

let bind name value shape =
  (* Find and bind the variable with the given name *)
  let rec find_and_bind expr =
    match expr with
    | Var v when v.name = name -> bind_var v value
    | Add (e1, e2) | Mul (e1, e2) ->
        find_and_bind e1;
        find_and_bind e2
    | Neg e -> find_and_bind e
    | _ -> ()
  in
  Array.iter find_and_bind shape

let rec eval_expr = function
  | Const n -> Some n
  | Var var -> var.value
  | Add (e1, e2) -> (
      match (eval_expr e1, eval_expr e2) with
      | Some v1, Some v2 -> Some (v1 + v2)
      | _ -> None)
  | Mul (e1, e2) -> (
      match (eval_expr e1, eval_expr e2) with
      | Some v1, Some v2 -> Some (v1 * v2)
      | _ -> None)
  | Neg e -> ( match eval_expr e with Some v -> Some (-v) | None -> None)

let eval_dim = eval_expr

let eval shape =
  let rec loop acc i =
    if i < 0 then Some (Array.of_list acc)
    else
      match eval_dim shape.(i) with
      | None -> None
      | Some n -> loop (n :: acc) (i - 1)
  in
  loop [] (Array.length shape - 1)

let partial_eval shape = Array.map eval_expr shape

let rec expr_is_bound = function
  | Const _ -> true
  | Var v -> Option.is_some v.value
  | Add (e1, e2) | Mul (e1, e2) -> expr_is_bound e1 && expr_is_bound e2
  | Neg e -> expr_is_bound e

let is_fully_bound shape = Array.for_all expr_is_bound shape

let rec expr_vars = function
  | Const _ -> []
  | Var v -> [ v ]
  | Add (e1, e2) | Mul (e1, e2) -> expr_vars e1 @ expr_vars e2
  | Neg e -> expr_vars e

let vars shape =
  shape |> Array.to_list |> List.concat_map expr_vars
  |> List.sort_uniq (fun v1 v2 -> String.compare v1.name v2.name)

let rec expr_is_static = function
  | Const _ -> true
  | Var _ -> false
  | Add (e1, e2) | Mul (e1, e2) -> expr_is_static e1 && expr_is_static e2
  | Neg e -> expr_is_static e

let is_static shape = Array.for_all expr_is_static shape
let rank shape = Array.length shape

let to_string shape =
  let rec expr_to_string = function
    | Const n -> string_of_int n
    | Var var -> (
        match var.value with
        | None -> var.name
        | Some n -> Printf.sprintf "%s=%d" var.name n)
    | Add (e1, e2) ->
        Printf.sprintf "(%s + %s)" (expr_to_string e1) (expr_to_string e2)
    | Mul (e1, e2) ->
        Printf.sprintf "(%s * %s)" (expr_to_string e1) (expr_to_string e2)
    | Neg e -> Printf.sprintf "(-%s)" (expr_to_string e)
  in
  "["
  ^ String.concat "; " (Array.to_list (Array.map expr_to_string shape))
  ^ "]"

let rec expr_equal e1 e2 =
  match (e1, e2) with
  | Const n1, Const n2 -> n1 = n2
  | Var v1, Var v2 -> v1 == v2
  | Add (a1, b1), Add (a2, b2) | Mul (a1, b1), Mul (a2, b2) ->
      expr_equal a1 a2 && expr_equal b1 b2
  | Neg e1', Neg e2' -> expr_equal e1' e2'
  | _ -> false

let equal s1 s2 =
  Array.length s1 = Array.length s2 && Array.for_all2 expr_equal s1 s2

let numel shape =
  let n = Array.length shape in
  if n = 0 then Some 1
  else
    let rec compute_product i acc =
      if i >= n then acc
      else
        match (acc, eval_dim shape.(i)) with
        | None, _ -> None
        | _, None -> None
        | Some acc_val, Some dim_val ->
            compute_product (i + 1) (Some (acc_val * dim_val))
    in
    compute_product 0 (Some 1)

(** Special dimension value representing "infer from context" (like -1 in NumPy
    reshape) *)
let infer = Const (-1)

let is_infer dim = match eval_dim dim with Some -1 -> true | _ -> false

let resolve_reshape ~from_shape ~to_shape =
  (* Resolve a reshape operation with potential -1 (infer) dimensions *)
  match numel from_shape with
  | None -> None (* Can't resolve if source shape is not fully known *)
  | Some total_elements -> (
      (* Count infer dimensions and compute known product *)
      let infer_indices = ref [] in
      let known_product = ref 1 in
      let resolved = Array.copy to_shape in

      Array.iteri
        (fun i dim ->
          if is_infer dim then infer_indices := i :: !infer_indices
          else
            match eval_dim dim with
            | Some n when n > 0 -> known_product := !known_product * n
            | Some n ->
                Error.invalid ~op:"resolve_reshape"
                  ~what:(Printf.sprintf "dimension %d" i)
                  ~reason:(Printf.sprintf "invalid size %d" n)
                  ()
            | None -> () (* Keep symbolic dimension as is *))
        to_shape;

      match !infer_indices with
      | [] -> Some resolved (* No inference needed *)
      | [ idx ] ->
          (* Exactly one dimension to infer *)
          if !known_product > 0 && total_elements mod !known_product = 0 then (
            let inferred_size = total_elements / !known_product in
            resolved.(idx) <- static inferred_size;
            Some resolved)
          else None (* Can't evenly divide *)
      | _ ->
          Error.invalid ~op:"resolve_reshape" ~what:"shape"
            ~reason:"can only infer one dimension" ())

let substitute bindings shape =
  (* Substitute variable bindings into a shape *)
  Array.map
    (fun dim ->
      let rec subst = function
        | Const n -> Const n
        | Var v as var -> (
            try Const (List.assoc v.name bindings) with Not_found -> var)
        | Add (e1, e2) -> Add (subst e1, subst e2)
        | Mul (e1, e2) -> Mul (subst e1, subst e2)
        | Neg e -> Neg (subst e)
      in
      subst dim)
    shape