-
-
Save ilyakam/6fa186e9f22bb721a5f5caee5734661a to your computer and use it in GitHub Desktop.
Reference Implementation of Silero V5 and V6 VAD model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class STFT(nn.Module): | |
| def __init__(self, filter_length, hop_length): | |
| super(STFT, self).__init__() | |
| self.filter_length = filter_length | |
| self.hop_length = hop_length | |
| self.padding = nn.ReflectionPad1d(filter_length // 2) | |
| self.register_buffer( | |
| "forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length]) | |
| ) | |
| def forward(self, input_data): | |
| input_data = self.padding(input_data).unsqueeze(1) | |
| forward_transform = F.conv1d( | |
| input_data, self.forward_basis_buffer, stride=self.hop_length | |
| ) | |
| cutoff = self.filter_length // 2 + 1 | |
| real_part = forward_transform[:, :cutoff, 1:] | |
| imag_part = forward_transform[:, cutoff:, 1:] | |
| magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2)) | |
| return magnitude | |
| class Encoder(nn.Module): | |
| def __init__(self, input_feature_size, filter_length, hop_length): | |
| super(Encoder, self).__init__() | |
| self.feature_extractor = STFT( | |
| filter_length=filter_length, hop_length=hop_length | |
| ) | |
| self.relu = nn.ReLU() | |
| self.conv_layers = nn.ModuleList( | |
| [ | |
| nn.Conv1d( | |
| in_channels=input_feature_size, | |
| out_channels=128, | |
| kernel_size=3, | |
| padding=1, | |
| ), | |
| nn.Conv1d( | |
| in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1 | |
| ), | |
| nn.Conv1d( | |
| in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1 | |
| ), | |
| nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1), | |
| ] | |
| ) | |
| def forward(self, x): | |
| x = self.feature_extractor(x) | |
| for conv_layer in self.conv_layers: | |
| x = self.relu(conv_layer(x)) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self): | |
| super(Decoder, self).__init__() | |
| self.rnn = nn.LSTM(input_size=128, hidden_size=128) | |
| self.dropout = nn.Dropout(p=0.1) | |
| self.relu = nn.ReLU() | |
| self.conv1d = nn.Conv1d(in_channels=128, out_channels=1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input: torch.Tensor, h: torch.Tensor, c: torch.Tensor): | |
| input, (h, c) = self.rnn(input, (h, c)) | |
| input = self.dropout(input) | |
| input = self.relu(input) | |
| input = self.conv1d(input.permute(1, 2, 0)) | |
| input = self.sigmoid(input) | |
| return input, (h, c) | |
| class VadModel(nn.Module): | |
| def __init__(self, sampling_rate=16000): | |
| super(VadModel, self).__init__() | |
| assert sampling_rate in [16000, 8000], "Unsupported sampling rate" | |
| filter_length = int(sampling_rate / 62.5) | |
| hop_length = filter_length // 2 | |
| encoder_feature_size = hop_length + 1 | |
| self.num_samples = filter_length * 2 | |
| self.context_size_samples = filter_length // 4 | |
| self.encoder = Encoder(encoder_feature_size, filter_length, hop_length) | |
| self.decoder = Decoder() | |
| def get_initial_states(self, device: torch.device): | |
| h = torch.zeros((1, 1, 128), dtype=torch.float32, device=device) | |
| c = torch.zeros((1, 1, 128), dtype=torch.float32, device=device) | |
| return (h, c) | |
| def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor): | |
| x = self.encoder(x) | |
| x = x.permute(0, 2, 1) | |
| out, (h, c) = self.decoder(x, h, c) | |
| return out.reshape(-1), (h, c) | |
| def predict(self, x: torch.Tensor): | |
| assert x.ndim == 1, "Input should be a 1D tensor" | |
| assert ( | |
| x.size(0) % self.num_samples == 0 | |
| ), "Input size should be a multiple of num_samples" | |
| h, c = self.get_initial_states(device=x.device) | |
| x = x.reshape(-1, self.num_samples) | |
| context = x[..., -self.context_size_samples :].clone() | |
| context[-1] = 0 | |
| context = context.roll(1, 0) | |
| x = torch.cat([context, x], 1) | |
| return self.forward(x, h, c) | |
| silero_model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad:v5.0", | |
| model="silero_vad", | |
| ) | |
| new_model = VadModel(16000) | |
| dict_mapping_16k = { | |
| "encoder.feature_extractor.forward_basis_buffer": "_model.stft.forward_basis_buffer", | |
| "encoder.conv_layers.0.weight": "_model.encoder.0.reparam_conv.weight", | |
| "encoder.conv_layers.0.bias": "_model.encoder.0.reparam_conv.bias", | |
| "encoder.conv_layers.1.weight": "_model.encoder.1.reparam_conv.weight", | |
| "encoder.conv_layers.1.bias": "_model.encoder.1.reparam_conv.bias", | |
| "encoder.conv_layers.2.weight": "_model.encoder.2.reparam_conv.weight", | |
| "encoder.conv_layers.2.bias": "_model.encoder.2.reparam_conv.bias", | |
| "encoder.conv_layers.3.weight": "_model.encoder.3.reparam_conv.weight", | |
| "encoder.conv_layers.3.bias": "_model.encoder.3.reparam_conv.bias", | |
| "decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih", | |
| "decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh", | |
| "decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih", | |
| "decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh", | |
| "decoder.conv1d.weight": "_model.decoder.decoder.2.weight", | |
| "decoder.conv1d.bias": "_model.decoder.decoder.2.bias", | |
| } | |
| new_model.load_state_dict( | |
| { | |
| key: silero_model.state_dict()[dict_mapping_16k[key]] | |
| for key in dict_mapping_16k.keys() | |
| } | |
| ) | |
| new_model = new_model.eval() | |
| torch.onnx.export( | |
| new_model, | |
| kwargs={ | |
| "x": torch.randn(10, 576, dtype=torch.float32), | |
| "h": torch.randn(1, 1, 128, dtype=torch.float32), | |
| "c": torch.randn(1, 1, 128, dtype=torch.float32), | |
| }, | |
| f="silero_vad_v5.onnx", | |
| input_names=["input", "h", "c"], | |
| output_names=["speech_probs", "hn", "cn"], | |
| dynamic_axes={"input": {0: "seq_len"}}, | |
| dynamo=False, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment