Skip to content

Instantly share code, notes, and snippets.

@youtux
Created April 26, 2021 08:04
Show Gist options
  • Select an option

  • Save youtux/fdad2f116dae119bb0b022c6e8a54116 to your computer and use it in GitHub Desktop.

Select an option

Save youtux/fdad2f116dae119bb0b022c6e8a54116 to your computer and use it in GitHub Desktop.
query_yield_per - Execute a sqlalchemy query in many chunks, by advancing the primary key (or any unique key) at each chunk
def query_yield_per(query, next_batch_filter, size, limit=None):
"""Execute a query in many chunks, by advancing the ordering clause.
This is a generator function that yields one element of the query at the time.
The query must already have the ORDER BY clause, and it must not have any LIMIT or OFFSET.
The ``next_batch_filter`` parameter is used to determine the filter condition for each subsequent batch.
It should be a callable with one argument, a the row of the last item in the chunk, and it should return the
filter condition to fetch the next batch.
For example:
>>> class Foo(Base): ...
>>> query = Foo.query().order_by(Foo.id)
>>> list(query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=100))
[Foo(1), Foo(2), ... Foo(9323)]
:param query: The query to use
:type query: sqlalchemy.Query
:param next_batch_filter: The callable that, given the last item of the batch, returns the condition
to fetch the next batch
:type next_batch_filter: callable
:param size: The size of each batch
:type size: int
:param limit: The limit to be applied to the query
:type limit: int or None
"""
if size is None or size < 1:
raise ValueError("Size must be a positive number")
if limit is not None and limit < 1:
raise ValueError("Limit must be None or a positive number")
records_found = 0
next_batch_condition = True
while True:
if limit is not None:
chunk_size = min(size, limit - records_found)
else:
chunk_size = size
items = query.filter(next_batch_condition).limit(chunk_size).all()
if not items:
break
yield from items
records_found += len(items)
next_batch_condition = next_batch_filter(items[-1])
###### Tests ######
import pytest
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
Base = declarative_base()
class Foo(Base):
__tablename__ = 'foo'
id = Column(Integer, primary_key=True)
@pytest.fixture
def session():
engine = create_engine('sqlite:///:memory:', echo=True)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
session.add_all(Foo(id=i) for i in range(100))
session.commit()
session.close()
session = Session()
yield Session()
session.close()
engine.dispose()
@pytest.mark.parametrize('size', [1, 2, 42, 100, 1000])
def test_valid_sizes(session, size):
query = session.query(Foo).order_by(Foo.id)
query_yielded = query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=size)
assert list(foo.id for foo in query_yielded) == list(range(100))
@pytest.mark.parametrize('size', [0, None, -1])
def test_invalid_sizes(session, size):
query = session.query(Foo).order_by(Foo.id)
with pytest.raises(ValueError):
query_yielded = query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=size)
next(iter(query_yielded))
@pytest.mark.parametrize('limit,expected_count', [
(None, 100),
(1, 1),
(5, 5),
(42, 42),
(100, 100),
(1000, 100),
])
@pytest.mark.parametrize('size', [1, 2, 5, 43, 100, 10000])
def test_valid_limit(session, limit, size, expected_count):
query = session.query(Foo).order_by(Foo.id)
query_yielded = query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=size, limit=limit)
assert list(foo.id for foo in query_yielded) == list(range(expected_count))
@pytest.mark.parametrize('limit', [0, -1])
def test_invalid_limit(session, limit):
query = session.query(Foo).order_by(Foo.id)
with pytest.raises(ValueError):
query_yielded = query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=10, limit=limit)
next(iter(query_yielded))
@pytest.mark.parametrize('limit', [None, 1, 10])
@pytest.mark.parametrize('size', [1, 2, 10])
def test_no_initial_record(session, size, limit):
query = session.query(Foo).filter(False).order_by(Foo.id)
assert list(query_yield_per(query, next_batch_filter=lambda foo: Foo.id > foo.id, size=size, limit=limit)) == []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment