Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Created May 6, 2024 17:32
Show Gist options
  • Select an option

  • Save MahmoudAshraf97/2209f7162b5257ec357a3b900234fe70 to your computer and use it in GitHub Desktop.

Select an option

Save MahmoudAshraf97/2209f7162b5257ec357a3b900234fe70 to your computer and use it in GitHub Desktop.
Conversion script for MMS-300 Alignment model
import os
import torch
from torchaudio.models import wav2vec2_model
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, AutoConfig
import json
import argparse
import tempfile
# Initialize parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--torch_model_dir",
help="Directory of the torch model, will be downloaded if it doesn't exist",
required=True,
)
parser.add_argument(
"--hf_model_dir", help="Directory to save the converted model", required=True
)
args = parser.parse_args()
def load_model_dict(model_dir: str):
os.makedirs(model_dir, exist_ok=True)
model_path_name = os.path.join(model_dir, "ctc_alignment_mling_uroman_model.pt")
print("Downloading model and dictionary...")
if os.path.exists(model_path_name):
print("Model path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
model_path_name,
)
assert os.path.exists(model_path_name)
state_dict = torch.load(model_path_name, map_location="cpu")
model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.0,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=31,
)
model.load_state_dict(state_dict)
model.eval()
dict_path_name = os.path.join(model_dir, "ctc_alignment_mling_uroman_model.dict")
if os.path.exists(dict_path_name):
print("Dictionary path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt",
dict_path_name,
)
assert os.path.exists(dict_path_name)
dictionary = {}
with open(dict_path_name) as f:
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())}
return model, dictionary
torch_model, dictionary = load_model_dict(args.torch_model_dir)
config = AutoConfig.from_pretrained("facebook/mms-300m")
config.attention_dropout = 0.0
config.hidden_dropout = 0.0
config.layerdrop = 0.0
config.feat_proj_dropout = 0.0
config.mask_time_prob = 0.0
config.vocab_size = 31
# config.feat_extract_activation = "linear"
# config.hidden_act = "linear"
hf_model = Wav2Vec2ForCTC._from_config(config)
# for conv_layer in hf_model.wav2vec2.feature_extractor.conv_layers:
# conv_layer.activation = torch.nn.Identity()
# hf_model.wav2vec2.encoder.pos_conv_embed.activation = torch.nn.Identity()
# for enc_layer in hf_model.wav2vec2.encoder.layers:
# enc_layer.feed_forward.intermediate_act_fn = torch.nn.Identity()
hf_layer_names = list(hf_model.state_dict().keys())
torch_layer_names = list(torch_model.state_dict().keys())
torch_layer_names = ["wav2vec2." + n for n in torch_layer_names]
torch_layer_names = [
n.replace("encoder.feature_projection.", "feature_projection.")
for n in torch_layer_names
]
torch_layer_names = [
n.replace("wav2vec2.encoder.transformer.", "wav2vec2.encoder.")
for n in torch_layer_names
]
torch_layer_names = [n.replace("wav2vec2.aux.", "lm_head.") for n in torch_layer_names]
torch_to_hf_mapping = dict(
zip(list(torch_model.state_dict().keys()), torch_layer_names)
)
new_state_dict = {}
for k, v in torch_to_hf_mapping.items():
new_state_dict[v] = torch_model.state_dict()[k]
for key in hf_model.state_dict().keys():
hf_model.state_dict()[key].copy_(new_state_dict[key])
for key in hf_model.state_dict().keys():
assert torch.all(hf_model.state_dict()[key] == (new_state_dict[key]))
hf_model.save_pretrained(args.hf_model_dir)
os.makedirs(args.hf_model_dir, exist_ok=True)
tf = tempfile.NamedTemporaryFile()
json.dump(dictionary, open(tf.name, "w"))
tokenizer = Wav2Vec2CTCTokenizer(tf.name,bos_token=None)
for i in range(4):
tokenizer._added_tokens_decoder[i].special = True
tokenizer.save_pretrained(args.hf_model_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment