Last active
July 19, 2025 16:18
-
-
Save dhbrojas/efa8c9051693d7037ab376e3fe85bea3 to your computer and use it in GitHub Desktop.
Minitron, LLM Training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "architectures": ["Qwen3ForCausalLM"], | |
| "attention_bias": false, | |
| "attention_dropout": 0.0, | |
| "bos_token_id": 151643, | |
| "eos_token_id": 151645, | |
| "head_dim": 128, | |
| "hidden_act": "silu", | |
| "hidden_size": 1024, | |
| "initializer_range": 0.02, | |
| "intermediate_size": 3072, | |
| "max_position_embeddings": 4096, | |
| "max_window_layers": 28, | |
| "model_type": "qwen3", | |
| "num_attention_heads": 16, | |
| "num_hidden_layers": 28, | |
| "num_key_value_heads": 8, | |
| "rms_norm_eps": 1e-6, | |
| "rope_scaling": null, | |
| "rope_theta": 50000, | |
| "sliding_window": null, | |
| "tie_word_embeddings": true, | |
| "torch_dtype": "bfloat16", | |
| "transformers_version": "4.51.0", | |
| "use_cache": false, | |
| "use_sliding_window": false, | |
| "vocab_size": 151936 | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Run: OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 train.py | |
| # | |
| # Tasks: | |
| # * Add WandB logging | |
| # * Improve the data loading pipeline (see https://gist.github.com/dhbrojas/30ce33fcbdb55f973a4e64fde5e52fd7) | |
| # * Generate ~64 tokens of text using the model every ~100 steps based on some prompt | |
| import os | |
| import torch | |
| import math | |
| import time | |
| import pyarrow.parquet as pq | |
| import tqdm | |
| import wandb | |
| import torch.distributed as dist | |
| from dataclasses import dataclass | |
| from torch.nn.functional import cross_entropy | |
| from torch import Tensor | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.nn import Module | |
| from typing import Dict, Any, Iterator | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from tokenizers import Tokenizer | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| ) | |
| def dprint(msg: str, *, all: bool = False) -> None: | |
| """Print to stdout on the main process, or all processes if `all` is True.""" | |
| if dist.get_rank() == 0 or all: | |
| print(f"[{dist.get_rank()}:{dist.get_world_size()}] {msg}") | |
| def dprint_module_parameters(module: Module, *, all: bool = False) -> None: | |
| num_parameters = sum(p.numel() for p in module.parameters()) | |
| num_trainable_parameters = sum( | |
| p.numel() for p in module.parameters() if p.requires_grad | |
| ) | |
| percent_trainable = 100 * num_trainable_parameters / num_parameters | |
| dprint( | |
| f"Module has {num_parameters:,} parameters, {num_trainable_parameters:,} ({percent_trainable:,.2f}%) trainable", | |
| all=all, | |
| ) | |
| for name, param in module.named_parameters(): | |
| dprint( | |
| f"* {name}, {list(param.shape)} @ {str(param.dtype).replace('torch.', '')}{' (frozen)' if not param.requires_grad else ''}", | |
| all=all, | |
| ) | |
| class ParquetReader: | |
| def __init__(self, file: str, batch_size: int = 256): | |
| self.fp = pq.ParquetFile(file) | |
| self.num_rows = self.fp.metadata.num_rows | |
| self.num_rows_read = 0 | |
| self.batch = None | |
| self.batches = self.fp.iter_batches(batch_size=batch_size) | |
| def __len__(self) -> int: | |
| return self.num_rows - self.num_rows_read | |
| def __iter__(self) -> Iterator[Dict[str, Any]]: | |
| return self | |
| def __next__(self) -> Dict[str, Any]: | |
| if self.num_rows_read >= self.num_rows: | |
| raise StopIteration | |
| if self.batch is None or len(self.batch) == 0: | |
| self.batch = next(self.batches).to_pylist() | |
| row = self.batch.pop(0) | |
| self.num_rows_read += 1 | |
| return row | |
| @dataclass | |
| class Batch: | |
| x: Tensor # (B, L) | |
| y: Tensor # (B, L) | |
| class DataLoader: | |
| def __init__( | |
| self, | |
| tokenizer: Tokenizer, | |
| reader: Iterator[Dict[str, Any]], | |
| *, | |
| device: torch.device, | |
| max_sequence_length: int = 2048, | |
| batch_size: int = 8, | |
| infinite: bool = True, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.reader = reader | |
| self.max_sequence_length = max_sequence_length | |
| self.batch_size = batch_size | |
| self.infinite = infinite | |
| self.device = device | |
| self.buffer = [] | |
| def __next__(self) -> Batch: | |
| B, L = self.batch_size, self.max_sequence_length | |
| # As long as we don't have enough samples, tokenize some more. | |
| while len(self.buffer) < B * L + 1: | |
| sample = next(self.reader) | |
| text = sample["text"] if "text" in sample else sample["clean_content"] | |
| self.buffer.extend(self.tokenizer.encode(text)) | |
| tokens = self.buffer[:B * L + 1] | |
| x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).reshape(B, L) | |
| y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).reshape(B, L) | |
| assert x.shape == y.shape == torch.Size([B, L]) | |
| self.buffer = self.buffer[B * L + 1:] | |
| return Batch(x=x, y=y) | |
| def warmup_cosine_decay(*, W: int, D: int, min_lr_scale_factor: float = 0.1): | |
| def lr_lambda(current_step: int): | |
| if current_step <= W: | |
| if W > 0: | |
| return current_step / W | |
| else: | |
| return 1.0 | |
| elif current_step <= D: | |
| return ( | |
| min_lr_scale_factor | |
| + (1 - min_lr_scale_factor) | |
| * (1 + math.cos(math.pi * (current_step - W) / (D - W))) | |
| / 2 | |
| ) | |
| else: | |
| return min_lr_scale_factor | |
| return lr_lambda | |
| def truncated_normal_weights(module: Module): | |
| for param in module.parameters(recurse=True): | |
| torch.nn.init.trunc_normal_(param, mean=0.0, std=0.02) | |
| if __name__ == "__main__": | |
| dist.init_process_group(backend="nccl") | |
| torch.set_float32_matmul_precision("high") | |
| WS, R = dist.get_world_size(), dist.get_rank() | |
| LR = int(os.environ.get("LOCAL_RANK", R)) | |
| MASTER = R == 0 | |
| # This process will use GPU with ID #R | |
| torch.cuda.set_device(R) | |
| device = torch.device(f"cuda:{R}") | |
| dprint(f"World Size: {WS}, Rank: {R}, Local Rank: {LR}") | |
| dprint(f"Device: {device}") | |
| # Configure our hyperparameters | |
| dataset = os.environ.get("DATASET", "web") | |
| micro_batch_size, sequence_length, steps, gradacc = 8, 2048, 2500, 8 | |
| # Configure tokenizer | |
| config = AutoConfig.from_pretrained("./qwen3-0.6b") | |
| tokenizer = AutoTokenizer.from_pretrained("./qwen3-0.6b") | |
| vocab_size = config.vocab_size | |
| # Load dataset | |
| if dataset == "web": | |
| # Very hardcoded! | |
| reader = ParquetReader(f"/mnt/cgx/datasets/fineweb-edu-10b/00{R}_00000.parquet") | |
| else: | |
| reader = ParquetReader(f"data/shard_00000{LR}.parquet") | |
| dataloader = DataLoader( | |
| tokenizer, | |
| reader, | |
| infinite=True, | |
| max_sequence_length=sequence_length, | |
| batch_size=micro_batch_size, | |
| device=device, | |
| ) | |
| # Load the very first batch, the rest will be prefetched | |
| batch = next(dataloader) | |
| # Configure model, optimizer, and learning rate scheduler | |
| model = AutoModelForCausalLM.from_config(config) | |
| optimizer = AdamW(model.parameters(), lr=0.003, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8) | |
| lr_scheduler = LambdaLR(optimizer, lr_lambda=warmup_cosine_decay(W=250, D=steps)) | |
| model = model.bfloat16().to(device) | |
| model.apply(truncated_normal_weights) | |
| model.compile() | |
| dprint_module_parameters(model) | |
| model = DDP(model, device_ids=[LR], output_device=LR) | |
| model.train() | |
| for step in range(steps): | |
| start = time.time() | |
| losses = torch.zeros(gradacc, device=device) | |
| for iter in range(gradacc): | |
| # Only sync gradients after the last iteration of the gradient accumulation | |
| model.require_backward_grad_sync = (iter == gradacc - 1) | |
| # Compute loss | |
| logits = model(input_ids=batch.x).logits | |
| # assert logits.shape == (micro_batch_size, sequence_length, vocab_size) | |
| loss = cross_entropy( | |
| logits.view(-1, vocab_size), | |
| batch.y.flatten(), | |
| ignore_index=-100, | |
| reduction='mean', | |
| ) | |
| # It's very important to divide the loss by the number of gradient accumulation iterations | |
| # otherwise the magnitude of the gradients will be `WS` times too large. | |
| (loss / gradacc).backward() | |
| # While model is computing forward + backward, let's prefetch the next batch from disk | |
| batch = next(dataloader) | |
| # Save the loss for logging purposes | |
| losses[iter] = loss.detach().item() | |
| # Clip the gradient norm for stable training | |
| gradnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() | |
| # Optimizer step | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # We gather the loss across all processes | |
| dist.all_reduce(losses, dist.ReduceOp.AVG) | |
| loss = losses.mean() | |
| # Compute some metrics | |
| elapsed = time.time() - start | |
| gbs = gradacc * micro_batch_size * sequence_length * WS | |
| tokens = gbs * (step + 1) | |
| metrics = { | |
| "loss": loss.detach().item(), | |
| "gradnorm": gradnorm, | |
| "mem": torch.cuda.memory_reserved() / 1024**3, | |
| "lr": lr_scheduler.get_last_lr()[0], | |
| "gbs": gbs, | |
| "tps": int(gbs / elapsed), | |
| "tokens": tokens, | |
| "elapsed": elapsed | |
| } | |
| fmt = f"[STEP {step}] " | |
| for key, value in metrics.items(): | |
| if isinstance(value, float): | |
| if value >= 0.01: | |
| fmt += f"{key}={value:.2f} " | |
| else: | |
| fmt += f"{key}={value:.2e} " | |
| elif isinstance(value, int): | |
| fmt += f"{key}={value:,} " | |
| else: | |
| fmt += f"{key}={str(value)} " | |
| dprint(fmt) | |
| dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment