Skip to content

Instantly share code, notes, and snippets.

@Codys12
Created April 19, 2025 01:56
Show Gist options
  • Select an option

  • Save Codys12/f2d83a192bba0294860ea3a163d4540e to your computer and use it in GitHub Desktop.

Select an option

Save Codys12/f2d83a192bba0294860ea3a163d4540e to your computer and use it in GitHub Desktop.
BitNet Finetuning
```
pip install git+ https://github.com/childressg/matmulfreellm/tree/master/mmfreelm/ops
```
```
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
DeepSpeed ZeRO for memory-efficient training (stage 3).
4-way experiment toggling:
--lambda_schedule [true|false]
--lambda_warmup <int>
--should_rms [true|false]
"""
import argparse
import sys
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
def __init__(self, num_features, alpha_init_value=0.25):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x * self.weight + self.bias
# Hugging Face Transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from transformers.trainer_utils import TrainOutput
from torch.optim import AdamW
import datasets
import mmfreelm.ops.fusedbitnet as fuse # from matmulfreellm (fusedbitnet module)
_current_lambda = 1.0
_global_should_rms = True # If True, enable RMS in BitLinear
def get_current_lambda():
"""Return the current global lambda."""
return _current_lambda
# ------------------------------------------------------------------------------
# Custom BitLinear Replacement
# ------------------------------------------------------------------------------
def replace_linear_with_fusedbit(model):
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
fusedbit_layer = fuse.BitLinear(
in_features=module.in_features,
out_features=module.out_features,
lambda_=get_current_lambda,
should_rms=True,
bias=(module.bias is not None)
).to(dtype=torch.bfloat16)
# Copy existing weights/bias
with torch.no_grad():
fusedbit_layer.weight.copy_(module.weight)
if module.bias is not None:
fusedbit_layer.bias.copy_(module.bias)
# 🔁 Replace RMSNorm if present
for subname, submodule in fusedbit_layer.named_modules():
if isinstance(submodule, nn.Module) and "norm" in subname.lower():
setattr(
fusedbit_layer,
subname,
RMSNorm(num_features=submodule.weight.shape[0]).to(dtype=torch.bfloat16)
)
break
# Replace the module in its parent
parent_path = name.rsplit('.', 1)
if len(parent_path) == 1:
setattr(model, parent_path[0], fusedbit_layer)
else:
parent_module_name, child_name = parent_path
parent_module = dict(model.named_modules())[parent_module_name]
setattr(parent_module, child_name, fusedbit_layer)
return model
def build_position_ids(input_ids):
"""
input_ids: [batch_size, seq_length] (torch.LongTensor)
returns: position_ids: [batch_size, seq_length]
"""
batch_size, seq_length = input_ids.shape
return torch.arange(seq_length, dtype=torch.long, device=input_ids.device)\
.unsqueeze(0).expand(batch_size, seq_length)
class ChatTemplateCollator:
def __init__(self, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, examples):
batch_input_ids = []
batch_attention_masks = []
for ex in examples:
# Build a 'chat' array recognized by apply_chat_template.
# This depends on your custom approach to chat templates.
# Adjust as needed for your data structure.
chat = []
for c in ex["conversations"]:
chat.append({"role": c["from"], "content": c["value"]})
tokenized = self.tokenizer.apply_chat_template(
chat,
tokenize=True,
add_generation_prompt=False,
return_dict=True,
return_tensors="pt",
)
# Squeeze out the extra batch dimension
input_ids = tokenized["input_ids"].squeeze(0) # [seq_len]
attention_mask = tokenized["attention_mask"].squeeze(0) # [seq_len]
# Apply length truncation if needed
if input_ids.size(0) > self.max_length:
input_ids = input_ids[: self.max_length]
attention_mask = attention_mask[: self.max_length]
batch_input_ids.append(input_ids)
batch_attention_masks.append(attention_mask)
# Pad the entire batch
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
).clone()
position_ids = build_position_ids(padded_input_ids).clone()
padded_attention_masks = torch.nn.utils.rnn.pad_sequence(
batch_attention_masks,
batch_first=True,
padding_value=0,
)
# For causal LM, labels = input_ids
labels = padded_input_ids.clone()
return {
"input_ids": padded_input_ids,
"attention_mask": padded_attention_masks,
"labels": labels,
"position_ids": position_ids,
}
def main():
global _current_lambda, _global_should_rms
_current_lambda = 1.0
base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"#"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
print(f"Loading model from {base_model_name} ...")
student = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
student = replace_linear_with_fusedbit(student)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = int(16384*1.0)
print("Loading dataset open-thoughts/OpenThoughts-114k...")
train_dataset = datasets.load_dataset("open-thoughts/OpenThoughts-114k", split="train")
data_collator = ChatTemplateCollator(tokenizer, max_length=int(16384*1.0)) # or 16384
# Name
exp_name = []
exp_name.append(base_model_name.replace("/", "-"))
exp_name.append("LambdaSchedule")#if args.lambda_schedule else "NoLambda")
exp_name.append("RMS")# if args.should_rms else "NoRMS")
output_dir = "_".join(exp_name)
# Deepspeed config
import json
ds_config = {
"bf16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_scatter": True,
"reduce_bucket_size": 104857600,
"allgather_partitions": True,
"allgather_bucket_size": 104857600,
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"save_only_model": True,
"stage3_gather_16bit_weights_on_model_save": True,
"save_steps": 500,
}
ds_config_path = "temp_ds_config.json"
with open(ds_config_path, "w") as f:
json.dump(ds_config, f)
# TrainingArguments
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
remove_unused_columns=False,
max_steps=1500,
per_device_train_batch_size=1,
save_steps=500,
logging_steps=10,
evaluation_strategy="no",
bf16=True, # matches ds_config
gradient_checkpointing=True,
gradient_accumulation_steps=8,
deepspeed=ds_config_path,
learning_rate=8e-5,
warmup_steps=500
)
# Create Trainer
trainer = Trainer(
model=student,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
for name, param in student.named_parameters():
print(name, param.dtype)
print(f"Starting training for {output_dir} with DeepSpeed ZeRO ...")
trainer.train()#resume_from_checkpoint=True)
# Save final model
final_save_dir = "final_" + output_dir
trainer.save_model(final_save_dir)
print(f"✅ Done. Model saved to {final_save_dir}.")
if __name__ == "__main__":
main()
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment