Last active
September 24, 2025 14:36
-
-
Save iateadonut/8cf771b7790b208c28d96fe9a466e564 to your computer and use it in GitHub Desktop.
Convert Hugging Face Whisper model to OpenAI .pt format for stable-ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Convert Hugging Face Whisper model to OpenAI .pt format for stable-ts | |
| """ | |
| import torch | |
| from transformers import WhisperForConditionalGeneration, WhisperConfig | |
| import numpy as np | |
| from collections import OrderedDict | |
| import sys | |
| from pathlib import Path | |
| def rename_keys(state_dict): | |
| """Map Hugging Face state dict keys to OpenAI Whisper format""" | |
| # Key mapping patterns | |
| mapping_patterns = [ | |
| # Encoder mappings | |
| ("model.encoder.embed_positions.weight", "encoder.positional_embedding"), | |
| ("model.encoder.conv1.weight", "encoder.conv1.weight"), | |
| ("model.encoder.conv1.bias", "encoder.conv1.bias"), | |
| ("model.encoder.conv2.weight", "encoder.conv2.weight"), | |
| ("model.encoder.conv2.bias", "encoder.conv2.bias"), | |
| ("model.encoder.layer_norm.weight", "encoder.ln_post.weight"), | |
| ("model.encoder.layer_norm.bias", "encoder.ln_post.bias"), | |
| # Decoder mappings | |
| ("model.decoder.embed_tokens.weight", "decoder.token_embedding.weight"), | |
| ("model.decoder.embed_positions.weight", "decoder.positional_embedding"), | |
| ("model.decoder.layer_norm.weight", "decoder.ln.weight"), | |
| ("model.decoder.layer_norm.bias", "decoder.ln.bias"), | |
| # Note: proj_out.weight is skipped - OpenAI Whisper shares decoder.token_embedding.weight | |
| ] | |
| new_state_dict = OrderedDict() | |
| for key, value in state_dict.items(): | |
| new_key = key | |
| # Apply direct mappings | |
| for hf_pattern, whisper_pattern in mapping_patterns: | |
| if key == hf_pattern: | |
| new_key = whisper_pattern | |
| break | |
| # Handle encoder layers | |
| if "model.encoder.layers." in key: | |
| new_key = key.replace("model.encoder.layers.", "encoder.blocks.") | |
| new_key = new_key.replace(".self_attn.k_proj.", ".attn.key.") | |
| new_key = new_key.replace(".self_attn.v_proj.", ".attn.value.") | |
| new_key = new_key.replace(".self_attn.q_proj.", ".attn.query.") | |
| new_key = new_key.replace(".self_attn.out_proj.", ".attn.out.") | |
| new_key = new_key.replace(".self_attn_layer_norm.", ".attn_ln.") | |
| # MLP layers in OpenAI format use numeric indices | |
| new_key = new_key.replace(".fc1.weight", ".mlp.0.weight") | |
| new_key = new_key.replace(".fc1.bias", ".mlp.0.bias") | |
| new_key = new_key.replace(".fc2.weight", ".mlp.2.weight") | |
| new_key = new_key.replace(".fc2.bias", ".mlp.2.bias") | |
| new_key = new_key.replace(".final_layer_norm.", ".mlp_ln.") | |
| # Handle decoder layers | |
| elif "model.decoder.layers." in key: | |
| new_key = key.replace("model.decoder.layers.", "decoder.blocks.") | |
| new_key = new_key.replace(".self_attn.k_proj.", ".attn.key.") | |
| new_key = new_key.replace(".self_attn.v_proj.", ".attn.value.") | |
| new_key = new_key.replace(".self_attn.q_proj.", ".attn.query.") | |
| new_key = new_key.replace(".self_attn.out_proj.", ".attn.out.") | |
| new_key = new_key.replace(".self_attn_layer_norm.", ".attn_ln.") | |
| new_key = new_key.replace(".encoder_attn.k_proj.", ".cross_attn.key.") | |
| new_key = new_key.replace(".encoder_attn.v_proj.", ".cross_attn.value.") | |
| new_key = new_key.replace(".encoder_attn.q_proj.", ".cross_attn.query.") | |
| new_key = new_key.replace(".encoder_attn.out_proj.", ".cross_attn.out.") | |
| new_key = new_key.replace(".encoder_attn_layer_norm.", ".cross_attn_ln.") | |
| # MLP layers in OpenAI format use numeric indices | |
| new_key = new_key.replace(".fc1.weight", ".mlp.0.weight") | |
| new_key = new_key.replace(".fc1.bias", ".mlp.0.bias") | |
| new_key = new_key.replace(".fc2.weight", ".mlp.2.weight") | |
| new_key = new_key.replace(".fc2.bias", ".mlp.2.bias") | |
| new_key = new_key.replace(".final_layer_norm.", ".mlp_ln.") | |
| # Skip keys we don't need | |
| if "proj_out" in key: | |
| continue # Skip proj_out.weight - OpenAI shares decoder.token_embedding | |
| if any(skip in new_key for skip in ["model.", "proj_layer_norm", "_mask"]): | |
| if not new_key.startswith("encoder.") and not new_key.startswith("decoder."): | |
| continue | |
| new_state_dict[new_key] = value | |
| return new_state_dict | |
| def get_model_dimensions(config): | |
| """Extract model dimensions from config""" | |
| # Model dimensions matching OpenAI Whisper structure | |
| dims = { | |
| 'n_mels': getattr(config, 'num_mel_bins', 128), # Should be 128 for large-v3 | |
| 'n_vocab': config.vocab_size, | |
| 'n_audio_ctx': config.max_source_positions, | |
| 'n_audio_state': config.d_model, | |
| 'n_audio_head': config.encoder_attention_heads, | |
| 'n_audio_layer': config.encoder_layers, | |
| 'n_text_ctx': config.max_target_positions, | |
| 'n_text_state': config.d_model, | |
| 'n_text_head': config.decoder_attention_heads, | |
| 'n_text_layer': config.decoder_layers, | |
| } | |
| return dims | |
| def main(): | |
| print("Hugging Face Whisper to OpenAI .pt Converter") | |
| print("=" * 50) | |
| model_id = "benmajor27/whisper-large-v3-hu_full" | |
| output_file = "whisper-large-v3-hungarian.pt" | |
| print(f"Model: {model_id}") | |
| print(f"Output: {output_file}") | |
| print() | |
| # Load the model from Hugging Face | |
| print("Loading model from Hugging Face...") | |
| print("This will download the model if not cached (~3GB)") | |
| try: | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| config = model.config | |
| print("✓ Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| sys.exit(1) | |
| # Get the state dict | |
| print("\nExtracting state dictionary...") | |
| state_dict = model.state_dict() | |
| print(f" Original keys: {len(state_dict)}") | |
| # Convert to OpenAI format | |
| print("\nConverting to OpenAI Whisper format...") | |
| converted_state_dict = rename_keys(state_dict) | |
| print(f" Converted keys: {len(converted_state_dict)}") | |
| # Get model dimensions | |
| print("\nExtracting model dimensions...") | |
| dims = get_model_dimensions(config) | |
| print(" Model configuration:") | |
| for key, value in dims.items(): | |
| print(f" {key}: {value}") | |
| # Prepare the checkpoint | |
| print("\nPreparing checkpoint...") | |
| checkpoint = { | |
| 'dims': dims, | |
| 'model_state_dict': converted_state_dict, | |
| } | |
| # Save the model | |
| print(f"\nSaving to {output_file}...") | |
| torch.save(checkpoint, output_file) | |
| # Verify file size | |
| file_size = Path(output_file).stat().st_size / (1024**3) | |
| print(f"✓ Saved successfully ({file_size:.2f} GB)") | |
| print("\n" + "=" * 50) | |
| print("Conversion complete!") | |
| print("\nUsage with stable-ts:") | |
| print(f" stable-ts audio.mp3 --model ./{output_file} --language hu --task transcribe -o output.srt") | |
| # Additional verification | |
| print("\nVerifying saved model...") | |
| try: | |
| loaded = torch.load(output_file, map_location='cpu') | |
| print(f" ✓ Model loads correctly") | |
| print(f" ✓ Contains 'dims': {bool('dims' in loaded)}") | |
| print(f" ✓ Contains 'model_state_dict': {bool('model_state_dict' in loaded)}") | |
| # Check for key components | |
| state_keys = list(loaded['model_state_dict'].keys()) | |
| has_encoder = any('encoder' in k for k in state_keys) | |
| has_decoder = any('decoder' in k for k in state_keys) | |
| print(f" ✓ Has encoder weights: {has_encoder}") | |
| print(f" ✓ Has decoder weights: {has_decoder}") | |
| except Exception as e: | |
| print(f" ⚠ Warning: Could not verify model: {e}") | |
| print("\n✅ Ready to use with stable-ts!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment