Last active
April 7, 2025 23:14
-
-
Save antonl/c8185bbdea1153ff5949fba52c60e9e4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io | |
| import os | |
| from functools import lru_cache | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| import regex as re | |
| import mmap | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| GPT2_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
| def _single_thread(input_path: str | os.PathLike, pattern: str, special_tokens: list[str], rank: int = 0, nranks: int = 1) -> Counter[str]: | |
| token_regex = '|'.join(list(re.escape(token) for token in special_tokens)) | |
| matcher = re.compile(rf'({token_regex})|({pattern})') | |
| # count unique words | |
| word_counts: Counter[str] = Counter() | |
| with open(input_path, 'r') as f: | |
| for match in matcher.finditer(f.read()): | |
| word = match.group(2) | |
| if not word: | |
| # during training we only process normal tokens | |
| continue | |
| word_counts[word] += 1 | |
| return word_counts | |
| def _reduce_documents(input_path: str | os.PathLike, pattern: str, special_tokens: list[str], rank: int = 0, nranks: int = 1) -> Counter[str]: | |
| token_regex = '|'.join(list(re.escape(token) for token in special_tokens)) | |
| matcher = re.compile(rf'({token_regex})|({pattern})') | |
| # count unique words | |
| word_counts: Counter[str] = Counter() | |
| with open(input_path, 'rb') as f: | |
| buf = io.BytesIO() | |
| with mmap.mmap(f.fileno(), 0, flags=mmap.MAP_SHARED, prot=mmap.PROT_READ) as mm: | |
| chunk_size = len(mm) // nranks | |
| mm.seek(chunk_size*rank) | |
| if rank > 0: | |
| mm.readline() # skip to the next line | |
| pos = mm.tell() | |
| if rank != nranks - 1: | |
| buf.write(mm[pos:pos+chunk_size]) | |
| buf.write(mm.readline()) | |
| else: | |
| buf.write(mm[pos:]) | |
| for match in matcher.finditer(buf.getvalue().decode('utf-8')): | |
| word = match.group(2) | |
| if not word: | |
| # during training we only process normal tokens | |
| continue | |
| word_counts[word] += 1 | |
| return word_counts | |
| def train_bpe(input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str] | None = None, pattern: str = GPT2_PATTERN, max_workers: int | None = 1): | |
| """Train a BPE tokenizer from the path.""" | |
| special_tokens = special_tokens or [] | |
| special_token_count = len(special_tokens) | |
| if vocab_size < 256 + special_token_count: | |
| raise ValueError(f"Need at least {256 + special_token_count} vocab_size.") | |
| # initialize vocab | |
| vocab = { | |
| i + special_token_count: bytes([i]) for i in range(256) | |
| } | |
| vocab.update({ | |
| i: token.encode('utf-8') for i, token in enumerate(special_tokens) | |
| }) | |
| next_id = len(vocab) | |
| def decode(codes): | |
| return b",".join([vocab[c] for c in codes]) | |
| word_counts: Counter[str] = Counter() | |
| if max_workers == 1: | |
| word_counts += _single_thread( | |
| input_path=input_path, | |
| pattern=pattern, | |
| special_tokens=special_tokens, | |
| ) | |
| else: | |
| nranks = max_workers or max_workers | |
| if nranks is None: | |
| nranks = 1 | |
| word_counts_new: Counter[str] = Counter() | |
| with ProcessPoolExecutor(max_workers=nranks) as pool: | |
| futures = [ | |
| pool.submit( | |
| _reduce_documents, | |
| input_path=input_path, | |
| pattern=pattern, | |
| special_tokens=special_tokens, | |
| rank=i, | |
| nranks=nranks | |
| ) for i in range(nranks) | |
| ] | |
| for value in as_completed(futures): | |
| word_counts_new += value.result() | |
| word_tokens: dict[str, list[int]] = {} | |
| for word in word_counts.keys(): | |
| word_tokens[word] = list(i + special_token_count for i in word.encode('utf-8')) | |
| pair_counts: Counter[tuple[int, int]] = Counter() | |
| pair_words: dict[tuple[int, int], set[str]] = {} | |
| for word, count in word_counts.items(): | |
| tokens = word_tokens[word] | |
| for pair in zip(tokens[:-1], tokens[1:]): | |
| pair_counts[pair] += count | |
| pair_words.setdefault(pair, set()).add(word) | |
| merges = [] | |
| while len(vocab) < vocab_size: | |
| if not pair_counts: | |
| # we have merged all pairs, can stop | |
| break | |
| pair_to_merge, _ = max(pair_counts.items(), key=lambda item: (item[1], decode(item[0]))) | |
| # found candidate, merge | |
| new_token = vocab[pair_to_merge[0]] + vocab[pair_to_merge[1]] | |
| vocab[next_id] = new_token | |
| merges.append((vocab[pair_to_merge[0]], vocab[pair_to_merge[1]])) | |
| affected_words = list(pair_words[pair_to_merge]) | |
| for word in affected_words: | |
| tokens = word_tokens[word] | |
| count = word_counts[word] | |
| # remove all pairs from counters | |
| for pair in zip(tokens[:-1], tokens[1:]): | |
| pair_counts[pair] -= count | |
| pair_words[pair].discard(word) | |
| if pair_counts[pair] == 0: | |
| del pair_counts[pair] | |
| # merge pairs | |
| i = 0 | |
| while i < len(tokens) - 1: | |
| if tokens[i] == pair_to_merge[0] and tokens[i+1] == pair_to_merge[1]: | |
| tokens[i:i+2] = [next_id] | |
| else: | |
| i += 1 | |
| # update pair statistics | |
| for pair in zip(tokens[:-1], tokens[1:]): | |
| pair_counts[pair] += count | |
| pair_words.setdefault(pair, set()).add(word) | |
| word_tokens[word] = tokens | |
| next_id += 1 | |
| return vocab, merges |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment