Skip to content

Instantly share code, notes, and snippets.

@zahlman
Last active December 19, 2024 06:40
Show Gist options
  • Select an option

  • Save zahlman/c1d2e98eac57cbb853ce2af515fecc23 to your computer and use it in GitHub Desktop.

Select an option

Save zahlman/c1d2e98eac57cbb853ce2af515fecc23 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import itertools, operator, random, timeit
GROUPS = 16384
ELEMENTS = 1048576
ESTIMATE = ELEMENTS // GROUPS
ITERS = 10
class Value:
def __init__(self):
self.group = random.randrange(GROUPS)
def __eq__(self, other):
# For testing.
return self.group == other.group
def _match_lists(x, y):
# For testing.
assert len(x) == len(y) == GROUPS
for sx, sy in zip(x, y):
count = min(len(sx), len(sy))
assert sx[:count] == sy[:count]
def shard_naive(values):
result = [[] for _ in range(GROUPS)]
for v in values:
result[v.group].append(v)
return result
def shard_presized_heuristic(values):
result = [[None] * ESTIMATE for _ in range(GROUPS)]
counts = [0 for _ in range(GROUPS)]
for v in values:
group = v.group
count = counts[group]
if count == ESTIMATE:
result[group].append(v)
else:
result[group][count] = v
counts[group] += 1
return result
def shard_library_sorted(values):
get_group = operator.attrgetter('group')
grouped = itertools.groupby(sorted(values, key=get_group), key=get_group)
return [list(shard) for group, shard in grouped]
def shard_128_radix(values):
# Make 128 supergroups, then GROUPS//128 groups from each.
# With the default GROUPS==16384, that works out to 128*128.
supergroups = [[] for _ in range(128)]
result = []
for v in values:
supergroups[v.group >> 7].append(v)
for s in supergroups:
chunk = [[] for _ in range(GROUPS // 128)]
for v in s:
chunk[v.group & 0x7f].append(v)
result.extend(chunk)
return result
data = [Value() for _ in range(ELEMENTS)]
_match_lists(shard_naive(data), shard_presized_heuristic(data))
_match_lists(shard_naive(data), shard_library_sorted(data))
_match_lists(shard_naive(data), shard_128_radix(data))
def time(label, algo):
print(label, timeit.timeit(lambda: algo(data), number=ITERS))
time("Naive:", shard_naive)
time("Presized:", shard_presized_heuristic)
time("Library sort / groupby:", shard_library_sorted)
time("Simple radixing first:", shard_128_radix)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment