Skip to content

Instantly share code, notes, and snippets.

@dhruvilp
Last active October 15, 2025 20:44
Show Gist options
  • Select an option

  • Save dhruvilp/c2fc7ea9a9507f39c4188231281aeffb to your computer and use it in GitHub Desktop.

Select an option

Save dhruvilp/c2fc7ea9a9507f39c4188231281aeffb to your computer and use it in GitHub Desktop.
quant gpt oss local
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
model_path = './gpt-oss-model-local'
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True
)
num_gpu_layers = 15
num_total_layers = 24
device_map = {
"model.embed_tokens": 0,
**{f"model.layers.{i}": 0 for i in range(num_gpu_layers)},
**{f"model.layers.{i}": "cpu" for i in range(num_gpu_layers, num_total_layers)},
"model.norm": "cpu",
"lm_head": "cpu"
}
model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
# attn_implementation="flash_attention_3",
# attn_implementation="sdpa",
device_map=device_map
)
print(f"current attention impl: {model.config._attn_implementation}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
messages = [
{"role": "user", "content": "Explain what MXFP4 quantization is in simple terms."},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
max_new_tokens = 1
start_time = time.perf_counter()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7
)
print(tokenizer.decode(outputs[0]))
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print("inf end")
print(tokenizer.decode(outputs[0]))
print("\n" + "="*30)
print(f"elapsed time: {elapsed_time:.2f}sec")
print("="*30)
"""
"quantization_config": {
"_load_in_4bit": true,
"_load_in_8bit": false,
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_quant_storage": "uint8",
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": true,
"llm_int8_enable_fp32_cpu_offload": false,
"llm_int8_has_fp16_weight": false,
"llm_int8_skip_modules": null,
"llm_int8_threshold": 6.0,
"load_in_4bit": true,
"load_in_8bit": false,
"quant_method": "bitsandbytes"
}
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.quantization as mtq
# Load base model (upcast from original MXFP4 to BF16)
MODEL_ID = "openai/gpt-oss-20b"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Configure NVFP4 quantization
config = mtq.NVFP4_DEFAULT_CFG
# Calibration for optimal quantization
def forward_loop(model):
calibration_prompts = [
"The future of artificial intelligence is",
"Machine learning has transformed",
"Deep learning models are capable of"
]
model.eval()
with torch.no_grad():
for prompt in calibration_prompts:
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
).to(model.device)
model(**inputs)
# Apply quantization
model = mtq.quantize(model, config, forward_loop)
# Save quantized model
model.save_pretrained("/path/to/output", safe_serialization=True)
tokenizer.save_pretrained("/path/to/output")
@dhruvilp
Copy link
Author

"quantization_config": {

"_load_in_4bit": true,
"_load_in_8bit": false,
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_quant_storage": "uint8",
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": true,
"llm_int8_enable_fp32_cpu_offload": false,
"llm_int8_has_fp16_weight": false,
"llm_int8_skip_modules": null,
"llm_int8_threshold": 6.0,
"load_in_4bit": true,
"load_in_8bit": false,
"quant_method": "bitsandbytes"
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment