Module Nx_core.Symbolic_shapeSource
Symbolic dimensions for shape-polymorphic tensors.
This module enables shape-polymorphic tensor operations by representing tensor dimensions as symbolic expressions that can be resolved at runtime. Symbolic shapes support dynamic batch sizes, variable sequence lengths, and other runtime-determined dimensions in compiled kernels.
Overview
A symbolic shape consists of dimension expressions that may be static constants, symbolic variables, or arithmetic combinations. Variables can be bound to concrete values at runtime, enabling a single compiled kernel to handle multiple input shapes.
Create shapes with of_ints for static dimensions or use dynamic to introduce symbolic variables. Variables created with var are distinct even if they share the same name, preventing accidental aliasing.
Key Concepts
Variables and Identity
Each call to var creates a fresh variable with a unique identity. Variables with the same name are distinct:
let v1 = Symbolic_shape.var "x" ~min:0 ~max:10 in
let v2 = Symbolic_shape.var "x" ~min:0 ~max:10 in
assert (v1 != v2)
(* Distinct variables *)
This prevents unintended variable sharing across independent operations.
Dimension Expressions
Dimensions support arithmetic operations to express relationships:
let n = Symbolic_shape.var "n" ~min:1 ~max:100 in
let dim_n = Symbolic_shape.dim_of_var n in
let dim_2n = Symbolic_shape.mul dim_n (Symbolic_shape.static 2) in
let dim_2n_plus_1 = Symbolic_shape.add dim_2n (Symbolic_shape.static 1)
Runtime Binding
Variables remain unbound until explicitly assigned values with bind. Binding checks that values respect variable bounds. Evaluation functions return None for shapes containing unbound variables.
Reshape with Inference
The special infer dimension (represented as -1) allows NumPy-style reshape operations where one dimension is computed from the total element count. resolve_reshape computes the inferred dimension:
let from_shape = Symbolic_shape.of_ints [| 2; 3; 4 |] in
let to_shape = [| Symbolic_shape.static 6; Symbolic_shape.infer |] in
match Symbolic_shape.resolve_reshape ~from_shape ~to_shape with
| Some resolved -> assert (Symbolic_shape.eval resolved = Some [| 6; 4 |])
| None -> assert false
A symbolic variable representing a dimension.
Variables have unique identities (compared by ID, not name or bounds) and mutable bindings that persist until changed. Each call to var creates a fresh variable even if the name is reused. Variables track minimum and maximum bounds to validate runtime values.
Sourcetype expr = | Const of intStatic dimension with fixed value
| Var of var| Add of expr * expr| Mul of expr * exprProduct of two expressions
| Neg of exprNegation of an expression
A dimension is an expression.
Dimensions may be static constants, symbolic variables, or arithmetic combinations. Use static for constants, dim_of_var for variables, and add, mul, neg for arithmetic.
A shape is an array of dimensions.
Shapes are immutable arrays representing the dimensions of multi-dimensional tensors. Operations return new shapes without modifying originals. Each element is a dimension expression that may contain symbolic variables. Empty shapes represent scalars (rank-0 tensors).
Creation
static n creates a static dimension with value n.
Static dimensions have fixed values known at creation time.
Sourceval dynamic : string -> min:int -> max:int -> dim dynamic name ~min ~max creates a dynamic dimension with bounds.
Creates a fresh symbolic variable and converts it to a dimension expression. Equivalent to dim_of_var (var name ~min ~max). The variable can be bound to any value in the range [min, max] at runtime.
Sourceval var : string -> min:int -> max:int -> var var name ~min ~max creates a fresh symbolic variable.
Each call returns a distinct variable with a unique identity, regardless of name. Names are used for debugging and display purposes only. Variables remain unbound until explicitly assigned with bind.
The bounds [min, max] constrain valid runtime values.
dim_of_var var wraps a variable as a dimension expression.
of_ints arr creates a shape with all static dimensions.
Each element of arr is converted to a static dimension using static. This is the standard way to create concrete shapes.
of_list lst creates a shape with all static dimensions.
Equivalent to of_ints (Array.of_list lst).
Dimension Operations
add d1 d2 creates a dimension expression d1 + d2.
Constructs an Add expression representing the sum of two dimensions. Useful for expressing padding or concatenation dimensions.
mul d1 d2 creates a dimension expression d1 * d2.
Constructs a Mul expression representing the product of two dimensions. Useful for expressing flattened or tiled dimensions.
neg d creates a dimension expression -d.
Constructs a Neg expression representing the negation of a dimension. Rarely used directly; primarily for internal expression manipulation.
Runtime Binding
bind var value shape binds value to var globally and updates all occurrences of var in shape by identity.
Performs a global mutation of the variable's mutable state. The shape is traversed to find all instances matching var by identity, including those within compound expressions. The binding persists until changed.
Variables must be bound before shapes can be evaluated to concrete dimensions. Binding is checked against the variable's min and max bounds specified at creation.
Time complexity: O(n) where n is the total size of all expression trees in the shape.
Sourceval eval : t -> int array option eval shape returns concrete shape if all dimensions are bound.
Evaluates all dimensions in shape to produce an integer array. Returns Some arr if all symbolic variables are bound and all expressions can be computed. Returns None if any dimension contains an unbound variable.
This is the primary way to extract concrete shapes for backend operations.
eval_dim dim evaluates a single dimension.
Returns Some n if dim is fully bound and can be evaluated to a concrete value. Returns None if dim contains unbound variables.
Evaluates arithmetic expressions by depth-first recursion without memoization, evaluating subexpressions and applying the corresponding operations.
Sourceval partial_eval : t -> int option array partial_eval shape returns an array of evaluated dimensions.
Evaluates each dimension independently, returning Some n for bound dimensions and None for unbound dimensions. Unlike eval, this succeeds even when some dimensions remain symbolic.
Useful for debugging and displaying partially bound shapes.
Sourceval is_fully_bound : t -> bool is_fully_bound shape returns true if all dimensions are bound.
Checks whether every symbolic variable in shape has been assigned a value. Static dimensions are always considered bound. Returns true if eval shape would succeed.
numel shape returns the total number of elements if shape is fully bound.
Computes the product of all dimensions if the shape can be fully evaluated. Returns Some 1 for empty shapes (scalars). Returns None if any dimension contains unbound variables.
Special Values
Special dimension value representing "infer from context".
Equivalent to -1 in NumPy reshape operations. Use this in target shapes to indicate that a dimension should be computed from the total element count. At most one dimension in a shape may be infer; this constraint exists because multiple unknowns make element count calculation ambiguous. The constraint is enforced at runtime by resolve_reshape.
The resolve_reshape function computes the concrete value for inferred dimensions based on the source shape's element count.
is_infer dim returns true if dimension should be inferred.
Checks whether dim evaluates to -1, indicating it should be computed from context during reshape operations. Returns false for unbound variables, which cannot be evaluated to -1.
Shape Resolution
Sourceval resolve_reshape : from_shape:t -> to_shape:t -> t option resolve_reshape ~from_shape ~to_shape resolves a reshape operation with inference.
Computes concrete dimensions for infer values in to_shape based on the element count of from_shape. At most one dimension in to_shape may be infer.
Returns Some resolved_shape if:
from_shape is fully bound (all dimensions have concrete values)to_shape contains zero or one infer dimensions- The total elements of
from_shape divides evenly by known dimensions in to_shape
Returns None if:
from_shape contains unbound variables- Element count doesn't divide evenly (i.e.,
total_elements mod known_product != 0)
Examples
Resolving a reshape with one inferred dimension:
let from_shape = Symbolic_shape.of_ints [| 2; 3; 4 |] in
let to_shape = [| Symbolic_shape.static 6; Symbolic_shape.infer |] in
match Symbolic_shape.resolve_reshape ~from_shape ~to_shape with
| Some resolved -> assert (Symbolic_shape.eval resolved = Some [| 6; 4 |])
| None -> ()
substitute bindings shape substitutes variable bindings into shape.
Replaces variables in shape with their corresponding values from bindings, creating a new shape with Const nodes where variables were substituted. Variables not present in bindings remain as Var nodes. Unlike bind, this creates a new shape without mutating variable state.
Binding list format: (var, value) pairs where var is matched by its unique identity (not by name).
Useful for creating multiple specialized versions of a parametric shape without side effects.
Analysis
vars shape returns all unique symbolic variables in shape.
Extracts all distinct variables from dimension expressions. Variables are compared by identity, so the same variable object appears only once even if used in multiple dimensions or compound expressions.
Returns an empty list for shapes containing only static dimensions. The order of variables in the result is unspecified.
var_id v returns the unique identifier assigned to v.
var_name v returns the user-facing name of v.
var_bounds v returns the inclusive minimum and maximum bounds for v.
is_static shape returns true if all dimensions are static.
Checks whether shape contains only Const expressions with no symbolic variables. Static shapes can be evaluated without binding variables.
Returns true for empty shapes.
rank shape returns number of dimensions.
Utilities
to_string shape returns human-readable representation.
Formats the shape as a bracketed list of dimension expressions. Static dimensions appear as integers. Variables show as name#id or name#id=value if bound. Empty variable names render as v{id}. Compound expressions use infix notation with parentheses.
equal s1 s2 compares shapes structurally.
Returns true if s1 and s2 have the same rank and corresponding dimensions are structurally equal. Dimensions are compared by:
- Constants: Equal if values match
- Variables: Equal if variable identities match (not names)
- Expressions: Equal if operators and subexpressions match recursively
Two shapes with different but equivalent variables are not equal:
let v1 = Symbolic_shape.var "x" ~min:0 ~max:10 in
let v2 = Symbolic_shape.var "x" ~min:0 ~max:10 in
let s1 = [| Symbolic_shape.dim_of_var v1 |] in
let s2 = [| Symbolic_shape.dim_of_var v2 |] in
assert (not (Symbolic_shape.equal s1 s2))
(* Different identities *)
Performs structural comparison without evaluation; expressions that evaluate to the same value may not be equal. For example, Add (Const 1, Const 2) is not equal to Const 3.
Time complexity: O(n*m) where n is shape size and m is average expression tree depth.