Created
April 28, 2025 04:21
-
-
Save junrushao/39f40f6191d403726f310319c6ed787c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import time | |
| import mlc.dataclasses as mlcd | |
| from mlc import DepGraph | |
| from tqdm import tqdm | |
| # from mlc.core._dep_graph_py import DepGraph | |
| random.seed(1234) | |
| @mlcd.py_class(repr=False) | |
| class Var(mlcd.PyClass): | |
| name: str | |
| def __str__(self) -> str: | |
| return self.name | |
| @mlcd.py_class(repr=False) | |
| class Stmt(mlcd.PyClass): | |
| args: list[Var] | |
| outs: list[Var] | |
| def __str__(self) -> str: | |
| lhs = ", ".join(x.name for x in self.outs) if self.outs else "(empty)" | |
| rhs = ", ".join(x.name for x in self.args) if self.args else "(empty)" | |
| return f"{lhs} := {rhs}" | |
| def stmt_inputs(stmt: Stmt) -> list[Var]: | |
| return stmt.args | |
| def stmt_outputs(stmt: Stmt) -> list[Var]: | |
| return stmt.outs | |
| def generate_graph( | |
| num_nodes: int = 10, | |
| num_input_vars: int = 3, | |
| max_num_inputs: int = 5, | |
| max_num_outputs: int = 5, | |
| ) -> tuple[list[Stmt], list[Var]]: | |
| all_vars = [] | |
| all_stmts = [] | |
| for i in range(num_input_vars): | |
| all_vars.append(Var(f"v_{i}")) | |
| for i in tqdm(range(num_nodes)): | |
| num_inputs = min( | |
| random.randint(1, max_num_inputs), | |
| len(all_vars), | |
| ) | |
| stmt_input_vars = random.sample(all_vars, num_inputs) | |
| num_outputs = random.randint(1, max_num_outputs) | |
| stmt_output_vars = [] | |
| for j in range(num_outputs): | |
| var = Var(f"v_{len(all_vars)}") | |
| all_vars.append(var) | |
| stmt_output_vars.append(var) | |
| all_stmts.append(Stmt(args=stmt_input_vars, outs=stmt_output_vars)) | |
| return all_stmts, all_vars[:num_input_vars] | |
| def main() -> None: | |
| print("Generating graph...") | |
| stmts, input_vars = generate_graph( | |
| num_nodes=1000000, | |
| num_input_vars=3, | |
| max_num_inputs=5, | |
| max_num_outputs=5, | |
| ) | |
| print("Building dep graph...") | |
| st_time = time.monotonic() | |
| for cnt in tqdm(range(10)): | |
| DepGraph.from_stmts(input_vars, stmts, stmt_inputs, stmt_outputs) | |
| et_time = time.monotonic() | |
| print(f"Average time per iteration: {(et_time - st_time) / 10:.2f} seconds") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import dataclasses | |
| from collections.abc import Callable, Generator | |
| from typing import Any | |
| type Stmt = Any | |
| type Var = Any | |
| @dataclasses.dataclass(slots=True) | |
| class DepNode: | |
| stmt: Stmt | |
| input_vars: list[Var] | |
| output_vars: list[Var] | |
| prev: DepNode | None = None | |
| next: DepNode | None = None | |
| def clear(self) -> None: | |
| self.stmt = None | |
| self.input_vars.clear() | |
| self.output_vars.clear() | |
| self.prev = None | |
| self.next = None | |
| def is_(self, other: Any) -> bool: | |
| return self is other | |
| @dataclasses.dataclass(slots=True) | |
| class DepGraph: | |
| stmt_to_inputs: Callable[[Stmt], list[Var]] | |
| stmt_to_outputs: Callable[[Stmt], list[Var]] | |
| _stmt_to_node: dict[Stmt, DepNode] | |
| _var_to_producer: dict[Var, DepNode] | |
| _var_to_consumers: dict[Var, list[DepNode]] | |
| _head: DepNode | |
| @staticmethod | |
| def from_stmts( | |
| input_vars: list[Var], | |
| stmts: list[Stmt], | |
| stmt_to_inputs: Callable[[Stmt], list[Var]], | |
| stmt_to_outputs: Callable[[Stmt], list[Var]], | |
| ) -> DepGraph: | |
| g = DepGraph(stmt_to_inputs, stmt_to_outputs, {}, {}, {}, None) # type: ignore[arg-type] | |
| g._head = DepNode(None, [], input_vars) | |
| g._stmt_to_node[None] = g._head | |
| for v in input_vars: | |
| g._var_to_producer[v] = g._head | |
| g._var_to_consumers[v] = [] | |
| prev = g._head | |
| for stmt in stmts: | |
| node = DepNode(stmt, stmt_to_inputs(stmt), stmt_to_outputs(stmt)) | |
| g.insert_after(prev, node) | |
| prev = node | |
| return g | |
| def clear(self) -> None: | |
| for node in self.nodes: | |
| node.clear() | |
| self._stmt_to_node.clear() | |
| self._var_to_producer.clear() | |
| self._var_to_consumers.clear() | |
| def create_node(self, stmt: Stmt) -> DepNode: | |
| return DepNode(stmt, self.stmt_to_inputs(stmt), self.stmt_to_outputs(stmt)) | |
| def get_node_from_stmt(self, stmt: Stmt) -> DepNode: | |
| if stmt in self._stmt_to_node: | |
| return self._stmt_to_node[stmt] | |
| raise RuntimeError(f"Stmt not in graph: {stmt}") | |
| def insert_before(self, anchor: DepNode, to_insert: DepNode) -> None: | |
| if anchor.prev is None: | |
| raise RuntimeError("Can't insert before _head") | |
| if anchor.stmt not in self._stmt_to_node: | |
| raise RuntimeError(f"Anchor not in graph: {anchor.stmt}") | |
| self._insert(anchor.prev, anchor, to_insert) | |
| def insert_after(self, anchor: DepNode, to_insert: DepNode) -> None: | |
| if anchor.stmt not in self._stmt_to_node: | |
| raise RuntimeError(f"Anchor not in graph: {anchor.stmt}") | |
| self._insert(anchor, anchor.next, to_insert) | |
| def erase_node(self, to_erase: DepNode) -> None: | |
| if to_erase.prev is None: | |
| raise RuntimeError("Can't erase _head") | |
| if to_erase.stmt not in self._stmt_to_node: | |
| raise RuntimeError(f"Node not in graph: {to_erase.stmt}") | |
| # Unlink | |
| del self._stmt_to_node[to_erase.stmt] | |
| to_erase.prev.next = to_erase.next | |
| if to_erase.next: | |
| to_erase.next.prev = to_erase.prev | |
| # Remove produced vars | |
| for var in to_erase.output_vars: | |
| if self._var_to_consumers[var]: | |
| raise RuntimeError(f"Produced var still has consumers: {var}") | |
| del self._var_to_producer[var] | |
| del self._var_to_consumers[var] | |
| # Remove from consumer lists of inputs | |
| for var in to_erase.input_vars: | |
| if var not in self._var_to_producer: | |
| raise RuntimeError(f"Var not produced: {var}") | |
| lst = self._var_to_consumers[var] | |
| if to_erase not in lst: | |
| raise RuntimeError(f"Node not a consumer of {var}") | |
| lst.remove(to_erase) | |
| to_erase.clear() | |
| def replace(self, old_node: DepNode, new_node: DepNode) -> None: | |
| if old_node is new_node: | |
| return | |
| if old_node.prev is None: | |
| raise RuntimeError("Can't replace _head") | |
| if old_node.stmt not in self._stmt_to_node: | |
| raise RuntimeError(f"Old node not in graph: {old_node.stmt}") | |
| if new_node.prev or new_node.next: | |
| raise RuntimeError(f"New node already in graph: {new_node.stmt}") | |
| if len(old_node.output_vars) != len(new_node.output_vars): | |
| raise RuntimeError("Mismatched output count") | |
| # Swap outputs | |
| for i, old_var in enumerate(old_node.output_vars): | |
| new_var = new_node.output_vars[i] | |
| consumers = list(self._var_to_consumers[old_var]) | |
| for c in consumers: | |
| c.input_vars = [new_var if v == old_var else v for v in c.input_vars] | |
| del self._var_to_producer[old_var] | |
| del self._var_to_consumers[old_var] | |
| self._var_to_producer[new_var] = new_node | |
| self._var_to_consumers[new_var] = consumers | |
| # Remove old inputs | |
| for var in old_node.input_vars: | |
| self._var_to_consumers[var].remove(old_node) | |
| # Add new inputs | |
| for var in new_node.input_vars: | |
| if var not in self._var_to_producer: | |
| raise RuntimeError(f"Var not produced: {var}") | |
| self._var_to_consumers[var].append(new_node) | |
| # Link new_node | |
| new_node.prev = old_node.prev | |
| new_node.next = old_node.next | |
| old_node.prev.next = new_node | |
| if old_node.next: | |
| old_node.next.prev = new_node | |
| del self._stmt_to_node[old_node.stmt] | |
| if new_node.stmt in self._stmt_to_node: | |
| raise RuntimeError(f"Stmt already in graph: {new_node.stmt}") | |
| self._stmt_to_node[new_node.stmt] = new_node | |
| old_node.clear() | |
| def get_node_producers(self, node: DepNode) -> list[DepNode]: | |
| return [self._var_to_producer[v] for v in node.input_vars] | |
| def get_node_consumers(self, node: DepNode) -> list[DepNode]: | |
| out = [] | |
| for v in node.output_vars: | |
| out.extend(self._var_to_consumers[v]) | |
| return out | |
| def get_var_producer(self, var: Var) -> DepNode: | |
| if var in self._var_to_producer: | |
| return self._var_to_producer[var] | |
| raise RuntimeError(f"Var not produced: {var}") | |
| def get_var_consumers(self, var: Var) -> list[DepNode]: | |
| if var in self._var_to_consumers: | |
| return list(self._var_to_consumers[var]) | |
| raise RuntimeError(f"Var not consumed: {var}") | |
| @property | |
| def nodes(self) -> Generator[DepNode, None, None]: | |
| node: DepNode | None = self._head | |
| while node: | |
| yield node | |
| node = node.next | |
| def _insert(self, prev: DepNode | None, next: DepNode | None, to_insert: DepNode) -> None: | |
| # 1) Must not already be in the graph | |
| if to_insert.prev or to_insert.next: | |
| raise RuntimeError(f"Node is already in the graph: {to_insert.stmt}") | |
| if to_insert.stmt in self._stmt_to_node: | |
| raise RuntimeError(f"Stmt already in the graph: {to_insert.stmt}") | |
| # 2) Link into the doubly-linked list | |
| self._stmt_to_node[to_insert.stmt] = to_insert | |
| to_insert.prev = prev | |
| to_insert.next = next | |
| if prev: | |
| prev.next = to_insert | |
| else: | |
| self._head = to_insert | |
| if next: | |
| next.prev = to_insert | |
| # 3) Register this node as producer for its outputs | |
| for var in to_insert.output_vars: | |
| if var in self._var_to_producer: | |
| raise RuntimeError(f"Variable already has a producer: {var}") | |
| self._var_to_producer[var] = to_insert | |
| self._var_to_consumers[var] = [] | |
| # 4) Register this node as consumer for its inputs | |
| for var in to_insert.input_vars: | |
| if var not in self._var_to_producer: | |
| raise RuntimeError(f"Variable is not produced by any node: {var}") | |
| self._var_to_consumers[var].append(to_insert) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment