Last active
October 15, 2025 20:44
-
-
Save dhruvilp/c2fc7ea9a9507f39c4188231281aeffb to your computer and use it in GitHub Desktop.
quant gpt oss local
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import time | |
| model_path = './gpt-oss-model-local' | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| llm_int8_enable_fp32_cpu_offload=True | |
| ) | |
| num_gpu_layers = 15 | |
| num_total_layers = 24 | |
| device_map = { | |
| "model.embed_tokens": 0, | |
| **{f"model.layers.{i}": 0 for i in range(num_gpu_layers)}, | |
| **{f"model.layers.{i}": "cpu" for i in range(num_gpu_layers, num_total_layers)}, | |
| "model.norm": "cpu", | |
| "lm_head": "cpu" | |
| } | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| quantization_config=quantization_config, | |
| # attn_implementation="flash_attention_3", | |
| # attn_implementation="sdpa", | |
| device_map=device_map | |
| ) | |
| print(f"current attention impl: {model.config._attn_implementation}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| messages = [ | |
| {"role": "user", "content": "Explain what MXFP4 quantization is in simple terms."}, | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ).to(model.device) | |
| max_new_tokens = 1 | |
| start_time = time.perf_counter() | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.7 | |
| ) | |
| print(tokenizer.decode(outputs[0])) | |
| end_time = time.perf_counter() | |
| elapsed_time = end_time - start_time | |
| print("inf end") | |
| print(tokenizer.decode(outputs[0])) | |
| print("\n" + "="*30) | |
| print(f"elapsed time: {elapsed_time:.2f}sec") | |
| print("="*30) | |
| """ | |
| "quantization_config": { | |
| "_load_in_4bit": true, | |
| "_load_in_8bit": false, | |
| "bnb_4bit_compute_dtype": "bfloat16", | |
| "bnb_4bit_quant_storage": "uint8", | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_use_double_quant": true, | |
| "llm_int8_enable_fp32_cpu_offload": false, | |
| "llm_int8_has_fp16_weight": false, | |
| "llm_int8_skip_modules": null, | |
| "llm_int8_threshold": 6.0, | |
| "load_in_4bit": true, | |
| "load_in_8bit": false, | |
| "quant_method": "bitsandbytes" | |
| } | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import modelopt.torch.quantization as mtq | |
| # Load base model (upcast from original MXFP4 to BF16) | |
| MODEL_ID = "openai/gpt-oss-20b" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| # Configure NVFP4 quantization | |
| config = mtq.NVFP4_DEFAULT_CFG | |
| # Calibration for optimal quantization | |
| def forward_loop(model): | |
| calibration_prompts = [ | |
| "The future of artificial intelligence is", | |
| "Machine learning has transformed", | |
| "Deep learning models are capable of" | |
| ] | |
| model.eval() | |
| with torch.no_grad(): | |
| for prompt in calibration_prompts: | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True | |
| ).to(model.device) | |
| model(**inputs) | |
| # Apply quantization | |
| model = mtq.quantize(model, config, forward_loop) | |
| # Save quantized model | |
| model.save_pretrained("/path/to/output", safe_serialization=True) | |
| tokenizer.save_pretrained("/path/to/output") |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
"quantization_config": {
"_load_in_4bit": true,
"_load_in_8bit": false,
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_quant_storage": "uint8",
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": true,
"llm_int8_enable_fp32_cpu_offload": false,
"llm_int8_has_fp16_weight": false,
"llm_int8_skip_modules": null,
"llm_int8_threshold": 6.0,
"load_in_4bit": true,
"load_in_8bit": false,
"quant_method": "bitsandbytes"
}