Created
July 17, 2025 21:12
-
-
Save antonl/e62dba7036a5d74748852a84798100b0 to your computer and use it in GitHub Desktop.
Nested async concurrency bug
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from dataclasses import dataclass, field | |
| from time import perf_counter_ns, sleep | |
| from typing import Literal | |
| @dataclass | |
| class API: | |
| state: list[tuple[str, Literal["start", "stop"], int]] = field(default_factory=list) | |
| sleep: float = 0.01 | |
| def transact(self, id_: str): | |
| # not thread-safe! | |
| self.state.append((id_, "start", perf_counter_ns())) | |
| sleep(self.sleep) | |
| self.state.append((id_, "stop", perf_counter_ns())) | |
| def verify(self) -> bool: | |
| queue = [] | |
| for id_, type_, ts in self.state: | |
| if type_ == "start": | |
| if queue and queue[-1][1] == "start": | |
| print(f"ERROR: nested sessions {id_} and {queue[-1][0]}") | |
| return False | |
| else: | |
| queue.append((id_, type_, ts)) | |
| else: | |
| if queue: | |
| if queue[-1][1] == "start" and queue[-1][0] != id_: | |
| print( | |
| f"ERROR: mismatched start/stop ids {id_} and {queue[-1][0]}" | |
| ) | |
| return False | |
| else: | |
| queue.pop() | |
| else: | |
| print(f"ERROR: bare stop {id_}") | |
| return False | |
| return True | |
| async def transact(api: API, id_: str, niter: int): | |
| for i in range(niter): | |
| # do some concurrent work | |
| await asyncio.sleep(api.sleep) | |
| # save results to a non-thread-safe API | |
| api.transact(id_) | |
| async def nested_async_transact(api: API, id_: str, niter: int): | |
| def _inner(): | |
| asyncio.run(transact(api, id_, niter)) | |
| await asyncio.to_thread(_inner) | |
| async def main(): | |
| api = API() | |
| assert api.verify() | |
| async with asyncio.TaskGroup() as tg: | |
| tg.create_task(transact(api, "task-1", 20)) | |
| tg.create_task(transact(api, "task-2", 20)) | |
| assert api.verify(), "normal case failed" | |
| api = API() | |
| assert api.verify() | |
| async with asyncio.TaskGroup() as tg: | |
| tg.create_task(transact(api, "task-1", niter=20)) | |
| tg.create_task(nested_async_transact(api, "task-2", niter=20)) | |
| assert api.verify(), "nested case failed" | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment