Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created January 28, 2026 00:54
Show Gist options
  • Select an option

  • Save justinchuby/6c2c4175db6a1c4cb81fbc74edac66a8 to your computer and use it in GitHub Desktop.

Select an option

Save justinchuby/6c2c4175db6a1c4cb81fbc74edac66a8 to your computer and use it in GitHub Desktop.
Shape inference plan

Symbolic Shape Inference for ONNX IR

Problem Statement

Design and implement symbolic shape inference capability for the ONNX IR that can propagate shapes through the graph while preserving symbolic dimension relationships (e.g., batch, seq_len, N+1).

Current State Analysis

Existing Infrastructure

  • SymbolicDim class (_core.py:1241): Immutable symbolic dimension with string or None values
  • Shape class (_core.py:1330): Supports mixed static/dynamic dimensions, freezing, denotations
  • ShapeInferencePass (passes/common/shape_inference.py): Delegates to ONNX's C++ shape inference
  • Pass infrastructure: InPlacePass, FunctionalPass, PassManager in _pass_infra.py

Limitations of Current Approach

  1. Current shape inference relies on onnx.shape_inference.infer_shapes() which:
    • Requires round-trip serialization (IR → proto → C++ → proto → IR)
    • Limited symbolic expression support (doesn't track N+1, N*2, etc.)
    • Cannot extend with custom op shape inference easily
  2. SymbolicDim only stores string names, not expressions or relationships

Proposed Approach

High-Level Design

Create a native symbolic shape inference system that:

  1. Operates directly on the IR (no serialization overhead)
  2. Supports symbolic expressions (e.g., N+1, batch*heads) using SymPy
  3. Is extensible via per-op shape inference functions
  4. Tracks dimension equivalence/constraints

Why SymPy?

  • Mature symbolic math library with expression simplification built-in
  • Supports substitution, evaluation, comparison, and constraint solving
  • Handles edge cases (division, modulo, inequalities)
  • Well-tested and maintained

Work Plan

Phase 1: SymPy Integration into SymbolicDim

  • 1.1 Update SymbolicDim to support SymPy expressions (backward compatible)
    • Accept str | None | sympy.Expr in constructor
    • Store SymPy expression internally (sympy.Symbol for strings, expression for complex)
    • Preserve existing value property behavior for simple cases
    • Add expr property returning the SymPy expression
  • 1.2 Add arithmetic operations to SymbolicDim
    • __add__, __sub__, __mul__, __floordiv__, __mod__
    • Return new SymbolicDim with computed SymPy expression
  • 1.3 Update Shape class to leverage enhanced SymbolicDim
    • No API changes needed - Shape already accepts SymbolicDim
    • Add evaluate(bindings) method to compute concrete shape
    • Add simplify() method to simplify all symbolic dimensions
  • 1.4 Update serialization (serde.py) to handle SymPy expressions
    • Serialize complex expressions to dim_param strings (e.g., "N + 1")
    • Parse dim_param strings back to SymPy expressions where possible

Phase 2: Shape Inference Engine

  • 2.1 Design ShapeInferenceContext to track:
    • Dimension variable bindings
    • Dimension constraints/equivalences
    • Value-to-shape mappings
    • Current opset version for lookups
  • 2.2 Define shape merge policies (enum):
    • SKIP - Don't update if shape already exists
    • OVERRIDE - Always replace with inferred shape
    • REFINE - Only update if inferred shape is more specific (e.g., concrete beats symbolic, named symbolic beats None)
    • STRICT - Fail if inferred shape conflicts with existing
  • 2.3 Create OpShapeInferenceRegistry with opset support
    • register(domain, op_type, opsets) decorator/method
    • get(domain, op_type, opset) lookup by version
    • Fallback to closest lower opset if exact match not found
  • 2.4 Implement SymbolicShapeInferencePass that:
    • Traverses graph in topological order
    • Looks up inference function by (domain, op_type, opset)
    • Merges shapes according to configured policy

Phase 3: Initial Operator Support

  • 3.1 Implement shape inference for Add
    • Broadcasting rules with symbolic dimensions
    • Handle rank mismatch (prepend 1s to shorter shape)
    • Symbolic broadcast: 1 broadcasts to any dim, matching dims stay, else error/unknown
  • 3.2 Implement shape inference for Transpose
    • Permute dimensions according to perm attribute
    • Default perm (reverse) if not specified

Phase 4: Constraint System

  • 4.1 Design DimensionConstraint system
    • Equality constraints: dim1 == dim2
    • Arithmetic constraints: dim1 == dim2 + k
    • Range constraints: dim >= 1
  • 4.2 Implement constraint propagation/unification
  • 4.3 Add constraint validation (detect conflicts)

Phase 5: Integration & API

  • 5.1 Create user-facing API (infer_symbolic_shapes())
  • 5.2 Add configuration options:
    • Strict mode (fail on unknown shapes vs. leave as unknown)
    • Custom op registry
    • Data propagation for Constant/Shape ops
  • 5.3 Write comprehensive documentation
  • 5.4 Add unit tests for all components

Phase 6: (Optional) Advanced Features

  • 6.1 Bidirectional inference (propagate constraints backward)
  • 6.2 Integration with ONNX's shape inference as fallback
  • 6.3 Symbolic dimension visualization/debugging tools
  • 6.4 Support for dynamic shapes with runtime bounds

Key Design Decisions

1. Expression Representation

Decision: Use SymPy for symbolic expressions

  • Mature library with built-in simplification, substitution, and solving
  • sympy.Symbol("batch") for named dimensions
  • sympy.Integer(128) for constants (or plain Python int)
  • Arithmetic expressions: sympy.Symbol("N") + 1

2. Integration with Existing Shape Class

Decision: Unify by enhancing SymbolicDim (backward compatible)

  • SymbolicDim accepts str | None | sympy.Expr
  • Existing code using SymbolicDim("batch") continues to work
  • New code can use SymbolicDim(sympy.Symbol("N") + 1)
  • Shape class unchanged - already stores int | SymbolicDim

Backward Compatibility:

# Existing code (still works)
dim = ir.SymbolicDim("batch")
dim.value  # Returns "batch"

# New capability
dim = ir.SymbolicDim("N") + 1
dim.expr   # Returns sympy.Symbol("N") + 1
dim.value  # Returns "N + 1" (string representation)

# Shape class unchanged
shape = ir.Shape(["batch", None, 3])  # Still works
shape = ir.Shape([ir.SymbolicDim("N") + 1, 128])  # Now also works

3. Op Registration

Option A: Decorator-based registration

@register_shape_inference("", "Add")
def infer_add(ctx, node): ...

Option B: Explicit registration

registry.register("", "Add", infer_add)

Recommendation: Support both patterns.


File Structure Proposal

src/onnx_ir/
├── _core.py                      # Updated SymbolicDim and Shape classes
├── shape_inference/              # Public module for shape inference
│   ├── __init__.py               # Public API exports
│   ├── _context.py               # ShapeInferenceContext
│   ├── _registry.py              # OpShapeInferenceRegistry with opset support
│   ├── _broadcast.py             # Broadcasting utilities (shared)
│   └── ops/                      # Per-op shape inference (one file per op)
│       ├── __init__.py           # Registers all ops
│       ├── _add.py               # Add
│       ├── _transpose.py         # Transpose
│       ├── _elementwise.py       # Shared logic for element-wise ops (parameterized)
│       └── ...                   # Future ops
├── passes/
│   └── common/
│       └── symbolic_shape_inference.py  # SymbolicShapeInferencePass

Module Design Principles

1. One file per operator (unless parameterizable):

# shape_inference/ops/_add.py
from onnx_ir.shape_inference import registry

@registry.register("", "Add", opsets=range(7, 20))
def infer_add(ctx, node):
    ...

2. Parameterized ops share a file:

# shape_inference/ops/_elementwise.py
def _make_elementwise_inferrer(op_type):
    def infer(ctx, node):
        # shared broadcasting logic
        ...
    return infer

# Register all elementwise ops
for op in ["Add", "Sub", "Mul", "Div", "Pow"]:
    registry.register("", op, opsets=range(7, 20))(_make_elementwise_inferrer(op))

3. Opset-aware registration:

# shape_inference/_registry.py
class OpShapeInferenceRegistry:
    def register(self, domain: str, op_type: str, opsets: range | int | None = None):
        """Register shape inference for specific opset versions.
        
        Args:
            domain: ONNX domain (e.g., "", "com.microsoft")
            op_type: Operator type (e.g., "Add", "Transpose")
            opsets: Opset versions to register for. None means all versions.
        """
        ...
    
    def get(self, domain: str, op_type: str, opset: int) -> Callable | None:
        """Get shape inference function for specific opset version."""
        ...

Example Usage (Target API)

import onnx_ir as ir
from onnx_ir.passes.common import SymbolicShapeInferencePass

# Load model with symbolic batch dimension
model = ir.load("model.onnx")

# Run symbolic shape inference
pass_ = SymbolicShapeInferencePass()
result = pass_(model)

# Shapes are now propagated - same Shape class, enhanced SymbolicDim
for value in model.graph.values():
    print(f"{value.name}: {value.shape}")
    # Example output: "output: Shape([batch, 256, seq_len + 1])"

# Evaluate with concrete values
shape = result.model.graph.outputs[0].shape
concrete = shape.evaluate({"batch": 32, "seq_len": 128})
# Result: (32, 256, 129)

# Arithmetic on symbolic dimensions
batch = ir.SymbolicDim("batch")
seq_len = ir.SymbolicDim("seq_len")
new_dim = batch * seq_len  # SymbolicDim with expr: batch*seq_len

# Simplification via SymPy
dim = ir.SymbolicDim("N") + 0
dim.simplify()  # Returns SymbolicDim("N")

Notes

  • Start with read-only inference (Phase 1-3), add constraint solving later (Phase 4)
  • Prioritize commonly used ops in Phase 3
  • SymPy handles expression simplification, substitution, and comparison automatically
  • SymPy's solve() can be used for constraint solving in Phase 4
  • Ensure thread-safety for registry access
  • SymPy is a required dependency; add to pyproject.toml if not present
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment