-
-
Save djm/dfa409b055bc0169d2913921ee116e49 to your computer and use it in GitHub Desktop.
Bulk utilities
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Utilities for working with bulk data and batches. | |
| """ | |
| import itertools | |
| def batches(items, batch_size=500): | |
| """ | |
| Given an iterable of items and a batch size, yield individual lists | |
| of items of maximum length `batch_size`. | |
| Args: | |
| items: | |
| batch_size: | |
| Returns: | |
| """ | |
| batch = [] | |
| for item in items: | |
| batch.append(item) | |
| if len(batch) % batch_size == 0: | |
| yield batch | |
| batch = [] | |
| if batch: | |
| yield batch | |
| def dedup(obs, key=hash): | |
| seen = set() | |
| for ob in obs: | |
| k = key(ob) | |
| if k not in seen: | |
| seen.add(k) | |
| yield ob | |
| def sort_then_groupby(seq, *, key): | |
| return itertools.groupby(sorted(seq, key=key), key=key) | |
| def dups(obs, key=hash): | |
| """ | |
| Given a sequence, return an iterable of duplicate elements (as determined | |
| by comparing hash values). | |
| Args: | |
| obs: Sequence to check | |
| Returns: | |
| Iterable of items | |
| """ | |
| seen = set() | |
| yielded = set() | |
| for ob in obs: | |
| k = key(ob) | |
| if k in seen: | |
| if k not in yielded: | |
| yield ob | |
| yielded.add(k) | |
| else: | |
| seen.add(k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment