Skip to content

Instantly share code, notes, and snippets.

@iateadonut
Last active September 24, 2025 14:36
Show Gist options
  • Select an option

  • Save iateadonut/8cf771b7790b208c28d96fe9a466e564 to your computer and use it in GitHub Desktop.

Select an option

Save iateadonut/8cf771b7790b208c28d96fe9a466e564 to your computer and use it in GitHub Desktop.
Convert Hugging Face Whisper model to OpenAI .pt format for stable-ts
#!/usr/bin/env python3
"""
Convert Hugging Face Whisper model to OpenAI .pt format for stable-ts
"""
import torch
from transformers import WhisperForConditionalGeneration, WhisperConfig
import numpy as np
from collections import OrderedDict
import sys
from pathlib import Path
def rename_keys(state_dict):
"""Map Hugging Face state dict keys to OpenAI Whisper format"""
# Key mapping patterns
mapping_patterns = [
# Encoder mappings
("model.encoder.embed_positions.weight", "encoder.positional_embedding"),
("model.encoder.conv1.weight", "encoder.conv1.weight"),
("model.encoder.conv1.bias", "encoder.conv1.bias"),
("model.encoder.conv2.weight", "encoder.conv2.weight"),
("model.encoder.conv2.bias", "encoder.conv2.bias"),
("model.encoder.layer_norm.weight", "encoder.ln_post.weight"),
("model.encoder.layer_norm.bias", "encoder.ln_post.bias"),
# Decoder mappings
("model.decoder.embed_tokens.weight", "decoder.token_embedding.weight"),
("model.decoder.embed_positions.weight", "decoder.positional_embedding"),
("model.decoder.layer_norm.weight", "decoder.ln.weight"),
("model.decoder.layer_norm.bias", "decoder.ln.bias"),
# Note: proj_out.weight is skipped - OpenAI Whisper shares decoder.token_embedding.weight
]
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key
# Apply direct mappings
for hf_pattern, whisper_pattern in mapping_patterns:
if key == hf_pattern:
new_key = whisper_pattern
break
# Handle encoder layers
if "model.encoder.layers." in key:
new_key = key.replace("model.encoder.layers.", "encoder.blocks.")
new_key = new_key.replace(".self_attn.k_proj.", ".attn.key.")
new_key = new_key.replace(".self_attn.v_proj.", ".attn.value.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.query.")
new_key = new_key.replace(".self_attn.out_proj.", ".attn.out.")
new_key = new_key.replace(".self_attn_layer_norm.", ".attn_ln.")
# MLP layers in OpenAI format use numeric indices
new_key = new_key.replace(".fc1.weight", ".mlp.0.weight")
new_key = new_key.replace(".fc1.bias", ".mlp.0.bias")
new_key = new_key.replace(".fc2.weight", ".mlp.2.weight")
new_key = new_key.replace(".fc2.bias", ".mlp.2.bias")
new_key = new_key.replace(".final_layer_norm.", ".mlp_ln.")
# Handle decoder layers
elif "model.decoder.layers." in key:
new_key = key.replace("model.decoder.layers.", "decoder.blocks.")
new_key = new_key.replace(".self_attn.k_proj.", ".attn.key.")
new_key = new_key.replace(".self_attn.v_proj.", ".attn.value.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.query.")
new_key = new_key.replace(".self_attn.out_proj.", ".attn.out.")
new_key = new_key.replace(".self_attn_layer_norm.", ".attn_ln.")
new_key = new_key.replace(".encoder_attn.k_proj.", ".cross_attn.key.")
new_key = new_key.replace(".encoder_attn.v_proj.", ".cross_attn.value.")
new_key = new_key.replace(".encoder_attn.q_proj.", ".cross_attn.query.")
new_key = new_key.replace(".encoder_attn.out_proj.", ".cross_attn.out.")
new_key = new_key.replace(".encoder_attn_layer_norm.", ".cross_attn_ln.")
# MLP layers in OpenAI format use numeric indices
new_key = new_key.replace(".fc1.weight", ".mlp.0.weight")
new_key = new_key.replace(".fc1.bias", ".mlp.0.bias")
new_key = new_key.replace(".fc2.weight", ".mlp.2.weight")
new_key = new_key.replace(".fc2.bias", ".mlp.2.bias")
new_key = new_key.replace(".final_layer_norm.", ".mlp_ln.")
# Skip keys we don't need
if "proj_out" in key:
continue # Skip proj_out.weight - OpenAI shares decoder.token_embedding
if any(skip in new_key for skip in ["model.", "proj_layer_norm", "_mask"]):
if not new_key.startswith("encoder.") and not new_key.startswith("decoder."):
continue
new_state_dict[new_key] = value
return new_state_dict
def get_model_dimensions(config):
"""Extract model dimensions from config"""
# Model dimensions matching OpenAI Whisper structure
dims = {
'n_mels': getattr(config, 'num_mel_bins', 128), # Should be 128 for large-v3
'n_vocab': config.vocab_size,
'n_audio_ctx': config.max_source_positions,
'n_audio_state': config.d_model,
'n_audio_head': config.encoder_attention_heads,
'n_audio_layer': config.encoder_layers,
'n_text_ctx': config.max_target_positions,
'n_text_state': config.d_model,
'n_text_head': config.decoder_attention_heads,
'n_text_layer': config.decoder_layers,
}
return dims
def main():
print("Hugging Face Whisper to OpenAI .pt Converter")
print("=" * 50)
model_id = "benmajor27/whisper-large-v3-hu_full"
output_file = "whisper-large-v3-hungarian.pt"
print(f"Model: {model_id}")
print(f"Output: {output_file}")
print()
# Load the model from Hugging Face
print("Loading model from Hugging Face...")
print("This will download the model if not cached (~3GB)")
try:
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
config = model.config
print("✓ Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
# Get the state dict
print("\nExtracting state dictionary...")
state_dict = model.state_dict()
print(f" Original keys: {len(state_dict)}")
# Convert to OpenAI format
print("\nConverting to OpenAI Whisper format...")
converted_state_dict = rename_keys(state_dict)
print(f" Converted keys: {len(converted_state_dict)}")
# Get model dimensions
print("\nExtracting model dimensions...")
dims = get_model_dimensions(config)
print(" Model configuration:")
for key, value in dims.items():
print(f" {key}: {value}")
# Prepare the checkpoint
print("\nPreparing checkpoint...")
checkpoint = {
'dims': dims,
'model_state_dict': converted_state_dict,
}
# Save the model
print(f"\nSaving to {output_file}...")
torch.save(checkpoint, output_file)
# Verify file size
file_size = Path(output_file).stat().st_size / (1024**3)
print(f"✓ Saved successfully ({file_size:.2f} GB)")
print("\n" + "=" * 50)
print("Conversion complete!")
print("\nUsage with stable-ts:")
print(f" stable-ts audio.mp3 --model ./{output_file} --language hu --task transcribe -o output.srt")
# Additional verification
print("\nVerifying saved model...")
try:
loaded = torch.load(output_file, map_location='cpu')
print(f" ✓ Model loads correctly")
print(f" ✓ Contains 'dims': {bool('dims' in loaded)}")
print(f" ✓ Contains 'model_state_dict': {bool('model_state_dict' in loaded)}")
# Check for key components
state_keys = list(loaded['model_state_dict'].keys())
has_encoder = any('encoder' in k for k in state_keys)
has_decoder = any('decoder' in k for k in state_keys)
print(f" ✓ Has encoder weights: {has_encoder}")
print(f" ✓ Has decoder weights: {has_decoder}")
except Exception as e:
print(f" ⚠ Warning: Could not verify model: {e}")
print("\n✅ Ready to use with stable-ts!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment