Skip to content

Instantly share code, notes, and snippets.

@ilyakam
Forked from MahmoudAshraf97/silero_v5.py
Created October 21, 2025 18:17
Show Gist options
  • Select an option

  • Save ilyakam/6fa186e9f22bb721a5f5caee5734661a to your computer and use it in GitHub Desktop.

Select an option

Save ilyakam/6fa186e9f22bb721a5f5caee5734661a to your computer and use it in GitHub Desktop.
Reference Implementation of Silero V5 and V6 VAD model
import torch
import torch.nn as nn
import torch.nn.functional as F
class STFT(nn.Module):
def __init__(self, filter_length, hop_length):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.padding = nn.ReflectionPad1d(filter_length // 2)
self.register_buffer(
"forward_basis_buffer", torch.zeros([filter_length + 2, 1, filter_length])
)
def forward(self, input_data):
input_data = self.padding(input_data).unsqueeze(1)
forward_transform = F.conv1d(
input_data, self.forward_basis_buffer, stride=self.hop_length
)
cutoff = self.filter_length // 2 + 1
real_part = forward_transform[:, :cutoff, 1:]
imag_part = forward_transform[:, cutoff:, 1:]
magnitude = torch.sqrt(real_part.pow(2) + imag_part.pow(2))
return magnitude
class Encoder(nn.Module):
def __init__(self, input_feature_size, filter_length, hop_length):
super(Encoder, self).__init__()
self.feature_extractor = STFT(
filter_length=filter_length, hop_length=hop_length
)
self.relu = nn.ReLU()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_feature_size,
out_channels=128,
kernel_size=3,
padding=1,
),
nn.Conv1d(
in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1
),
nn.Conv1d(
in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1
),
nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
]
)
def forward(self, x):
x = self.feature_extractor(x)
for conv_layer in self.conv_layers:
x = self.relu(conv_layer(x))
return x
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.rnn = nn.LSTM(input_size=128, hidden_size=128)
self.dropout = nn.Dropout(p=0.1)
self.relu = nn.ReLU()
self.conv1d = nn.Conv1d(in_channels=128, out_channels=1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
input, (h, c) = self.rnn(input, (h, c))
input = self.dropout(input)
input = self.relu(input)
input = self.conv1d(input.permute(1, 2, 0))
input = self.sigmoid(input)
return input, (h, c)
class VadModel(nn.Module):
def __init__(self, sampling_rate=16000):
super(VadModel, self).__init__()
assert sampling_rate in [16000, 8000], "Unsupported sampling rate"
filter_length = int(sampling_rate / 62.5)
hop_length = filter_length // 2
encoder_feature_size = hop_length + 1
self.num_samples = filter_length * 2
self.context_size_samples = filter_length // 4
self.encoder = Encoder(encoder_feature_size, filter_length, hop_length)
self.decoder = Decoder()
def get_initial_states(self, device: torch.device):
h = torch.zeros((1, 1, 128), dtype=torch.float32, device=device)
c = torch.zeros((1, 1, 128), dtype=torch.float32, device=device)
return (h, c)
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
x = self.encoder(x)
x = x.permute(0, 2, 1)
out, (h, c) = self.decoder(x, h, c)
return out.reshape(-1), (h, c)
def predict(self, x: torch.Tensor):
assert x.ndim == 1, "Input should be a 1D tensor"
assert (
x.size(0) % self.num_samples == 0
), "Input size should be a multiple of num_samples"
h, c = self.get_initial_states(device=x.device)
x = x.reshape(-1, self.num_samples)
context = x[..., -self.context_size_samples :].clone()
context[-1] = 0
context = context.roll(1, 0)
x = torch.cat([context, x], 1)
return self.forward(x, h, c)
silero_model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad:v5.0",
model="silero_vad",
)
new_model = VadModel(16000)
dict_mapping_16k = {
"encoder.feature_extractor.forward_basis_buffer": "_model.stft.forward_basis_buffer",
"encoder.conv_layers.0.weight": "_model.encoder.0.reparam_conv.weight",
"encoder.conv_layers.0.bias": "_model.encoder.0.reparam_conv.bias",
"encoder.conv_layers.1.weight": "_model.encoder.1.reparam_conv.weight",
"encoder.conv_layers.1.bias": "_model.encoder.1.reparam_conv.bias",
"encoder.conv_layers.2.weight": "_model.encoder.2.reparam_conv.weight",
"encoder.conv_layers.2.bias": "_model.encoder.2.reparam_conv.bias",
"encoder.conv_layers.3.weight": "_model.encoder.3.reparam_conv.weight",
"encoder.conv_layers.3.bias": "_model.encoder.3.reparam_conv.bias",
"decoder.rnn.weight_ih_l0": "_model.decoder.rnn.weight_ih",
"decoder.rnn.weight_hh_l0": "_model.decoder.rnn.weight_hh",
"decoder.rnn.bias_ih_l0": "_model.decoder.rnn.bias_ih",
"decoder.rnn.bias_hh_l0": "_model.decoder.rnn.bias_hh",
"decoder.conv1d.weight": "_model.decoder.decoder.2.weight",
"decoder.conv1d.bias": "_model.decoder.decoder.2.bias",
}
new_model.load_state_dict(
{
key: silero_model.state_dict()[dict_mapping_16k[key]]
for key in dict_mapping_16k.keys()
}
)
new_model = new_model.eval()
torch.onnx.export(
new_model,
kwargs={
"x": torch.randn(10, 576, dtype=torch.float32),
"h": torch.randn(1, 1, 128, dtype=torch.float32),
"c": torch.randn(1, 1, 128, dtype=torch.float32),
},
f="silero_vad_v5.onnx",
input_names=["input", "h", "c"],
output_names=["speech_probs", "hn", "cn"],
dynamic_axes={"input": {0: "seq_len"}},
dynamo=False,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment