Created
April 19, 2025 01:56
-
-
Save Codys12/f2d83a192bba0294860ea3a163d4540e to your computer and use it in GitHub Desktop.
BitNet Finetuning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ``` | |
| pip install git+ https://github.com/childressg/matmulfreellm/tree/master/mmfreelm/ops | |
| ``` | |
| ``` | |
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| DeepSpeed ZeRO for memory-efficient training (stage 3). | |
| 4-way experiment toggling: | |
| --lambda_schedule [true|false] | |
| --lambda_warmup <int> | |
| --should_rms [true|false] | |
| """ | |
| import argparse | |
| import sys | |
| import time | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class RMSNorm(nn.Module): | |
| def __init__(self, num_features, alpha_init_value=0.25): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| def forward(self, x): | |
| x = torch.tanh(self.alpha * x) | |
| return x * self.weight + self.bias | |
| # Hugging Face Transformers | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from transformers.trainer_utils import TrainOutput | |
| from torch.optim import AdamW | |
| import datasets | |
| import mmfreelm.ops.fusedbitnet as fuse # from matmulfreellm (fusedbitnet module) | |
| _current_lambda = 1.0 | |
| _global_should_rms = True # If True, enable RMS in BitLinear | |
| def get_current_lambda(): | |
| """Return the current global lambda.""" | |
| return _current_lambda | |
| # ------------------------------------------------------------------------------ | |
| # Custom BitLinear Replacement | |
| # ------------------------------------------------------------------------------ | |
| def replace_linear_with_fusedbit(model): | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Linear): | |
| fusedbit_layer = fuse.BitLinear( | |
| in_features=module.in_features, | |
| out_features=module.out_features, | |
| lambda_=get_current_lambda, | |
| should_rms=True, | |
| bias=(module.bias is not None) | |
| ).to(dtype=torch.bfloat16) | |
| # Copy existing weights/bias | |
| with torch.no_grad(): | |
| fusedbit_layer.weight.copy_(module.weight) | |
| if module.bias is not None: | |
| fusedbit_layer.bias.copy_(module.bias) | |
| # 🔁 Replace RMSNorm if present | |
| for subname, submodule in fusedbit_layer.named_modules(): | |
| if isinstance(submodule, nn.Module) and "norm" in subname.lower(): | |
| setattr( | |
| fusedbit_layer, | |
| subname, | |
| RMSNorm(num_features=submodule.weight.shape[0]).to(dtype=torch.bfloat16) | |
| ) | |
| break | |
| # Replace the module in its parent | |
| parent_path = name.rsplit('.', 1) | |
| if len(parent_path) == 1: | |
| setattr(model, parent_path[0], fusedbit_layer) | |
| else: | |
| parent_module_name, child_name = parent_path | |
| parent_module = dict(model.named_modules())[parent_module_name] | |
| setattr(parent_module, child_name, fusedbit_layer) | |
| return model | |
| def build_position_ids(input_ids): | |
| """ | |
| input_ids: [batch_size, seq_length] (torch.LongTensor) | |
| returns: position_ids: [batch_size, seq_length] | |
| """ | |
| batch_size, seq_length = input_ids.shape | |
| return torch.arange(seq_length, dtype=torch.long, device=input_ids.device)\ | |
| .unsqueeze(0).expand(batch_size, seq_length) | |
| class ChatTemplateCollator: | |
| def __init__(self, tokenizer, max_length=512): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __call__(self, examples): | |
| batch_input_ids = [] | |
| batch_attention_masks = [] | |
| for ex in examples: | |
| # Build a 'chat' array recognized by apply_chat_template. | |
| # This depends on your custom approach to chat templates. | |
| # Adjust as needed for your data structure. | |
| chat = [] | |
| for c in ex["conversations"]: | |
| chat.append({"role": c["from"], "content": c["value"]}) | |
| tokenized = self.tokenizer.apply_chat_template( | |
| chat, | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # Squeeze out the extra batch dimension | |
| input_ids = tokenized["input_ids"].squeeze(0) # [seq_len] | |
| attention_mask = tokenized["attention_mask"].squeeze(0) # [seq_len] | |
| # Apply length truncation if needed | |
| if input_ids.size(0) > self.max_length: | |
| input_ids = input_ids[: self.max_length] | |
| attention_mask = attention_mask[: self.max_length] | |
| batch_input_ids.append(input_ids) | |
| batch_attention_masks.append(attention_mask) | |
| # Pad the entire batch | |
| padded_input_ids = torch.nn.utils.rnn.pad_sequence( | |
| batch_input_ids, | |
| batch_first=True, | |
| padding_value=self.tokenizer.pad_token_id, | |
| ).clone() | |
| position_ids = build_position_ids(padded_input_ids).clone() | |
| padded_attention_masks = torch.nn.utils.rnn.pad_sequence( | |
| batch_attention_masks, | |
| batch_first=True, | |
| padding_value=0, | |
| ) | |
| # For causal LM, labels = input_ids | |
| labels = padded_input_ids.clone() | |
| return { | |
| "input_ids": padded_input_ids, | |
| "attention_mask": padded_attention_masks, | |
| "labels": labels, | |
| "position_ids": position_ids, | |
| } | |
| def main(): | |
| global _current_lambda, _global_should_rms | |
| _current_lambda = 1.0 | |
| base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"#"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" | |
| print(f"Loading model from {base_model_name} ...") | |
| student = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| student = replace_linear_with_fusedbit(student) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.model_max_length = int(16384*1.0) | |
| print("Loading dataset open-thoughts/OpenThoughts-114k...") | |
| train_dataset = datasets.load_dataset("open-thoughts/OpenThoughts-114k", split="train") | |
| data_collator = ChatTemplateCollator(tokenizer, max_length=int(16384*1.0)) # or 16384 | |
| # Name | |
| exp_name = [] | |
| exp_name.append(base_model_name.replace("/", "-")) | |
| exp_name.append("LambdaSchedule")#if args.lambda_schedule else "NoLambda") | |
| exp_name.append("RMS")# if args.should_rms else "NoRMS") | |
| output_dir = "_".join(exp_name) | |
| # Deepspeed config | |
| import json | |
| ds_config = { | |
| "bf16": { | |
| "enabled": True | |
| }, | |
| "zero_optimization": { | |
| "stage": 3, | |
| "overlap_comm": True, | |
| "contiguous_gradients": True, | |
| "reduce_scatter": True, | |
| "reduce_bucket_size": 104857600, | |
| "allgather_partitions": True, | |
| "allgather_bucket_size": 104857600, | |
| "offload_param": { | |
| "device": "cpu", | |
| "pin_memory": True | |
| }, | |
| "offload_optimizer": { | |
| "device": "cpu", | |
| "pin_memory": True | |
| }, | |
| }, | |
| "gradient_clipping": 1.0, | |
| "train_micro_batch_size_per_gpu": "auto", | |
| "gradient_accumulation_steps": "auto", | |
| "save_only_model": True, | |
| "stage3_gather_16bit_weights_on_model_save": True, | |
| "save_steps": 500, | |
| } | |
| ds_config_path = "temp_ds_config.json" | |
| with open(ds_config_path, "w") as f: | |
| json.dump(ds_config, f) | |
| # TrainingArguments | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| remove_unused_columns=False, | |
| max_steps=1500, | |
| per_device_train_batch_size=1, | |
| save_steps=500, | |
| logging_steps=10, | |
| evaluation_strategy="no", | |
| bf16=True, # matches ds_config | |
| gradient_checkpointing=True, | |
| gradient_accumulation_steps=8, | |
| deepspeed=ds_config_path, | |
| learning_rate=8e-5, | |
| warmup_steps=500 | |
| ) | |
| # Create Trainer | |
| trainer = Trainer( | |
| model=student, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| for name, param in student.named_parameters(): | |
| print(name, param.dtype) | |
| print(f"Starting training for {output_dir} with DeepSpeed ZeRO ...") | |
| trainer.train()#resume_from_checkpoint=True) | |
| # Save final model | |
| final_save_dir = "final_" + output_dir | |
| trainer.save_model(final_save_dir) | |
| print(f"✅ Done. Model saved to {final_save_dir}.") | |
| if __name__ == "__main__": | |
| main() | |
| ``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment