Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Last active June 30, 2025 11:13
Show Gist options
  • Select an option

  • Save dhbrojas/514538cc9e5443d8cafe9ef48198aabf to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/514538cc9e5443d8cafe9ef48198aabf to your computer and use it in GitHub Desktop.
Beautiful ARLM Sequence Packing & Padding
from dataclasses import dataclass
from typing import List, Iterator
@dataclass
class Sequence:
"""Contains a single token sequence"""
x: List[int]
y: List[int]
document: int
"""The unique ID of the document this sequence belongs to"""
source: int
"""The unique ID of the source this sequence belongs to"""
@dataclass
class TrainSequence:
"""Contains many packed token sequences"""
x: List[int]
y: List[int]
document: List[int]
"""For each token in `x`, the unique ID of the document it belongs to"""
source: List[int]
"""For each token in `y`, the unique ID of the source it belongs to"""
def __post_init__(self):
assert len(self.x) \
== len(self.y) \
== len(self.document) \
== len(self.source)
def __len__(self):
return len(self.x)
def __add__(self, other: TrainSequence) -> TrainSequence:
return TrainSequence(
x=self.x + other.x,
y=self.y + other.y,
document=self.document + other.document,
source=self.source + other.source,
)
def __getitem__(self, key: slice) -> TrainSequence:
return TrainSequence(
x=self.x[key],
y=self.y[key],
document=self.document[key],
source=self.source[key],
)
@staticmethod
def from_sequence(sequence: Sequence) -> TrainSequence:
return TrainSequence(
x=sequence.x,
y=sequence.y,
document=[sequence.document] * len(sequence.x),
source=[sequence.source] * len(sequence.x),
)
@staticmethod
def padding(length: int, pad: int) -> TrainSequence:
return TrainSequence(
x=[pad] * length,
y=[-100] * length,
document=[-1] * length,
source=[-1] * length,
)
class SequenceCollator:
def __init__(
self,
*,
sequences: Iterator[Sequence],
sequence_length: int,
pad: int,
max_num_buckets: int = 32,
):
self.sequences = sequences
self.sequence_length = sequence_length
self.pad = pad
self.max_num_buckets = max_num_buckets
self.buckets = []
def next(self) -> TrainSequence:
while True:
if len(self.buckets) > self.max_num_buckets or self.sequences is None:
if len(self.buckets) == 0:
# We've exhausted data source and have no more buckets.
return None
# Find the largest bucket and return it.
highest = 0
highest_index = 0
for i in range(len(self.buckets)):
if len(self.buckets[i]) > highest:
highest = len(self.buckets[i])
highest_index = i
sequence = self.buckets.pop(highest_index)
missing = self.sequence_length - len(sequence)
sequence = sequence + TrainSequence.padding(missing, self.pad)
return sequence
try:
sequence = next(self.sequences)
except StopIteration:
self.sequences = None
continue
sequence = TrainSequence.from_sequence(sequence)
while len(sequence) > 0:
subsequence = sequence[:self.sequence_length]
sequence = sequence[self.sequence_length:]
for i in range(len(self.buckets)):
if len(self.buckets[i]) + len(subsequence) <= self.sequence_length:
self.buckets[i] = self.buckets[i] + subsequence
subsequence = None
break
if subsequence is not None:
# Could not find a fit, create a new bucket.
self.buckets.append(subsequence)
# Example usage
EOS = 0
collator = SequenceCollator(
sequences=iter([
Sequence(
x=[1, 2, 3],
y=[2, 3, EOS],
source=1,
document=1,
)
]),
sequence_length=16,
pad=0,
)
collator.next()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment