Skip to content

Instantly share code, notes, and snippets.

@junrushao
Created April 28, 2025 04:21
Show Gist options
  • Select an option

  • Save junrushao/39f40f6191d403726f310319c6ed787c to your computer and use it in GitHub Desktop.

Select an option

Save junrushao/39f40f6191d403726f310319c6ed787c to your computer and use it in GitHub Desktop.
import random
import time
import mlc.dataclasses as mlcd
from mlc import DepGraph
from tqdm import tqdm
# from mlc.core._dep_graph_py import DepGraph
random.seed(1234)
@mlcd.py_class(repr=False)
class Var(mlcd.PyClass):
name: str
def __str__(self) -> str:
return self.name
@mlcd.py_class(repr=False)
class Stmt(mlcd.PyClass):
args: list[Var]
outs: list[Var]
def __str__(self) -> str:
lhs = ", ".join(x.name for x in self.outs) if self.outs else "(empty)"
rhs = ", ".join(x.name for x in self.args) if self.args else "(empty)"
return f"{lhs} := {rhs}"
def stmt_inputs(stmt: Stmt) -> list[Var]:
return stmt.args
def stmt_outputs(stmt: Stmt) -> list[Var]:
return stmt.outs
def generate_graph(
num_nodes: int = 10,
num_input_vars: int = 3,
max_num_inputs: int = 5,
max_num_outputs: int = 5,
) -> tuple[list[Stmt], list[Var]]:
all_vars = []
all_stmts = []
for i in range(num_input_vars):
all_vars.append(Var(f"v_{i}"))
for i in tqdm(range(num_nodes)):
num_inputs = min(
random.randint(1, max_num_inputs),
len(all_vars),
)
stmt_input_vars = random.sample(all_vars, num_inputs)
num_outputs = random.randint(1, max_num_outputs)
stmt_output_vars = []
for j in range(num_outputs):
var = Var(f"v_{len(all_vars)}")
all_vars.append(var)
stmt_output_vars.append(var)
all_stmts.append(Stmt(args=stmt_input_vars, outs=stmt_output_vars))
return all_stmts, all_vars[:num_input_vars]
def main() -> None:
print("Generating graph...")
stmts, input_vars = generate_graph(
num_nodes=1000000,
num_input_vars=3,
max_num_inputs=5,
max_num_outputs=5,
)
print("Building dep graph...")
st_time = time.monotonic()
for cnt in tqdm(range(10)):
DepGraph.from_stmts(input_vars, stmts, stmt_inputs, stmt_outputs)
et_time = time.monotonic()
print(f"Average time per iteration: {(et_time - st_time) / 10:.2f} seconds")
if __name__ == "__main__":
main()
from __future__ import annotations
import dataclasses
from collections.abc import Callable, Generator
from typing import Any
type Stmt = Any
type Var = Any
@dataclasses.dataclass(slots=True)
class DepNode:
stmt: Stmt
input_vars: list[Var]
output_vars: list[Var]
prev: DepNode | None = None
next: DepNode | None = None
def clear(self) -> None:
self.stmt = None
self.input_vars.clear()
self.output_vars.clear()
self.prev = None
self.next = None
def is_(self, other: Any) -> bool:
return self is other
@dataclasses.dataclass(slots=True)
class DepGraph:
stmt_to_inputs: Callable[[Stmt], list[Var]]
stmt_to_outputs: Callable[[Stmt], list[Var]]
_stmt_to_node: dict[Stmt, DepNode]
_var_to_producer: dict[Var, DepNode]
_var_to_consumers: dict[Var, list[DepNode]]
_head: DepNode
@staticmethod
def from_stmts(
input_vars: list[Var],
stmts: list[Stmt],
stmt_to_inputs: Callable[[Stmt], list[Var]],
stmt_to_outputs: Callable[[Stmt], list[Var]],
) -> DepGraph:
g = DepGraph(stmt_to_inputs, stmt_to_outputs, {}, {}, {}, None) # type: ignore[arg-type]
g._head = DepNode(None, [], input_vars)
g._stmt_to_node[None] = g._head
for v in input_vars:
g._var_to_producer[v] = g._head
g._var_to_consumers[v] = []
prev = g._head
for stmt in stmts:
node = DepNode(stmt, stmt_to_inputs(stmt), stmt_to_outputs(stmt))
g.insert_after(prev, node)
prev = node
return g
def clear(self) -> None:
for node in self.nodes:
node.clear()
self._stmt_to_node.clear()
self._var_to_producer.clear()
self._var_to_consumers.clear()
def create_node(self, stmt: Stmt) -> DepNode:
return DepNode(stmt, self.stmt_to_inputs(stmt), self.stmt_to_outputs(stmt))
def get_node_from_stmt(self, stmt: Stmt) -> DepNode:
if stmt in self._stmt_to_node:
return self._stmt_to_node[stmt]
raise RuntimeError(f"Stmt not in graph: {stmt}")
def insert_before(self, anchor: DepNode, to_insert: DepNode) -> None:
if anchor.prev is None:
raise RuntimeError("Can't insert before _head")
if anchor.stmt not in self._stmt_to_node:
raise RuntimeError(f"Anchor not in graph: {anchor.stmt}")
self._insert(anchor.prev, anchor, to_insert)
def insert_after(self, anchor: DepNode, to_insert: DepNode) -> None:
if anchor.stmt not in self._stmt_to_node:
raise RuntimeError(f"Anchor not in graph: {anchor.stmt}")
self._insert(anchor, anchor.next, to_insert)
def erase_node(self, to_erase: DepNode) -> None:
if to_erase.prev is None:
raise RuntimeError("Can't erase _head")
if to_erase.stmt not in self._stmt_to_node:
raise RuntimeError(f"Node not in graph: {to_erase.stmt}")
# Unlink
del self._stmt_to_node[to_erase.stmt]
to_erase.prev.next = to_erase.next
if to_erase.next:
to_erase.next.prev = to_erase.prev
# Remove produced vars
for var in to_erase.output_vars:
if self._var_to_consumers[var]:
raise RuntimeError(f"Produced var still has consumers: {var}")
del self._var_to_producer[var]
del self._var_to_consumers[var]
# Remove from consumer lists of inputs
for var in to_erase.input_vars:
if var not in self._var_to_producer:
raise RuntimeError(f"Var not produced: {var}")
lst = self._var_to_consumers[var]
if to_erase not in lst:
raise RuntimeError(f"Node not a consumer of {var}")
lst.remove(to_erase)
to_erase.clear()
def replace(self, old_node: DepNode, new_node: DepNode) -> None:
if old_node is new_node:
return
if old_node.prev is None:
raise RuntimeError("Can't replace _head")
if old_node.stmt not in self._stmt_to_node:
raise RuntimeError(f"Old node not in graph: {old_node.stmt}")
if new_node.prev or new_node.next:
raise RuntimeError(f"New node already in graph: {new_node.stmt}")
if len(old_node.output_vars) != len(new_node.output_vars):
raise RuntimeError("Mismatched output count")
# Swap outputs
for i, old_var in enumerate(old_node.output_vars):
new_var = new_node.output_vars[i]
consumers = list(self._var_to_consumers[old_var])
for c in consumers:
c.input_vars = [new_var if v == old_var else v for v in c.input_vars]
del self._var_to_producer[old_var]
del self._var_to_consumers[old_var]
self._var_to_producer[new_var] = new_node
self._var_to_consumers[new_var] = consumers
# Remove old inputs
for var in old_node.input_vars:
self._var_to_consumers[var].remove(old_node)
# Add new inputs
for var in new_node.input_vars:
if var not in self._var_to_producer:
raise RuntimeError(f"Var not produced: {var}")
self._var_to_consumers[var].append(new_node)
# Link new_node
new_node.prev = old_node.prev
new_node.next = old_node.next
old_node.prev.next = new_node
if old_node.next:
old_node.next.prev = new_node
del self._stmt_to_node[old_node.stmt]
if new_node.stmt in self._stmt_to_node:
raise RuntimeError(f"Stmt already in graph: {new_node.stmt}")
self._stmt_to_node[new_node.stmt] = new_node
old_node.clear()
def get_node_producers(self, node: DepNode) -> list[DepNode]:
return [self._var_to_producer[v] for v in node.input_vars]
def get_node_consumers(self, node: DepNode) -> list[DepNode]:
out = []
for v in node.output_vars:
out.extend(self._var_to_consumers[v])
return out
def get_var_producer(self, var: Var) -> DepNode:
if var in self._var_to_producer:
return self._var_to_producer[var]
raise RuntimeError(f"Var not produced: {var}")
def get_var_consumers(self, var: Var) -> list[DepNode]:
if var in self._var_to_consumers:
return list(self._var_to_consumers[var])
raise RuntimeError(f"Var not consumed: {var}")
@property
def nodes(self) -> Generator[DepNode, None, None]:
node: DepNode | None = self._head
while node:
yield node
node = node.next
def _insert(self, prev: DepNode | None, next: DepNode | None, to_insert: DepNode) -> None:
# 1) Must not already be in the graph
if to_insert.prev or to_insert.next:
raise RuntimeError(f"Node is already in the graph: {to_insert.stmt}")
if to_insert.stmt in self._stmt_to_node:
raise RuntimeError(f"Stmt already in the graph: {to_insert.stmt}")
# 2) Link into the doubly-linked list
self._stmt_to_node[to_insert.stmt] = to_insert
to_insert.prev = prev
to_insert.next = next
if prev:
prev.next = to_insert
else:
self._head = to_insert
if next:
next.prev = to_insert
# 3) Register this node as producer for its outputs
for var in to_insert.output_vars:
if var in self._var_to_producer:
raise RuntimeError(f"Variable already has a producer: {var}")
self._var_to_producer[var] = to_insert
self._var_to_consumers[var] = []
# 4) Register this node as consumer for its inputs
for var in to_insert.input_vars:
if var not in self._var_to_producer:
raise RuntimeError(f"Variable is not produced by any node: {var}")
self._var_to_consumers[var].append(to_insert)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment