Last active
November 13, 2024 02:49
-
-
Save AndyGrant/42cac0f32e5fa044c9d01b509c66a757 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| import multiprocessing | |
| class BatchedExecutionPool(): | |
| def __init__( | |
| self, | |
| input_generator : typing.Generator, | |
| process_function : typing.Callable[..., typing.Any], | |
| process_function_args : typing.Optional[typing.Iterable[typing.Any]] | |
| ) -> None: | |
| self.input_generator = input_generator | |
| self.process_function = process_function | |
| self.process_function_args = process_function_args | |
| def execute( | |
| self, | |
| threads : int, | |
| batchsize : int | |
| ) -> int: | |
| in_queue = multiprocessing.Queue() | |
| out_queue = multiprocessing.Queue() | |
| workers = [ | |
| multiprocessing.Process( | |
| target=BatchedExecutionPool._process_function_wrapper, | |
| args=(in_queue, out_queue, self.process_function, self.process_function_args), | |
| daemon=True | |
| ) for f in range(threads) | |
| ] | |
| for worker in workers: | |
| worker.start() | |
| while True: | |
| n = self._enqueue_elements(batchsize, in_queue) | |
| for f in range(n): | |
| yield out_queue.get() | |
| if n != batchsize: | |
| break | |
| for f in range(threads): | |
| in_queue.put(None) | |
| for worker in workers: | |
| worker.join() | |
| def _enqueue_elements( | |
| self, | |
| batchsize : int, | |
| in_queue : multiprocessing.Queue | |
| ) -> int: | |
| for f in range(batchsize): | |
| try: in_queue.put(next(self.input_generator)) | |
| except StopIteration: return f | |
| return batchsize | |
| @staticmethod | |
| def _process_function_wrapper( | |
| in_queue : multiprocessing.Queue, | |
| out_queue : multiprocessing.Queue, | |
| process_function : typing.Callable[..., typing.Any], | |
| process_function_args : typing.Optional[typing.Iterable[typing.Any]] | |
| ) -> None: | |
| while (data := in_queue.get()) != None: | |
| out_queue.put(process_function(data, process_function_args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment