Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Last active July 19, 2025 16:18
Show Gist options
  • Select an option

  • Save dhbrojas/efa8c9051693d7037ab376e3fe85bea3 to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/efa8c9051693d7037ab376e3fe85bea3 to your computer and use it in GitHub Desktop.
Minitron, LLM Training
{
"architectures": ["Qwen3ForCausalLM"],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 4096,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": null,
"rope_theta": 50000,
"sliding_window": null,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": false,
"use_sliding_window": false,
"vocab_size": 151936
}
# Run: OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 train.py
#
# Tasks:
# * Add WandB logging
# * Improve the data loading pipeline (see https://gist.github.com/dhbrojas/30ce33fcbdb55f973a4e64fde5e52fd7)
# * Generate ~64 tokens of text using the model every ~100 steps based on some prompt
import os
import torch
import math
import time
import pyarrow.parquet as pq
import tqdm
import wandb
import torch.distributed as dist
from dataclasses import dataclass
from torch.nn.functional import cross_entropy
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn import Module
from typing import Dict, Any, Iterator
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from tokenizers import Tokenizer
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForCausalLM,
)
def dprint(msg: str, *, all: bool = False) -> None:
"""Print to stdout on the main process, or all processes if `all` is True."""
if dist.get_rank() == 0 or all:
print(f"[{dist.get_rank()}:{dist.get_world_size()}] {msg}")
def dprint_module_parameters(module: Module, *, all: bool = False) -> None:
num_parameters = sum(p.numel() for p in module.parameters())
num_trainable_parameters = sum(
p.numel() for p in module.parameters() if p.requires_grad
)
percent_trainable = 100 * num_trainable_parameters / num_parameters
dprint(
f"Module has {num_parameters:,} parameters, {num_trainable_parameters:,} ({percent_trainable:,.2f}%) trainable",
all=all,
)
for name, param in module.named_parameters():
dprint(
f"* {name}, {list(param.shape)} @ {str(param.dtype).replace('torch.', '')}{' (frozen)' if not param.requires_grad else ''}",
all=all,
)
class ParquetReader:
def __init__(self, file: str, batch_size: int = 256):
self.fp = pq.ParquetFile(file)
self.num_rows = self.fp.metadata.num_rows
self.num_rows_read = 0
self.batch = None
self.batches = self.fp.iter_batches(batch_size=batch_size)
def __len__(self) -> int:
return self.num_rows - self.num_rows_read
def __iter__(self) -> Iterator[Dict[str, Any]]:
return self
def __next__(self) -> Dict[str, Any]:
if self.num_rows_read >= self.num_rows:
raise StopIteration
if self.batch is None or len(self.batch) == 0:
self.batch = next(self.batches).to_pylist()
row = self.batch.pop(0)
self.num_rows_read += 1
return row
@dataclass
class Batch:
x: Tensor # (B, L)
y: Tensor # (B, L)
class DataLoader:
def __init__(
self,
tokenizer: Tokenizer,
reader: Iterator[Dict[str, Any]],
*,
device: torch.device,
max_sequence_length: int = 2048,
batch_size: int = 8,
infinite: bool = True,
):
self.tokenizer = tokenizer
self.reader = reader
self.max_sequence_length = max_sequence_length
self.batch_size = batch_size
self.infinite = infinite
self.device = device
self.buffer = []
def __next__(self) -> Batch:
B, L = self.batch_size, self.max_sequence_length
# As long as we don't have enough samples, tokenize some more.
while len(self.buffer) < B * L + 1:
sample = next(self.reader)
text = sample["text"] if "text" in sample else sample["clean_content"]
self.buffer.extend(self.tokenizer.encode(text))
tokens = self.buffer[:B * L + 1]
x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).reshape(B, L)
y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).reshape(B, L)
assert x.shape == y.shape == torch.Size([B, L])
self.buffer = self.buffer[B * L + 1:]
return Batch(x=x, y=y)
def warmup_cosine_decay(*, W: int, D: int, min_lr_scale_factor: float = 0.1):
def lr_lambda(current_step: int):
if current_step <= W:
if W > 0:
return current_step / W
else:
return 1.0
elif current_step <= D:
return (
min_lr_scale_factor
+ (1 - min_lr_scale_factor)
* (1 + math.cos(math.pi * (current_step - W) / (D - W)))
/ 2
)
else:
return min_lr_scale_factor
return lr_lambda
def truncated_normal_weights(module: Module):
for param in module.parameters(recurse=True):
torch.nn.init.trunc_normal_(param, mean=0.0, std=0.02)
if __name__ == "__main__":
dist.init_process_group(backend="nccl")
torch.set_float32_matmul_precision("high")
WS, R = dist.get_world_size(), dist.get_rank()
LR = int(os.environ.get("LOCAL_RANK", R))
MASTER = R == 0
# This process will use GPU with ID #R
torch.cuda.set_device(R)
device = torch.device(f"cuda:{R}")
dprint(f"World Size: {WS}, Rank: {R}, Local Rank: {LR}")
dprint(f"Device: {device}")
# Configure our hyperparameters
dataset = os.environ.get("DATASET", "web")
micro_batch_size, sequence_length, steps, gradacc = 8, 2048, 2500, 8
# Configure tokenizer
config = AutoConfig.from_pretrained("./qwen3-0.6b")
tokenizer = AutoTokenizer.from_pretrained("./qwen3-0.6b")
vocab_size = config.vocab_size
# Load dataset
if dataset == "web":
# Very hardcoded!
reader = ParquetReader(f"/mnt/cgx/datasets/fineweb-edu-10b/00{R}_00000.parquet")
else:
reader = ParquetReader(f"data/shard_00000{LR}.parquet")
dataloader = DataLoader(
tokenizer,
reader,
infinite=True,
max_sequence_length=sequence_length,
batch_size=micro_batch_size,
device=device,
)
# Load the very first batch, the rest will be prefetched
batch = next(dataloader)
# Configure model, optimizer, and learning rate scheduler
model = AutoModelForCausalLM.from_config(config)
optimizer = AdamW(model.parameters(), lr=0.003, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-8)
lr_scheduler = LambdaLR(optimizer, lr_lambda=warmup_cosine_decay(W=250, D=steps))
model = model.bfloat16().to(device)
model.apply(truncated_normal_weights)
model.compile()
dprint_module_parameters(model)
model = DDP(model, device_ids=[LR], output_device=LR)
model.train()
for step in range(steps):
start = time.time()
losses = torch.zeros(gradacc, device=device)
for iter in range(gradacc):
# Only sync gradients after the last iteration of the gradient accumulation
model.require_backward_grad_sync = (iter == gradacc - 1)
# Compute loss
logits = model(input_ids=batch.x).logits
# assert logits.shape == (micro_batch_size, sequence_length, vocab_size)
loss = cross_entropy(
logits.view(-1, vocab_size),
batch.y.flatten(),
ignore_index=-100,
reduction='mean',
)
# It's very important to divide the loss by the number of gradient accumulation iterations
# otherwise the magnitude of the gradients will be `WS` times too large.
(loss / gradacc).backward()
# While model is computing forward + backward, let's prefetch the next batch from disk
batch = next(dataloader)
# Save the loss for logging purposes
losses[iter] = loss.detach().item()
# Clip the gradient norm for stable training
gradnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
# Optimizer step
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# We gather the loss across all processes
dist.all_reduce(losses, dist.ReduceOp.AVG)
loss = losses.mean()
# Compute some metrics
elapsed = time.time() - start
gbs = gradacc * micro_batch_size * sequence_length * WS
tokens = gbs * (step + 1)
metrics = {
"loss": loss.detach().item(),
"gradnorm": gradnorm,
"mem": torch.cuda.memory_reserved() / 1024**3,
"lr": lr_scheduler.get_last_lr()[0],
"gbs": gbs,
"tps": int(gbs / elapsed),
"tokens": tokens,
"elapsed": elapsed
}
fmt = f"[STEP {step}] "
for key, value in metrics.items():
if isinstance(value, float):
if value >= 0.01:
fmt += f"{key}={value:.2f} "
else:
fmt += f"{key}={value:.2e} "
elif isinstance(value, int):
fmt += f"{key}={value:,} "
else:
fmt += f"{key}={str(value)} "
dprint(fmt)
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment