Skip to content

Instantly share code, notes, and snippets.

@antonl
Last active April 7, 2025 23:14
Show Gist options
  • Select an option

  • Save antonl/c8185bbdea1153ff5949fba52c60e9e4 to your computer and use it in GitHub Desktop.

Select an option

Save antonl/c8185bbdea1153ff5949fba52c60e9e4 to your computer and use it in GitHub Desktop.
import io
import os
from functools import lru_cache
from collections import Counter
from dataclasses import dataclass
import regex as re
import mmap
from concurrent.futures import ProcessPoolExecutor, as_completed
GPT2_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def _single_thread(input_path: str | os.PathLike, pattern: str, special_tokens: list[str], rank: int = 0, nranks: int = 1) -> Counter[str]:
token_regex = '|'.join(list(re.escape(token) for token in special_tokens))
matcher = re.compile(rf'({token_regex})|({pattern})')
# count unique words
word_counts: Counter[str] = Counter()
with open(input_path, 'r') as f:
for match in matcher.finditer(f.read()):
word = match.group(2)
if not word:
# during training we only process normal tokens
continue
word_counts[word] += 1
return word_counts
def _reduce_documents(input_path: str | os.PathLike, pattern: str, special_tokens: list[str], rank: int = 0, nranks: int = 1) -> Counter[str]:
token_regex = '|'.join(list(re.escape(token) for token in special_tokens))
matcher = re.compile(rf'({token_regex})|({pattern})')
# count unique words
word_counts: Counter[str] = Counter()
with open(input_path, 'rb') as f:
buf = io.BytesIO()
with mmap.mmap(f.fileno(), 0, flags=mmap.MAP_SHARED, prot=mmap.PROT_READ) as mm:
chunk_size = len(mm) // nranks
mm.seek(chunk_size*rank)
if rank > 0:
mm.readline() # skip to the next line
pos = mm.tell()
if rank != nranks - 1:
buf.write(mm[pos:pos+chunk_size])
buf.write(mm.readline())
else:
buf.write(mm[pos:])
for match in matcher.finditer(buf.getvalue().decode('utf-8')):
word = match.group(2)
if not word:
# during training we only process normal tokens
continue
word_counts[word] += 1
return word_counts
def train_bpe(input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str] | None = None, pattern: str = GPT2_PATTERN, max_workers: int | None = 1):
"""Train a BPE tokenizer from the path."""
special_tokens = special_tokens or []
special_token_count = len(special_tokens)
if vocab_size < 256 + special_token_count:
raise ValueError(f"Need at least {256 + special_token_count} vocab_size.")
# initialize vocab
vocab = {
i + special_token_count: bytes([i]) for i in range(256)
}
vocab.update({
i: token.encode('utf-8') for i, token in enumerate(special_tokens)
})
next_id = len(vocab)
def decode(codes):
return b",".join([vocab[c] for c in codes])
word_counts: Counter[str] = Counter()
if max_workers == 1:
word_counts += _single_thread(
input_path=input_path,
pattern=pattern,
special_tokens=special_tokens,
)
else:
nranks = max_workers or max_workers
if nranks is None:
nranks = 1
word_counts_new: Counter[str] = Counter()
with ProcessPoolExecutor(max_workers=nranks) as pool:
futures = [
pool.submit(
_reduce_documents,
input_path=input_path,
pattern=pattern,
special_tokens=special_tokens,
rank=i,
nranks=nranks
) for i in range(nranks)
]
for value in as_completed(futures):
word_counts_new += value.result()
word_tokens: dict[str, list[int]] = {}
for word in word_counts.keys():
word_tokens[word] = list(i + special_token_count for i in word.encode('utf-8'))
pair_counts: Counter[tuple[int, int]] = Counter()
pair_words: dict[tuple[int, int], set[str]] = {}
for word, count in word_counts.items():
tokens = word_tokens[word]
for pair in zip(tokens[:-1], tokens[1:]):
pair_counts[pair] += count
pair_words.setdefault(pair, set()).add(word)
merges = []
while len(vocab) < vocab_size:
if not pair_counts:
# we have merged all pairs, can stop
break
pair_to_merge, _ = max(pair_counts.items(), key=lambda item: (item[1], decode(item[0])))
# found candidate, merge
new_token = vocab[pair_to_merge[0]] + vocab[pair_to_merge[1]]
vocab[next_id] = new_token
merges.append((vocab[pair_to_merge[0]], vocab[pair_to_merge[1]]))
affected_words = list(pair_words[pair_to_merge])
for word in affected_words:
tokens = word_tokens[word]
count = word_counts[word]
# remove all pairs from counters
for pair in zip(tokens[:-1], tokens[1:]):
pair_counts[pair] -= count
pair_words[pair].discard(word)
if pair_counts[pair] == 0:
del pair_counts[pair]
# merge pairs
i = 0
while i < len(tokens) - 1:
if tokens[i] == pair_to_merge[0] and tokens[i+1] == pair_to_merge[1]:
tokens[i:i+2] = [next_id]
else:
i += 1
# update pair statistics
for pair in zip(tokens[:-1], tokens[1:]):
pair_counts[pair] += count
pair_words.setdefault(pair, set()).add(word)
word_tokens[word] = tokens
next_id += 1
return vocab, merges
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment