Last active
May 5, 2025 19:57
-
-
Save antonl/2359b0b6754c8c48b2d53d8535b7dedc to your computer and use it in GitHub Desktop.
Toy task queue with common subexpression elimination
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Implements a simple task computation framework inspired by Dask. | |
| The tasks are encoded as a map where the keys are the names of | |
| values and the values are either computed values, references to | |
| other keys, or tuples representing a function call. | |
| For example, the function call fn(x) may be represented as | |
| ```python | |
| { | |
| 'x': 1, | |
| 'fn': Q(fn=fn, args=(R('x'),), kwargs=()) | |
| } | |
| ``` | |
| However in this case it's probably better to represent it as {'fn': (fn, 1)}. | |
| """ | |
| from collections import Counter, defaultdict | |
| from typing import Any, TypeVar, NamedTuple, overload | |
| from collections.abc import Callable, Hashable | |
| from concurrent.futures import Executor, Future, wait | |
| from inspect import signature | |
| T = TypeVar("T") | |
| type K = str | tuple[str, *tuple[Hashable, ...]] | |
| type V = Any | |
| class R(NamedTuple): | |
| key: K | |
| class Q[T](NamedTuple): | |
| fn: Callable[..., T] | |
| args: tuple["Q | R | V", ...] | |
| kwargs: tuple[tuple[str, "Q | R | V"], ...] | |
| type TaskT = dict[K, Q] | |
| class Delayed[T]: | |
| def __init__(self, fn: Callable[..., T], name: str | None = None): | |
| self.fn = fn | |
| self.name = fn.__qualname__ if name is None else name | |
| def __call__(self, *args: Q | V, **kwargs) -> Q[T]: | |
| # raises TypeError if not callable with arguments | |
| try: | |
| signature(self.fn).bind(*args, **kwargs) | |
| except TypeError as e: | |
| raise TypeError(f"unable to call {self.name}. {e.args[0]}") from None | |
| return Q[T](fn=self.fn, args=args, kwargs=tuple(kwargs.items())) | |
| def __repr__(self) -> str: | |
| return f"Delayed(fn={self.fn}, name={self.name})" | |
| @overload | |
| def delayed[T](fn_or_name: str) -> Callable[[Callable[..., T]], Delayed[T]]: ... | |
| @overload | |
| def delayed[T](fn_or_name: Callable[..., T]) -> Delayed[T]: ... | |
| def delayed[T](fn_or_name: Callable[..., T] | str) -> Delayed[T] | Callable[[Callable[..., T]], Delayed[T]]: | |
| if callable(fn_or_name): | |
| return Delayed[T](fn=fn_or_name) | |
| else: | |
| def _inner(fn: Callable[..., T]) -> Delayed[T]: | |
| return Delayed[T](fn=fn, name=fn_or_name) | |
| return _inner | |
| def stable_key(item: Q) -> K: | |
| return (item.fn.__qualname__, id(item)) | |
| def build_adjacency_list(*value: Q) -> tuple[tuple[K], TaskT]: | |
| counts = Counter[K]() | |
| seen = {} | |
| refs = {} | |
| def dfs(value: Q[T] | V) -> Q[T] | V: | |
| if not isinstance(value, Q): | |
| # reference or bare value | |
| return value | |
| # if seen, look up ref for the hash and replace arg with it | |
| key = stable_key(value) | |
| if key in seen: | |
| return seen[key] | |
| # otherwise make fresh reference and recurse down | |
| args = tuple(dfs(a) for a in value.args) | |
| kwargs = tuple((name, dfs(a)) for name, a in value.kwargs) | |
| name = value.fn.__qualname__ | |
| ref = R((name, counts[name])) | |
| counts[name] += 1 | |
| item = Q[T](value.fn, args, kwargs) | |
| refs[ref.key] = item | |
| seen[key] = ref | |
| return ref | |
| roots = [] | |
| for v in value: | |
| roots.append(dfs(v)) | |
| return tuple(roots), refs | |
| def get(*args: Q, pool: Executor, progress: Callable[[set, set, set, set, dict[K, Exception]], None] | None = None): | |
| remaining_count = Counter[K]() | |
| depends_on = defaultdict[K, set](set) | |
| roots, values = build_adjacency_list(*args) | |
| for key, task in values.items(): | |
| remaining_count[key] = 0 | |
| for a in task.args: | |
| if isinstance(a, R): | |
| deps = depends_on[a.key] | |
| if key not in deps: | |
| deps.add(key) | |
| remaining_count[key] += 1 | |
| for _, a in task.kwargs: | |
| if isinstance(a, R): | |
| deps = depends_on[a.key] | |
| if key not in deps: | |
| deps.add(key) | |
| remaining_count[key] += 1 | |
| def _lookup(item): | |
| if isinstance(item, R): | |
| return values[item.key] | |
| else: | |
| return item | |
| pending = set(remaining_count.keys()) | |
| completed = set() | |
| running = set() | |
| failed = set() | |
| failed_msg = {} | |
| def submit(key: K): | |
| task = values[key] | |
| args = tuple(_lookup(item) for item in task.args) | |
| kwargs = dict((name, _lookup(item)) for name, item in task.kwargs) | |
| def fail(key): | |
| pending.discard(key) | |
| failed.add(key) | |
| for item in depends_on[key]: | |
| del remaining_count[item] | |
| if item not in failed: | |
| fail(item) | |
| def on_complete(future: Future): | |
| running.remove(key) | |
| del remaining_count[key] | |
| try: | |
| values[key] = future.result() | |
| completed.add(key) | |
| for item in depends_on[key]: | |
| remaining_count[item] -= 1 | |
| except Exception as e: | |
| fail(key) | |
| failed_msg[key] = e | |
| try: | |
| if callable(progress): | |
| progress(pending, running, completed, failed, failed_msg) | |
| except: | |
| pass | |
| pending.remove(key) | |
| running.add(key) | |
| future = pool.submit(task.fn, *args, **kwargs) | |
| future.add_done_callback(on_complete) | |
| return future | |
| while remaining_count: | |
| ready = [] | |
| for k, v in remaining_count.items(): | |
| if v == 0 and k not in running: | |
| ready.append(k) | |
| futures = [] | |
| for key in ready: | |
| futures.append(submit(key)) | |
| wait(futures, return_when="FIRST_COMPLETED") | |
| result = [] | |
| for r in roots: | |
| if isinstance(r, R): | |
| result.append(values[r.key]) | |
| else: | |
| result.append(r) | |
| return tuple(result) | |
| if __name__ == "__main__": | |
| from pprint import pprint | |
| from concurrent.futures import ThreadPoolExecutor | |
| from loky import ProcessPoolExecutor | |
| @delayed | |
| def double(x: int) -> int: | |
| return 2 * x | |
| @delayed("my_add") | |
| def add(a: int, b: int) -> int: | |
| return a + b | |
| @delayed | |
| def length(*, kwarg: str) -> int: | |
| return len(kwarg) | |
| @delayed | |
| def fail(): | |
| raise RuntimeError() | |
| four = double(2) | |
| sixteen = double(add(four, four)) | |
| to_fail = double(fail()) | |
| def progress(pending, running, completed, failed, failed_msg): | |
| num_pending = len(pending) | |
| num_running = len(running) | |
| num_completed = len(completed) | |
| num_failed = len(failed) | |
| if failed_msg: | |
| print(failed_msg) | |
| total = num_pending + num_running + num_completed + num_failed | |
| processed = num_completed + num_failed | |
| print(f"Processed {100*processed / total:.1f}%, Failed {num_failed}, Completed {num_completed}") | |
| with ThreadPoolExecutor(max_workers=1) as pool: | |
| result = get(sixteen, four, pool=pool, progress=progress) | |
| pprint(result) | |
| result = get(sixteen, to_fail, pool=pool, progress=progress) | |
| pprint(result) | |
| result = get(double(length(kwarg="mystring")), pool=pool, progress=progress) | |
| pprint(result) | |
| with ProcessPoolExecutor(max_workers=1) as pool: | |
| result = get(sixteen, pool=pool, progress=progress) | |
| pprint(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment