Skip to content

Instantly share code, notes, and snippets.

@AntreasAntoniou
Created March 25, 2023 23:37
Show Gist options
  • Select an option

  • Save AntreasAntoniou/20b1434865242b7625325867a058e0ef to your computer and use it in GitHub Desktop.

Select an option

Save AntreasAntoniou/20b1434865242b7625325867a058e0ef to your computer and use it in GitHub Desktop.
A gist showing how one can have a list of dataloaders asynchronously yielded from
import asyncio
import queue
import threading
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm
# πŸ“˜ Create a synthetic dataset for linear regression
class SyntheticDataset(Dataset):
def __init__(self, a, b, noise_std, num_samples):
self.a = a
self.b = b
self.noise_std = noise_std
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
x = torch.randn(1)
noise = torch.randn(1) * self.noise_std
y = self.a * x + self.b + noise
return x, y
# πŸ“˜ Define a simple linear regression model
class LinearRegressionModel(nn.Module):
def __init__(self):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# πŸ“˜ Define the AsyncGeneratorWrapper class
from concurrent.futures import ThreadPoolExecutor
from typing import List
from torch.utils.data import DataLoader
class AsyncGeneratorWrapper:
def __init__(self, data_loaders: List[DataLoader]):
self.data_loaders = data_loaders
self.queue = asyncio.Queue()
def wrapper(self, data_loader):
for value in data_loader:
self.queue.put_nowait(value)
self.queue.put_nowait(None)
def __len__(self):
return sum(len(dl) for dl in self.data_loaders)
async def process_queue(self):
num_none_received = 0
while num_none_received < len(self.data_loaders):
value = await self.queue.get()
if value is None:
num_none_received += 1
else:
yield value
async def run(self):
with ThreadPoolExecutor() as executor:
tasks = [executor.submit(self.wrapper, dl) for dl in self.data_loaders]
for future in tasks:
await asyncio.wrap_future(future)
# πŸ“˜ Helper function to convert an async generator to a sync generator
def async_to_sync(async_generator):
q = queue.Queue()
async def _async_consumer():
async for item in async_generator:
q.put(item)
q.put(None)
def _sync_generator():
while True:
item = q.get()
if item is None:
break
yield item
threading.Thread(target=lambda: asyncio.run(_async_consumer())).start()
return _sync_generator()
# πŸ“˜ Training loop
def train_model(data_generator, model, criterion, optimizer, num_epochs, length):
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
# πŸ“˜ Run the AsyncGeneratorWrapper and get the sync generator
asyncio.run(data_generator.run())
sync_generator = async_to_sync(data_generator.process_queue())
with tqdm.tqdm(total=len(data_generator)) as pbar:
for inputs, targets in sync_generator:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
pbar.update(1)
pbar.set_description(f"Loss: {loss.item():.4f}")
# πŸ“˜ Create DataLoader instances
datasets = [
SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=100),
SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=200),
SyntheticDataset(a=2, b=3, noise_std=0.1, num_samples=300),
]
data_loaders = [DataLoader(dataset, batch_size=1, shuffle=True) for dataset in datasets]
# πŸ“˜ Create the AsyncGeneratorWrapper instance
async_data_generator = AsyncGeneratorWrapper(data_loaders)
# πŸ“˜ Initialize the model, loss function, and optimizer
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# πŸ“˜ Train the model
num_epochs = 10
train_model(
async_data_generator,
model,
criterion,
optimizer,
num_epochs,
len(async_data_generator),
)
@AntreasAntoniou
Copy link
Author

Original problem:

I have a list of generators in python. Each generator has different yielding speeds. I want to be able to yield samples as soon as they are available from any generator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment