Skip to content

Instantly share code, notes, and snippets.

@pruksmhc
Last active July 24, 2025 01:21
Show Gist options
  • Select an option

  • Save pruksmhc/6f70c6f41b93fe2fdd16344181e062ec to your computer and use it in GitHub Desktop.

Select an option

Save pruksmhc/6f70c6f41b93fe2fdd16344181e062ec to your computer and use it in GitHub Desktop.
calculate_tokens_per_byte_fineweb2hq.py
import multiprocessing as mp
from datasets import load_dataset
from transformers import AutoTokenizer
from itertools import islice
import tqdm
MAX_DOCS_PER_LANG = 10_000
langs = [
'rus_Cyrl', 'cmn_Hani', 'deu_Latn', 'jpn_Jpan', 'spa_Latn',
'fra_Latn', 'ita_Latn', 'por_Latn', 'pol_Latn', 'nld_Latn',
'ind_Latn', 'tur_Latn', 'ces_Latn', 'vie_Latn', 'swe_Latn',
'fas_Arab', 'arb_Arab', 'ell_Grek', 'dan_Latn', 'hun_Latn'
]
_tokenizer = tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
def init_worker():
global _tokenizer
_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
def count_bytes_and_tokens(example):
global _tokenizer
text = example.get("text", "")
if not text.strip():
return (0, 0)
return (len(text.encode("utf-8")), len(_tokenizer(text, add_special_tokens=False).input_ids))
def process_lang(lang):
dataset = load_dataset("epfml/FineWeb2-HQ", lang, streaming=True)
examples = list(islice(dataset['train'], MAX_DOCS_PER_LANG))
with mp.get_context("spawn").Pool(processes=max(mp.cpu_count() - 1, 1), initializer=init_worker) as pool:
results = list(tqdm.tqdm(
pool.imap_unordered(count_bytes_and_tokens, examples, chunksize=10),
total=len(examples),
desc=f"Processing {lang}"
))
bytes_breakdown = [res[0] for res in results]
toks_breakdown = [res[1] for res in results]
total_bytes = sum(bytes_breakdown)
total_tokens = sum(toks_breakdown)
return lang, total_bytes, total_tokens, bytes_breakdown, toks_breakdown
if __name__ == "__main__":
bytes_per_lang = {}
tokens_per_lang = {}
tpb_per_lang = {}
lang_bytes_breakdown = {}
lang_toks_breakdown ={}
for lang in langs:
print("lang")
lang, total_bytes, total_tokens, bytes_breakdown, toks_breakdown = process_lang(lang)
bytes_per_lang[lang] = total_bytes
tokens_per_lang[lang] = total_tokens
tpb_per_lang[lang] = total_tokens / total_bytes
lang_bytes_breakdown[lang] = bytes_breakdown
lang_toks_breakdown[lang] = toks_breakdown
print("\nBytes per language:")
print(bytes_per_lang)
print("\nTokens per language:")
print(tokens_per_lang)
print("All together TPB")
print(tpb_per_lang)
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment