Created
March 25, 2023 23:37
-
-
Save AntreasAntoniou/20b1434865242b7625325867a058e0ef to your computer and use it in GitHub Desktop.
A gist showing how one can have a list of dataloaders asynchronously yielded from
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import queue | |
| import threading | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import tqdm | |
| # π Create a synthetic dataset for linear regression | |
| class SyntheticDataset(Dataset): | |
| def __init__(self, a, b, noise_std, num_samples): | |
| self.a = a | |
| self.b = b | |
| self.noise_std = noise_std | |
| self.num_samples = num_samples | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, idx): | |
| x = torch.randn(1) | |
| noise = torch.randn(1) * self.noise_std | |
| y = self.a * x + self.b + noise | |
| return x, y | |
| # π Define a simple linear regression model | |
| class LinearRegressionModel(nn.Module): | |
| def __init__(self): | |
| super(LinearRegressionModel, self).__init__() | |
| self.linear = nn.Linear(1, 1) | |
| def forward(self, x): | |
| return self.linear(x) | |
| # π Define the AsyncGeneratorWrapper class | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List | |
| from torch.utils.data import DataLoader | |
| class AsyncGeneratorWrapper: | |
| def __init__(self, data_loaders: List[DataLoader]): | |
| self.data_loaders = data_loaders | |
| self.queue = asyncio.Queue() | |
| def wrapper(self, data_loader): | |
| for value in data_loader: | |
| self.queue.put_nowait(value) | |
| self.queue.put_nowait(None) | |
| def __len__(self): | |
| return sum(len(dl) for dl in self.data_loaders) | |
| async def process_queue(self): | |
| num_none_received = 0 | |
| while num_none_received < len(self.data_loaders): | |
| value = await self.queue.get() | |
| if value is None: | |
| num_none_received += 1 | |
| else: | |
| yield value | |
| async def run(self): | |
| with ThreadPoolExecutor() as executor: | |
| tasks = [executor.submit(self.wrapper, dl) for dl in self.data_loaders] | |
| for future in tasks: | |
| await asyncio.wrap_future(future) | |
| # π Helper function to convert an async generator to a sync generator | |
| def async_to_sync(async_generator): | |
| q = queue.Queue() | |
| async def _async_consumer(): | |
| async for item in async_generator: | |
| q.put(item) | |
| q.put(None) | |
| def _sync_generator(): | |
| while True: | |
| item = q.get() | |
| if item is None: | |
| break | |
| yield item | |
| threading.Thread(target=lambda: asyncio.run(_async_consumer())).start() | |
| return _sync_generator() | |
| # π Training loop | |
| def train_model(data_generator, model, criterion, optimizer, num_epochs, length): | |
| for epoch in range(num_epochs): | |
| print(f"Epoch {epoch + 1}/{num_epochs}") | |
| # π Run the AsyncGeneratorWrapper and get the sync generator | |
| asyncio.run(data_generator.run()) | |
| sync_generator = async_to_sync(data_generator.process_queue()) | |
| with tqdm.tqdm(total=len(data_generator)) as pbar: | |
| for inputs, targets in sync_generator: | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, targets) | |
| loss.backward() | |
| optimizer.step() | |
| pbar.update(1) | |
| pbar.set_description(f"Loss: {loss.item():.4f}") | |
| # π Create DataLoader instances | |
| datasets = [ | |
| SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=100), | |
| SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=200), | |
| SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=300), | |
| ] | |
| data_loaders = [DataLoader(dataset, batch_size=1, shuffle=True) for dataset in datasets] | |
| # π Create the AsyncGeneratorWrapper instance | |
| async_data_generator = AsyncGeneratorWrapper(data_loaders) | |
| # π Initialize the model, loss function, and optimizer | |
| model = LinearRegressionModel() | |
| criterion = nn.MSELoss() | |
| optimizer = optim.SGD(model.parameters(), lr=0.01) | |
| # π Train the model | |
| num_epochs = 10 | |
| train_model( | |
| async_data_generator, | |
| model, | |
| criterion, | |
| optimizer, | |
| num_epochs, | |
| len(async_data_generator), | |
| ) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Original problem:
I have a list of generators in python. Each generator has different yielding speeds. I want to be able to yield samples as soon as they are available from any generator.