Last active
June 30, 2025 11:13
-
-
Save dhbrojas/514538cc9e5443d8cafe9ef48198aabf to your computer and use it in GitHub Desktop.
Beautiful ARLM Sequence Packing & Padding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import List, Iterator | |
| @dataclass | |
| class Sequence: | |
| """Contains a single token sequence""" | |
| x: List[int] | |
| y: List[int] | |
| document: int | |
| """The unique ID of the document this sequence belongs to""" | |
| source: int | |
| """The unique ID of the source this sequence belongs to""" | |
| @dataclass | |
| class TrainSequence: | |
| """Contains many packed token sequences""" | |
| x: List[int] | |
| y: List[int] | |
| document: List[int] | |
| """For each token in `x`, the unique ID of the document it belongs to""" | |
| source: List[int] | |
| """For each token in `y`, the unique ID of the source it belongs to""" | |
| def __post_init__(self): | |
| assert len(self.x) \ | |
| == len(self.y) \ | |
| == len(self.document) \ | |
| == len(self.source) | |
| def __len__(self): | |
| return len(self.x) | |
| def __add__(self, other: TrainSequence) -> TrainSequence: | |
| return TrainSequence( | |
| x=self.x + other.x, | |
| y=self.y + other.y, | |
| document=self.document + other.document, | |
| source=self.source + other.source, | |
| ) | |
| def __getitem__(self, key: slice) -> TrainSequence: | |
| return TrainSequence( | |
| x=self.x[key], | |
| y=self.y[key], | |
| document=self.document[key], | |
| source=self.source[key], | |
| ) | |
| @staticmethod | |
| def from_sequence(sequence: Sequence) -> TrainSequence: | |
| return TrainSequence( | |
| x=sequence.x, | |
| y=sequence.y, | |
| document=[sequence.document] * len(sequence.x), | |
| source=[sequence.source] * len(sequence.x), | |
| ) | |
| @staticmethod | |
| def padding(length: int, pad: int) -> TrainSequence: | |
| return TrainSequence( | |
| x=[pad] * length, | |
| y=[-100] * length, | |
| document=[-1] * length, | |
| source=[-1] * length, | |
| ) | |
| class SequenceCollator: | |
| def __init__( | |
| self, | |
| *, | |
| sequences: Iterator[Sequence], | |
| sequence_length: int, | |
| pad: int, | |
| max_num_buckets: int = 32, | |
| ): | |
| self.sequences = sequences | |
| self.sequence_length = sequence_length | |
| self.pad = pad | |
| self.max_num_buckets = max_num_buckets | |
| self.buckets = [] | |
| def next(self) -> TrainSequence: | |
| while True: | |
| if len(self.buckets) > self.max_num_buckets or self.sequences is None: | |
| if len(self.buckets) == 0: | |
| # We've exhausted data source and have no more buckets. | |
| return None | |
| # Find the largest bucket and return it. | |
| highest = 0 | |
| highest_index = 0 | |
| for i in range(len(self.buckets)): | |
| if len(self.buckets[i]) > highest: | |
| highest = len(self.buckets[i]) | |
| highest_index = i | |
| sequence = self.buckets.pop(highest_index) | |
| missing = self.sequence_length - len(sequence) | |
| sequence = sequence + TrainSequence.padding(missing, self.pad) | |
| return sequence | |
| try: | |
| sequence = next(self.sequences) | |
| except StopIteration: | |
| self.sequences = None | |
| continue | |
| sequence = TrainSequence.from_sequence(sequence) | |
| while len(sequence) > 0: | |
| subsequence = sequence[:self.sequence_length] | |
| sequence = sequence[self.sequence_length:] | |
| for i in range(len(self.buckets)): | |
| if len(self.buckets[i]) + len(subsequence) <= self.sequence_length: | |
| self.buckets[i] = self.buckets[i] + subsequence | |
| subsequence = None | |
| break | |
| if subsequence is not None: | |
| # Could not find a fit, create a new bucket. | |
| self.buckets.append(subsequence) | |
| # Example usage | |
| EOS = 0 | |
| collator = SequenceCollator( | |
| sequences=iter([ | |
| Sequence( | |
| x=[1, 2, 3], | |
| y=[2, 3, EOS], | |
| source=1, | |
| document=1, | |
| ) | |
| ]), | |
| sequence_length=16, | |
| pad=0, | |
| ) | |
| collator.next() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment