Skip to content

Instantly share code, notes, and snippets.

@antonl
Last active May 5, 2025 19:57
Show Gist options
  • Select an option

  • Save antonl/2359b0b6754c8c48b2d53d8535b7dedc to your computer and use it in GitHub Desktop.

Select an option

Save antonl/2359b0b6754c8c48b2d53d8535b7dedc to your computer and use it in GitHub Desktop.
Toy task queue with common subexpression elimination
"""
Implements a simple task computation framework inspired by Dask.
The tasks are encoded as a map where the keys are the names of
values and the values are either computed values, references to
other keys, or tuples representing a function call.
For example, the function call fn(x) may be represented as
```python
{
'x': 1,
'fn': Q(fn=fn, args=(R('x'),), kwargs=())
}
```
However in this case it's probably better to represent it as {'fn': (fn, 1)}.
"""
from collections import Counter, defaultdict
from typing import Any, TypeVar, NamedTuple, overload
from collections.abc import Callable, Hashable
from concurrent.futures import Executor, Future, wait
from inspect import signature
T = TypeVar("T")
type K = str | tuple[str, *tuple[Hashable, ...]]
type V = Any
class R(NamedTuple):
key: K
class Q[T](NamedTuple):
fn: Callable[..., T]
args: tuple["Q | R | V", ...]
kwargs: tuple[tuple[str, "Q | R | V"], ...]
type TaskT = dict[K, Q]
class Delayed[T]:
def __init__(self, fn: Callable[..., T], name: str | None = None):
self.fn = fn
self.name = fn.__qualname__ if name is None else name
def __call__(self, *args: Q | V, **kwargs) -> Q[T]:
# raises TypeError if not callable with arguments
try:
signature(self.fn).bind(*args, **kwargs)
except TypeError as e:
raise TypeError(f"unable to call {self.name}. {e.args[0]}") from None
return Q[T](fn=self.fn, args=args, kwargs=tuple(kwargs.items()))
def __repr__(self) -> str:
return f"Delayed(fn={self.fn}, name={self.name})"
@overload
def delayed[T](fn_or_name: str) -> Callable[[Callable[..., T]], Delayed[T]]: ...
@overload
def delayed[T](fn_or_name: Callable[..., T]) -> Delayed[T]: ...
def delayed[T](fn_or_name: Callable[..., T] | str) -> Delayed[T] | Callable[[Callable[..., T]], Delayed[T]]:
if callable(fn_or_name):
return Delayed[T](fn=fn_or_name)
else:
def _inner(fn: Callable[..., T]) -> Delayed[T]:
return Delayed[T](fn=fn, name=fn_or_name)
return _inner
def stable_key(item: Q) -> K:
return (item.fn.__qualname__, id(item))
def build_adjacency_list(*value: Q) -> tuple[tuple[K], TaskT]:
counts = Counter[K]()
seen = {}
refs = {}
def dfs(value: Q[T] | V) -> Q[T] | V:
if not isinstance(value, Q):
# reference or bare value
return value
# if seen, look up ref for the hash and replace arg with it
key = stable_key(value)
if key in seen:
return seen[key]
# otherwise make fresh reference and recurse down
args = tuple(dfs(a) for a in value.args)
kwargs = tuple((name, dfs(a)) for name, a in value.kwargs)
name = value.fn.__qualname__
ref = R((name, counts[name]))
counts[name] += 1
item = Q[T](value.fn, args, kwargs)
refs[ref.key] = item
seen[key] = ref
return ref
roots = []
for v in value:
roots.append(dfs(v))
return tuple(roots), refs
def get(*args: Q, pool: Executor, progress: Callable[[set, set, set, set, dict[K, Exception]], None] | None = None):
remaining_count = Counter[K]()
depends_on = defaultdict[K, set](set)
roots, values = build_adjacency_list(*args)
for key, task in values.items():
remaining_count[key] = 0
for a in task.args:
if isinstance(a, R):
deps = depends_on[a.key]
if key not in deps:
deps.add(key)
remaining_count[key] += 1
for _, a in task.kwargs:
if isinstance(a, R):
deps = depends_on[a.key]
if key not in deps:
deps.add(key)
remaining_count[key] += 1
def _lookup(item):
if isinstance(item, R):
return values[item.key]
else:
return item
pending = set(remaining_count.keys())
completed = set()
running = set()
failed = set()
failed_msg = {}
def submit(key: K):
task = values[key]
args = tuple(_lookup(item) for item in task.args)
kwargs = dict((name, _lookup(item)) for name, item in task.kwargs)
def fail(key):
pending.discard(key)
failed.add(key)
for item in depends_on[key]:
del remaining_count[item]
if item not in failed:
fail(item)
def on_complete(future: Future):
running.remove(key)
del remaining_count[key]
try:
values[key] = future.result()
completed.add(key)
for item in depends_on[key]:
remaining_count[item] -= 1
except Exception as e:
fail(key)
failed_msg[key] = e
try:
if callable(progress):
progress(pending, running, completed, failed, failed_msg)
except:
pass
pending.remove(key)
running.add(key)
future = pool.submit(task.fn, *args, **kwargs)
future.add_done_callback(on_complete)
return future
while remaining_count:
ready = []
for k, v in remaining_count.items():
if v == 0 and k not in running:
ready.append(k)
futures = []
for key in ready:
futures.append(submit(key))
wait(futures, return_when="FIRST_COMPLETED")
result = []
for r in roots:
if isinstance(r, R):
result.append(values[r.key])
else:
result.append(r)
return tuple(result)
if __name__ == "__main__":
from pprint import pprint
from concurrent.futures import ThreadPoolExecutor
from loky import ProcessPoolExecutor
@delayed
def double(x: int) -> int:
return 2 * x
@delayed("my_add")
def add(a: int, b: int) -> int:
return a + b
@delayed
def length(*, kwarg: str) -> int:
return len(kwarg)
@delayed
def fail():
raise RuntimeError()
four = double(2)
sixteen = double(add(four, four))
to_fail = double(fail())
def progress(pending, running, completed, failed, failed_msg):
num_pending = len(pending)
num_running = len(running)
num_completed = len(completed)
num_failed = len(failed)
if failed_msg:
print(failed_msg)
total = num_pending + num_running + num_completed + num_failed
processed = num_completed + num_failed
print(f"Processed {100*processed / total:.1f}%, Failed {num_failed}, Completed {num_completed}")
with ThreadPoolExecutor(max_workers=1) as pool:
result = get(sixteen, four, pool=pool, progress=progress)
pprint(result)
result = get(sixteen, to_fail, pool=pool, progress=progress)
pprint(result)
result = get(double(length(kwarg="mystring")), pool=pool, progress=progress)
pprint(result)
with ProcessPoolExecutor(max_workers=1) as pool:
result = get(sixteen, pool=pool, progress=progress)
pprint(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment