Last active
July 24, 2025 01:21
-
-
Save pruksmhc/6f70c6f41b93fe2fdd16344181e062ec to your computer and use it in GitHub Desktop.
calculate_tokens_per_byte_fineweb2hq.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import multiprocessing as mp | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| from itertools import islice | |
| import tqdm | |
| MAX_DOCS_PER_LANG = 10_000 | |
| langs = [ | |
| 'rus_Cyrl', 'cmn_Hani', 'deu_Latn', 'jpn_Jpan', 'spa_Latn', | |
| 'fra_Latn', 'ita_Latn', 'por_Latn', 'pol_Latn', 'nld_Latn', | |
| 'ind_Latn', 'tur_Latn', 'ces_Latn', 'vie_Latn', 'swe_Latn', | |
| 'fas_Arab', 'arb_Arab', 'ell_Grek', 'dan_Latn', 'hun_Latn' | |
| ] | |
| _tokenizer = tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B") | |
| def init_worker(): | |
| global _tokenizer | |
| _tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B") | |
| def count_bytes_and_tokens(example): | |
| global _tokenizer | |
| text = example.get("text", "") | |
| if not text.strip(): | |
| return (0, 0) | |
| return (len(text.encode("utf-8")), len(_tokenizer(text, add_special_tokens=False).input_ids)) | |
| def process_lang(lang): | |
| dataset = load_dataset("epfml/FineWeb2-HQ", lang, streaming=True) | |
| examples = list(islice(dataset['train'], MAX_DOCS_PER_LANG)) | |
| with mp.get_context("spawn").Pool(processes=max(mp.cpu_count() - 1, 1), initializer=init_worker) as pool: | |
| results = list(tqdm.tqdm( | |
| pool.imap_unordered(count_bytes_and_tokens, examples, chunksize=10), | |
| total=len(examples), | |
| desc=f"Processing {lang}" | |
| )) | |
| bytes_breakdown = [res[0] for res in results] | |
| toks_breakdown = [res[1] for res in results] | |
| total_bytes = sum(bytes_breakdown) | |
| total_tokens = sum(toks_breakdown) | |
| return lang, total_bytes, total_tokens, bytes_breakdown, toks_breakdown | |
| if __name__ == "__main__": | |
| bytes_per_lang = {} | |
| tokens_per_lang = {} | |
| tpb_per_lang = {} | |
| lang_bytes_breakdown = {} | |
| lang_toks_breakdown ={} | |
| for lang in langs: | |
| print("lang") | |
| lang, total_bytes, total_tokens, bytes_breakdown, toks_breakdown = process_lang(lang) | |
| bytes_per_lang[lang] = total_bytes | |
| tokens_per_lang[lang] = total_tokens | |
| tpb_per_lang[lang] = total_tokens / total_bytes | |
| lang_bytes_breakdown[lang] = bytes_breakdown | |
| lang_toks_breakdown[lang] = toks_breakdown | |
| print("\nBytes per language:") | |
| print(bytes_per_lang) | |
| print("\nTokens per language:") | |
| print(tokens_per_lang) | |
| print("All together TPB") | |
| print(tpb_per_lang) | |
| breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment