Skip to content

Instantly share code, notes, and snippets.

@samuellangajr
Created March 12, 2025 19:03
Show Gist options
  • Select an option

  • Save samuellangajr/844f1135166effb3e1fcebddc57aeac0 to your computer and use it in GitHub Desktop.

Select an option

Save samuellangajr/844f1135166effb3e1fcebddc57aeac0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
# Exemplo de um modelo simples (CNN)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Supondo uma imagem de entrada 224x224, após pooling a dimensão é 112x112
self.fc = nn.Linear(16 * 112 * 112, 10) # 10 classes de veículos, por exemplo
def forward(self, x):
x = self.pool(torch.relu(self.conv(x)))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Instanciar o modelo e colocá-lo em modo de avaliação
model = SimpleCNN()
model.eval()
# Aplicar pruning na camada convolucional
prune.random_unstructured(model.conv, name="weight", amount=0.3) # remove 30% dos pesos
# Exibir a estrutura do modelo com pruning aplicado
print("Modelo após pruning:")
print(model)
# Aplicar quantização dinâmica para a camada totalmente conectada
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
print("\nModelo após quantização dinâmica:")
print(model_quantized)
# Agora, o modelo_quantized está otimizado para inferência em dispositivos com recursos limitados.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment