Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Created February 19, 2025 12:36
Show Gist options
  • Select an option

  • Save MahmoudAshraf97/7ed36a87c874a8354cea36670feb3a0d to your computer and use it in GitHub Desktop.

Select an option

Save MahmoudAshraf97/7ed36a87c874a8354cea36670feb3a0d to your computer and use it in GitHub Desktop.
Silero V5 for batched inference
from typing import Tuple
import line_profiler
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, feature_size, filter_length, hop_length):
super(Encoder, self).__init__()
self._n_fft = filter_length
self._hop_length = hop_length
self._window = nn.Parameter(
torch.hann_window(filter_length), requires_grad=False
)
self.relu = nn.ReLU()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=feature_size, out_channels=128, kernel_size=3, padding=1
),
nn.Conv1d(
in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1
),
nn.Conv1d(
in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1
),
nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
]
)
def forward(self, x):
x = torch.stft(
x,
n_fft=self._n_fft,
hop_length=self._hop_length,
window=self._window,
return_complex=True,
)[:, :, 1:].abs()
for conv_layer in self.conv_layers:
x = self.relu(conv_layer(x))
return x
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.dropout = nn.Dropout(p=0.1)
self.relu = nn.ReLU()
self.conv1d = nn.Conv1d(in_channels=128, out_channels=1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, input) -> torch.Tensor:
input = self.dropout(input)
input = self.relu(input)
input = self.conv1d(input.unsqueeze(-1))
input = self.sigmoid(input)
return input
class VadModel(nn.Module):
def __init__(self, sampling_rate=16000):
super(VadModel, self).__init__()
assert sampling_rate in [16000, 8000], "Unsupported sampling rate"
filter_length = int(sampling_rate / 62.5)
hop_length = filter_length // 2
encoder_feature_size = hop_length + 1
self.num_samples = filter_length * 2
self.context_size_samples = filter_length // 4
self.encoder = Encoder(encoder_feature_size, filter_length, hop_length)
self.lstm = nn.LSTM(input_size=128, hidden_size=128)
self.decoder = Decoder()
# @torch.inference_mode()
@line_profiler.profile
def forward(
self,
inputs: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor],
input_lengths: list[int],
):
encoder_output = self.encoder(inputs).squeeze(-1)
encoder_output = torch.nn.utils.rnn.pad_sequence(
encoder_output.split(input_lengths), batch_first=True
)
packed_encoder_output = torch.nn.utils.rnn.pack_padded_sequence(
encoder_output,
enforce_sorted=False,
lengths=input_lengths,
batch_first=True,
)
lstm_output, states = self.lstm(packed_encoder_output, states)
unpacked_lstm_output = torch.nn.utils.rnn.unpack_sequence(lstm_output)
output = self.decoder(torch.cat(unpacked_lstm_output, dim=0)).view(-1)
return (
output.split(input_lengths),
torch.cat(states).cpu().split(1),
)
dict_mapping_16k = {
"encoder.conv_layers.0.weight": "_model.encoder.0.reparam_conv.weight",
"encoder.conv_layers.0.bias": "_model.encoder.0.reparam_conv.bias",
"encoder.conv_layers.1.weight": "_model.encoder.1.reparam_conv.weight",
"encoder.conv_layers.1.bias": "_model.encoder.1.reparam_conv.bias",
"encoder.conv_layers.2.weight": "_model.encoder.2.reparam_conv.weight",
"encoder.conv_layers.2.bias": "_model.encoder.2.reparam_conv.bias",
"encoder.conv_layers.3.weight": "_model.encoder.3.reparam_conv.weight",
"encoder.conv_layers.3.bias": "_model.encoder.3.reparam_conv.bias",
"lstm.weight_ih_l0": "_model.decoder.rnn.weight_ih",
"lstm.weight_hh_l0": "_model.decoder.rnn.weight_hh",
"lstm.bias_ih_l0": "_model.decoder.rnn.bias_ih",
"lstm.bias_hh_l0": "_model.decoder.rnn.bias_hh",
"decoder.conv1d.weight": "_model.decoder.decoder.2.weight",
"decoder.conv1d.bias": "_model.decoder.decoder.2.bias",
}
dict_mapping_8k = {
k: v.replace("_model.", "_model_8k.") for k, v in dict_mapping_16k.items()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment