Skip to content

Instantly share code, notes, and snippets.

@Codys12
Created July 8, 2025 13:58
Show Gist options
  • Select an option

  • Save Codys12/08d7c3d8f57d915740e5ae93f2f4974a to your computer and use it in GitHub Desktop.

Select an option

Save Codys12/08d7c3d8f57d915740e5ae93f2f4974a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Cross-entropy fine-tune on PrimeIntellect/SYNTHETIC-1
– ZeRO-2 (no remote_device)
– Flash-Attn-2 enabled after loading
"""
import argparse, json, os
from pathlib import Path
from typing import Iterable
import torch, torch.nn as nn, deepspeed
from deepspeed import zero
from datasets import load_dataset, IterableDataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TrainingArguments, Trainer, TrainerCallback
)
SEQ_LEN = 8192
DATASET_ID = "PrimeIntellect/SYNTHETIC-1"
# ─────────────────── optional BitLinear swap ───────────────────
try:
import mmfreelm.ops.fusedbitnet as fuse
def replace_linear_with_fusedbit(m: nn.Module) -> nn.Module:
for name, mod in m.named_modules():
if isinstance(mod, nn.Linear):
fused = fuse.BitLinear(mod.in_features, mod.out_features,
bias=mod.bias is not None).to(dtype=torch.bfloat16)
with torch.no_grad():
fused.weight.copy_(mod.weight)
if mod.bias is not None:
fused.bias.copy_(mod.bias)
parent, child = name.rsplit('.', 1) if '.' in name else ('', name)
(m if parent == '' else dict(m.named_modules())[parent]).__setattr__(child, fused)
return m
except ImportError:
replace_linear_with_fusedbit = lambda x: x # pragma: no cover
# ─────────────────────── dataset helpers ───────────────────────
def streaming_ds(tok) -> IterableDataset:
raw = load_dataset(DATASET_ID, split="train", streaming=True)
def _gen():
for ex in raw:
p, a = ex.get("prompt"), ex.get("llm_response")
if not p or not a:
continue
chat = tok.apply_chat_template(
[{"role":"user","content":p},{"role":"assistant","content":a}],
tokenize=False, add_generation_prompt=False)
user_only = tok.apply_chat_template(
[{"role":"user","content":[{"type":"text","text":p}]}],
tokenize=False, add_generation_prompt=True)
conv_ids = tok(chat, add_special_tokens=False)["input_ids"]
user_len = len(tok(user_only, add_special_tokens=False)["input_ids"])
yield {"input_ids":conv_ids, "user_len":user_len}
return IterableDataset.from_generator(_gen)
class Collator:
def __init__(self, pad_id:int): self.pad_id = pad_id
def __call__(self, rows):
B, L = len(rows), SEQ_LEN
ids = torch.full((B,L), self.pad_id, dtype=torch.long)
attn = torch.zeros((B,L), dtype=torch.long)
labels = torch.full((B,L), -100, dtype=torch.long)
for b,r in enumerate(rows):
toks, user = r["input_ids"], r["user_len"]
over = max(0, len(toks)-L)
toks = toks[over:]; user = max(0, user-over)
off = L-len(toks)
ids[b,off:] = torch.tensor(toks)
attn[b,off:] = 1
lab = torch.tensor(toks); lab[:user] = -100
labels[b,off:] = lab
return {"input_ids":ids,"attention_mask":attn,"labels":labels}
class Tick(TrainerCallback):
def on_step_end(self, a, s, c, **k): pass
# ────────────────────────── main ────────────────────────────────
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name", default="Qwen/Qwen3-8B")
ap.add_argument("--output_root", default="./checkpoints")
ap.add_argument("--max_steps", type=int, default=3000)
ap.add_argument("--lr", type=float, default=10.0)
args = ap.parse_args()
tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
tok.pad_token = tok.pad_token or tok.eos_token
tok.padding_side = "left"
train_ds = streaming_ds(tok)
collate = Collator(tok.pad_token_id)
# ───── DeepSpeed configs (stage-2, no remote_device) ────────
ds_config = {
"bf16": {"enabled": True},
"zero_optimization": {"stage":2},#{"stage": 3,"offload_optimizer":{"device": "cpu", "pin_memory": false}},
"gradient_clipping": 1e-8,
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 8,
"train_batch_size": 64,
"stage3_gather_16bit_weights_on_model_save": True,
"save_only_model": True,
"save_steps": 1000,
}
cfg_path = "ds_cfg.json"
json.dump(ds_config, open(cfg_path, "w"))
print("🔹 loading & sharding (GPU build)…")
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
for n, p in model.named_parameters():
if "rms_norm" in n:
p.requires_grad_(False)
model = replace_linear_with_fusedbit(model)
model.config.attn_implementation = "flash_attention_2"
out = Path(args.output_root) / f"{Path(args.model_name).name.replace('/','-')}_CE"
out.mkdir(parents=True, exist_ok=True)
targs = TrainingArguments(
output_dir=str(out), overwrite_output_dir=True,
remove_unused_columns=False, max_steps=args.max_steps,
per_device_train_batch_size=1, bf16=True,
gradient_accumulation_steps=8, gradient_checkpointing=True,
learning_rate=args.lr, warmup_steps=100,
logging_steps=10, save_steps=1000, deepspeed=cfg_path, max_grad_norm=1e-8
)
trainer = Trainer(model=model, args=targs,
train_dataset=train_ds,
data_collator=collate, tokenizer=tok)
trainer.add_callback(Tick())
print("🔹 start training …")
trainer.train()
trainer.save_model(str(out/"final"))
print("✅ done – final weights in", out/"final")
if __name__ == "__main__":
main()
@BradKML
Copy link

BradKML commented Aug 26, 2025

How hard is it to turn maybe something like Qwen3 30B-A3B or 32B or the bigger ones, into BitNet for accelerated performance if the end users' VRAM is just 8GB or 16GB? Can it be done on T4 NanoPoor style (or at the very least with "dining out" cost of GPU renting)? https://github.com/VatsaDev/NanoPoor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment