Skip to content

Instantly share code, notes, and snippets.

@spacepxl
Created August 30, 2025 23:53
Show Gist options
  • Select an option

  • Save spacepxl/30fe4595e89ce912a76ef462c566b2d1 to your computer and use it in GitHub Desktop.

Select an option

Save spacepxl/30fe4595e89ce912a76ef462c566b2d1 to your computer and use it in GitHub Desktop.
# based on https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/main/hyvideo/modules/fp8_optimization.py
import os
import torch
import argparse
from tqdm.auto import tqdm
from safetensors.torch import load_file, save_file
def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
_bits = torch.tensor(bits)
_mantissa_bit = torch.tensor(mantissa_bit)
_sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
E = _bits - _sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
return maxval
def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
"""
Default is E4M3.
"""
bits = torch.tensor(bits)
mantissa_bit = torch.tensor(mantissa_bit)
sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
E = bits - sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
minval = - maxval
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
input_clamp = torch.min(torch.max(x, minval), maxval)
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
# dequant
qdq_out = torch.round(input_clamp / log_scales) * log_scales
return qdq_out, log_scales
def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
for i in range(len(x.shape) - 1):
scale = scale.unsqueeze(-1)
new_x = x / scale
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
return quant_dequant_x, scale, log_scales
def parse_args():
parser = argparse.ArgumentParser(
description = "Convert safetensors to fp8 scaled",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--file",
type = str,
required = True,
help = "Input .safetensors file to convert",
)
parser.add_argument(
"--base_dtype",
type = str,
default = "fp16",
choices = ["fp16", "bf16", "fp32"],
help = "dtype to use for anything that can't be converted to fp8",
)
parser.add_argument(
"--ban_list",
nargs="*",
default=[],
help="List of banned keys to keep in base dtype instead of converting to fp8 (zero or more strings)"
)
args = parser.parse_args()
return args
def main(args):
input_path = os.path.normpath(args.file)
output_path = os.path.splitext(input_path)[0] + "_fp8_scaled.safetensors"
orig_state_dict = load_file(input_path)
new_state_dict = {}
if args.base_dtype == "fp16":
base_dtype = torch.float16
elif args.base_dtype == "bf16":
base_dtype = torch.bfloat16
elif args.base_dtype == "fp32":
base_dtype == torch.float32
else:
raise Exception(f"unknown dtype: {args.base_dtype}")
# ban_list = ["text", "time", "head"]
ban_list = args.ban_list
maxval = get_fp_maxval()
for key in tqdm(orig_state_dict.keys()):
# decide whether to convert based on shape and banned keys
convert = False
if orig_state_dict[key].dim() == 2:
convert = True
for ban in ban_list:
if ban in key:
convert = False
scale_key = key.rsplit(".", 1)[0] + ".scale_weight"
if convert:
weight = orig_state_dict[key]
scale = torch.max(torch.abs(weight.flatten())) / maxval
linear_weight, scale, log_scales = fp8_tensor_quant(weight, scale)
linear_weight = linear_weight.to(torch.float8_e4m3fn)
new_state_dict[scale_key] = scale
new_state_dict[key] = linear_weight
else:
if orig_state_dict[key].dim() == 2:
new_state_dict[scale_key] = torch.ones(1)
new_state_dict[key] = orig_state_dict[key].to(base_dtype)
new_state_dict["scaled_fp8"] = torch.zeros(2).to(torch.float8_e4m3fn)
save_file(new_state_dict, output_path)
if __name__ == "__main__":
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment