Design and implement symbolic shape inference capability for the ONNX IR that can propagate shapes through the graph while preserving symbolic dimension relationships (e.g., batch, seq_len, N+1).
SymbolicDimclass (_core.py:1241): Immutable symbolic dimension with string or None valuesShapeclass (_core.py:1330): Supports mixed static/dynamic dimensions, freezing, denotationsShapeInferencePass(passes/common/shape_inference.py): Delegates to ONNX's C++ shape inference- Pass infrastructure:
InPlacePass,FunctionalPass,PassManagerin_pass_infra.py
- Current shape inference relies on
onnx.shape_inference.infer_shapes()which:- Requires round-trip serialization (IR → proto → C++ → proto → IR)
- Limited symbolic expression support (doesn't track
N+1,N*2, etc.) - Cannot extend with custom op shape inference easily
SymbolicDimonly stores string names, not expressions or relationships
Create a native symbolic shape inference system that:
- Operates directly on the IR (no serialization overhead)
- Supports symbolic expressions (e.g.,
N+1,batch*heads) using SymPy - Is extensible via per-op shape inference functions
- Tracks dimension equivalence/constraints
- Mature symbolic math library with expression simplification built-in
- Supports substitution, evaluation, comparison, and constraint solving
- Handles edge cases (division, modulo, inequalities)
- Well-tested and maintained
- 1.1 Update
SymbolicDimto support SymPy expressions (backward compatible)- Accept
str | None | sympy.Exprin constructor - Store SymPy expression internally (
sympy.Symbolfor strings, expression for complex) - Preserve existing
valueproperty behavior for simple cases - Add
exprproperty returning the SymPy expression
- Accept
- 1.2 Add arithmetic operations to
SymbolicDim__add__,__sub__,__mul__,__floordiv__,__mod__- Return new
SymbolicDimwith computed SymPy expression
- 1.3 Update
Shapeclass to leverage enhancedSymbolicDim- No API changes needed -
Shapealready acceptsSymbolicDim - Add
evaluate(bindings)method to compute concrete shape - Add
simplify()method to simplify all symbolic dimensions
- No API changes needed -
- 1.4 Update serialization (
serde.py) to handle SymPy expressions- Serialize complex expressions to
dim_paramstrings (e.g.,"N + 1") - Parse
dim_paramstrings back to SymPy expressions where possible
- Serialize complex expressions to
- 2.1 Design
ShapeInferenceContextto track:- Dimension variable bindings
- Dimension constraints/equivalences
- Value-to-shape mappings
- Current opset version for lookups
- 2.2 Define shape merge policies (enum):
SKIP- Don't update if shape already existsOVERRIDE- Always replace with inferred shapeREFINE- Only update if inferred shape is more specific (e.g., concrete beats symbolic, named symbolic beatsNone)STRICT- Fail if inferred shape conflicts with existing
- 2.3 Create
OpShapeInferenceRegistrywith opset supportregister(domain, op_type, opsets)decorator/methodget(domain, op_type, opset)lookup by version- Fallback to closest lower opset if exact match not found
- 2.4 Implement
SymbolicShapeInferencePassthat:- Traverses graph in topological order
- Looks up inference function by (domain, op_type, opset)
- Merges shapes according to configured policy
- 3.1 Implement shape inference for
Add- Broadcasting rules with symbolic dimensions
- Handle rank mismatch (prepend 1s to shorter shape)
- Symbolic broadcast:
1broadcasts to any dim, matching dims stay, else error/unknown
- 3.2 Implement shape inference for
Transpose- Permute dimensions according to
permattribute - Default perm (reverse) if not specified
- Permute dimensions according to
- 4.1 Design
DimensionConstraintsystem- Equality constraints:
dim1 == dim2 - Arithmetic constraints:
dim1 == dim2 + k - Range constraints:
dim >= 1
- Equality constraints:
- 4.2 Implement constraint propagation/unification
- 4.3 Add constraint validation (detect conflicts)
- 5.1 Create user-facing API (
infer_symbolic_shapes()) - 5.2 Add configuration options:
- Strict mode (fail on unknown shapes vs. leave as unknown)
- Custom op registry
- Data propagation for Constant/Shape ops
- 5.3 Write comprehensive documentation
- 5.4 Add unit tests for all components
- 6.1 Bidirectional inference (propagate constraints backward)
- 6.2 Integration with ONNX's shape inference as fallback
- 6.3 Symbolic dimension visualization/debugging tools
- 6.4 Support for dynamic shapes with runtime bounds
Decision: Use SymPy for symbolic expressions
- Mature library with built-in simplification, substitution, and solving
sympy.Symbol("batch")for named dimensionssympy.Integer(128)for constants (or plain Pythonint)- Arithmetic expressions:
sympy.Symbol("N") + 1
Decision: Unify by enhancing SymbolicDim (backward compatible)
SymbolicDimacceptsstr | None | sympy.Expr- Existing code using
SymbolicDim("batch")continues to work - New code can use
SymbolicDim(sympy.Symbol("N") + 1) Shapeclass unchanged - already storesint | SymbolicDim
Backward Compatibility:
# Existing code (still works)
dim = ir.SymbolicDim("batch")
dim.value # Returns "batch"
# New capability
dim = ir.SymbolicDim("N") + 1
dim.expr # Returns sympy.Symbol("N") + 1
dim.value # Returns "N + 1" (string representation)
# Shape class unchanged
shape = ir.Shape(["batch", None, 3]) # Still works
shape = ir.Shape([ir.SymbolicDim("N") + 1, 128]) # Now also worksOption A: Decorator-based registration
@register_shape_inference("", "Add")
def infer_add(ctx, node): ...Option B: Explicit registration
registry.register("", "Add", infer_add)Recommendation: Support both patterns.
src/onnx_ir/
├── _core.py # Updated SymbolicDim and Shape classes
├── shape_inference/ # Public module for shape inference
│ ├── __init__.py # Public API exports
│ ├── _context.py # ShapeInferenceContext
│ ├── _registry.py # OpShapeInferenceRegistry with opset support
│ ├── _broadcast.py # Broadcasting utilities (shared)
│ └── ops/ # Per-op shape inference (one file per op)
│ ├── __init__.py # Registers all ops
│ ├── _add.py # Add
│ ├── _transpose.py # Transpose
│ ├── _elementwise.py # Shared logic for element-wise ops (parameterized)
│ └── ... # Future ops
├── passes/
│ └── common/
│ └── symbolic_shape_inference.py # SymbolicShapeInferencePass
1. One file per operator (unless parameterizable):
# shape_inference/ops/_add.py
from onnx_ir.shape_inference import registry
@registry.register("", "Add", opsets=range(7, 20))
def infer_add(ctx, node):
...2. Parameterized ops share a file:
# shape_inference/ops/_elementwise.py
def _make_elementwise_inferrer(op_type):
def infer(ctx, node):
# shared broadcasting logic
...
return infer
# Register all elementwise ops
for op in ["Add", "Sub", "Mul", "Div", "Pow"]:
registry.register("", op, opsets=range(7, 20))(_make_elementwise_inferrer(op))3. Opset-aware registration:
# shape_inference/_registry.py
class OpShapeInferenceRegistry:
def register(self, domain: str, op_type: str, opsets: range | int | None = None):
"""Register shape inference for specific opset versions.
Args:
domain: ONNX domain (e.g., "", "com.microsoft")
op_type: Operator type (e.g., "Add", "Transpose")
opsets: Opset versions to register for. None means all versions.
"""
...
def get(self, domain: str, op_type: str, opset: int) -> Callable | None:
"""Get shape inference function for specific opset version."""
...import onnx_ir as ir
from onnx_ir.passes.common import SymbolicShapeInferencePass
# Load model with symbolic batch dimension
model = ir.load("model.onnx")
# Run symbolic shape inference
pass_ = SymbolicShapeInferencePass()
result = pass_(model)
# Shapes are now propagated - same Shape class, enhanced SymbolicDim
for value in model.graph.values():
print(f"{value.name}: {value.shape}")
# Example output: "output: Shape([batch, 256, seq_len + 1])"
# Evaluate with concrete values
shape = result.model.graph.outputs[0].shape
concrete = shape.evaluate({"batch": 32, "seq_len": 128})
# Result: (32, 256, 129)
# Arithmetic on symbolic dimensions
batch = ir.SymbolicDim("batch")
seq_len = ir.SymbolicDim("seq_len")
new_dim = batch * seq_len # SymbolicDim with expr: batch*seq_len
# Simplification via SymPy
dim = ir.SymbolicDim("N") + 0
dim.simplify() # Returns SymbolicDim("N")- Start with read-only inference (Phase 1-3), add constraint solving later (Phase 4)
- Prioritize commonly used ops in Phase 3
- SymPy handles expression simplification, substitution, and comparison automatically
- SymPy's
solve()can be used for constraint solving in Phase 4 - Ensure thread-safety for registry access
- SymPy is a required dependency; add to
pyproject.tomlif not present