Created
August 30, 2025 23:53
-
-
Save spacepxl/30fe4595e89ce912a76ef462c566b2d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # based on https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/main/hyvideo/modules/fp8_optimization.py | |
| import os | |
| import torch | |
| import argparse | |
| from tqdm.auto import tqdm | |
| from safetensors.torch import load_file, save_file | |
| def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): | |
| _bits = torch.tensor(bits) | |
| _mantissa_bit = torch.tensor(mantissa_bit) | |
| _sign_bits = torch.tensor(sign_bits) | |
| M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) | |
| E = _bits - _sign_bits - M | |
| bias = 2 ** (E - 1) - 1 | |
| mantissa = 1 | |
| for i in range(mantissa_bit - 1): | |
| mantissa += 1 / (2 ** (i+1)) | |
| maxval = mantissa * 2 ** (2**E - 1 - bias) | |
| return maxval | |
| def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): | |
| """ | |
| Default is E4M3. | |
| """ | |
| bits = torch.tensor(bits) | |
| mantissa_bit = torch.tensor(mantissa_bit) | |
| sign_bits = torch.tensor(sign_bits) | |
| M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) | |
| E = bits - sign_bits - M | |
| bias = 2 ** (E - 1) - 1 | |
| mantissa = 1 | |
| for i in range(mantissa_bit - 1): | |
| mantissa += 1 / (2 ** (i+1)) | |
| maxval = mantissa * 2 ** (2**E - 1 - bias) | |
| minval = - maxval | |
| minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) | |
| input_clamp = torch.min(torch.max(x, minval), maxval) | |
| log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) | |
| log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) | |
| # dequant | |
| qdq_out = torch.round(input_clamp / log_scales) * log_scales | |
| return qdq_out, log_scales | |
| def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): | |
| for i in range(len(x.shape) - 1): | |
| scale = scale.unsqueeze(-1) | |
| new_x = x / scale | |
| quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) | |
| return quant_dequant_x, scale, log_scales | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description = "Convert safetensors to fp8 scaled", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--file", | |
| type = str, | |
| required = True, | |
| help = "Input .safetensors file to convert", | |
| ) | |
| parser.add_argument( | |
| "--base_dtype", | |
| type = str, | |
| default = "fp16", | |
| choices = ["fp16", "bf16", "fp32"], | |
| help = "dtype to use for anything that can't be converted to fp8", | |
| ) | |
| parser.add_argument( | |
| "--ban_list", | |
| nargs="*", | |
| default=[], | |
| help="List of banned keys to keep in base dtype instead of converting to fp8 (zero or more strings)" | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(args): | |
| input_path = os.path.normpath(args.file) | |
| output_path = os.path.splitext(input_path)[0] + "_fp8_scaled.safetensors" | |
| orig_state_dict = load_file(input_path) | |
| new_state_dict = {} | |
| if args.base_dtype == "fp16": | |
| base_dtype = torch.float16 | |
| elif args.base_dtype == "bf16": | |
| base_dtype = torch.bfloat16 | |
| elif args.base_dtype == "fp32": | |
| base_dtype == torch.float32 | |
| else: | |
| raise Exception(f"unknown dtype: {args.base_dtype}") | |
| # ban_list = ["text", "time", "head"] | |
| ban_list = args.ban_list | |
| maxval = get_fp_maxval() | |
| for key in tqdm(orig_state_dict.keys()): | |
| # decide whether to convert based on shape and banned keys | |
| convert = False | |
| if orig_state_dict[key].dim() == 2: | |
| convert = True | |
| for ban in ban_list: | |
| if ban in key: | |
| convert = False | |
| scale_key = key.rsplit(".", 1)[0] + ".scale_weight" | |
| if convert: | |
| weight = orig_state_dict[key] | |
| scale = torch.max(torch.abs(weight.flatten())) / maxval | |
| linear_weight, scale, log_scales = fp8_tensor_quant(weight, scale) | |
| linear_weight = linear_weight.to(torch.float8_e4m3fn) | |
| new_state_dict[scale_key] = scale | |
| new_state_dict[key] = linear_weight | |
| else: | |
| if orig_state_dict[key].dim() == 2: | |
| new_state_dict[scale_key] = torch.ones(1) | |
| new_state_dict[key] = orig_state_dict[key].to(base_dtype) | |
| new_state_dict["scaled_fp8"] = torch.zeros(2).to(torch.float8_e4m3fn) | |
| save_file(new_state_dict, output_path) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment