Created
July 8, 2025 13:58
-
-
Save Codys12/08d7c3d8f57d915740e5ae93f2f4974a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Cross-entropy fine-tune on PrimeIntellect/SYNTHETIC-1 | |
| – ZeRO-2 (no remote_device) | |
| – Flash-Attn-2 enabled after loading | |
| """ | |
| import argparse, json, os | |
| from pathlib import Path | |
| from typing import Iterable | |
| import torch, torch.nn as nn, deepspeed | |
| from deepspeed import zero | |
| from datasets import load_dataset, IterableDataset | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| TrainingArguments, Trainer, TrainerCallback | |
| ) | |
| SEQ_LEN = 8192 | |
| DATASET_ID = "PrimeIntellect/SYNTHETIC-1" | |
| # ─────────────────── optional BitLinear swap ─────────────────── | |
| try: | |
| import mmfreelm.ops.fusedbitnet as fuse | |
| def replace_linear_with_fusedbit(m: nn.Module) -> nn.Module: | |
| for name, mod in m.named_modules(): | |
| if isinstance(mod, nn.Linear): | |
| fused = fuse.BitLinear(mod.in_features, mod.out_features, | |
| bias=mod.bias is not None).to(dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| fused.weight.copy_(mod.weight) | |
| if mod.bias is not None: | |
| fused.bias.copy_(mod.bias) | |
| parent, child = name.rsplit('.', 1) if '.' in name else ('', name) | |
| (m if parent == '' else dict(m.named_modules())[parent]).__setattr__(child, fused) | |
| return m | |
| except ImportError: | |
| replace_linear_with_fusedbit = lambda x: x # pragma: no cover | |
| # ─────────────────────── dataset helpers ─────────────────────── | |
| def streaming_ds(tok) -> IterableDataset: | |
| raw = load_dataset(DATASET_ID, split="train", streaming=True) | |
| def _gen(): | |
| for ex in raw: | |
| p, a = ex.get("prompt"), ex.get("llm_response") | |
| if not p or not a: | |
| continue | |
| chat = tok.apply_chat_template( | |
| [{"role":"user","content":p},{"role":"assistant","content":a}], | |
| tokenize=False, add_generation_prompt=False) | |
| user_only = tok.apply_chat_template( | |
| [{"role":"user","content":[{"type":"text","text":p}]}], | |
| tokenize=False, add_generation_prompt=True) | |
| conv_ids = tok(chat, add_special_tokens=False)["input_ids"] | |
| user_len = len(tok(user_only, add_special_tokens=False)["input_ids"]) | |
| yield {"input_ids":conv_ids, "user_len":user_len} | |
| return IterableDataset.from_generator(_gen) | |
| class Collator: | |
| def __init__(self, pad_id:int): self.pad_id = pad_id | |
| def __call__(self, rows): | |
| B, L = len(rows), SEQ_LEN | |
| ids = torch.full((B,L), self.pad_id, dtype=torch.long) | |
| attn = torch.zeros((B,L), dtype=torch.long) | |
| labels = torch.full((B,L), -100, dtype=torch.long) | |
| for b,r in enumerate(rows): | |
| toks, user = r["input_ids"], r["user_len"] | |
| over = max(0, len(toks)-L) | |
| toks = toks[over:]; user = max(0, user-over) | |
| off = L-len(toks) | |
| ids[b,off:] = torch.tensor(toks) | |
| attn[b,off:] = 1 | |
| lab = torch.tensor(toks); lab[:user] = -100 | |
| labels[b,off:] = lab | |
| return {"input_ids":ids,"attention_mask":attn,"labels":labels} | |
| class Tick(TrainerCallback): | |
| def on_step_end(self, a, s, c, **k): pass | |
| # ────────────────────────── main ──────────────────────────────── | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model_name", default="Qwen/Qwen3-8B") | |
| ap.add_argument("--output_root", default="./checkpoints") | |
| ap.add_argument("--max_steps", type=int, default=3000) | |
| ap.add_argument("--lr", type=float, default=10.0) | |
| args = ap.parse_args() | |
| tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) | |
| tok.pad_token = tok.pad_token or tok.eos_token | |
| tok.padding_side = "left" | |
| train_ds = streaming_ds(tok) | |
| collate = Collator(tok.pad_token_id) | |
| # ───── DeepSpeed configs (stage-2, no remote_device) ──────── | |
| ds_config = { | |
| "bf16": {"enabled": True}, | |
| "zero_optimization": {"stage":2},#{"stage": 3,"offload_optimizer":{"device": "cpu", "pin_memory": false}}, | |
| "gradient_clipping": 1e-8, | |
| "train_micro_batch_size_per_gpu": 1, | |
| "gradient_accumulation_steps": 8, | |
| "train_batch_size": 64, | |
| "stage3_gather_16bit_weights_on_model_save": True, | |
| "save_only_model": True, | |
| "save_steps": 1000, | |
| } | |
| cfg_path = "ds_cfg.json" | |
| json.dump(ds_config, open(cfg_path, "w")) | |
| print("🔹 loading & sharding (GPU build)…") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_name, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| for n, p in model.named_parameters(): | |
| if "rms_norm" in n: | |
| p.requires_grad_(False) | |
| model = replace_linear_with_fusedbit(model) | |
| model.config.attn_implementation = "flash_attention_2" | |
| out = Path(args.output_root) / f"{Path(args.model_name).name.replace('/','-')}_CE" | |
| out.mkdir(parents=True, exist_ok=True) | |
| targs = TrainingArguments( | |
| output_dir=str(out), overwrite_output_dir=True, | |
| remove_unused_columns=False, max_steps=args.max_steps, | |
| per_device_train_batch_size=1, bf16=True, | |
| gradient_accumulation_steps=8, gradient_checkpointing=True, | |
| learning_rate=args.lr, warmup_steps=100, | |
| logging_steps=10, save_steps=1000, deepspeed=cfg_path, max_grad_norm=1e-8 | |
| ) | |
| trainer = Trainer(model=model, args=targs, | |
| train_dataset=train_ds, | |
| data_collator=collate, tokenizer=tok) | |
| trainer.add_callback(Tick()) | |
| print("🔹 start training …") | |
| trainer.train() | |
| trainer.save_model(str(out/"final")) | |
| print("✅ done – final weights in", out/"final") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How hard is it to turn maybe something like Qwen3 30B-A3B or 32B or the bigger ones, into BitNet for accelerated performance if the end users' VRAM is just 8GB or 16GB? Can it be done on T4 NanoPoor style (or at the very least with "dining out" cost of GPU renting)? https://github.com/VatsaDev/NanoPoor