Created
March 12, 2025 19:03
-
-
Save samuellangajr/844f1135166effb3e1fcebddc57aeac0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.utils.prune as prune | |
| import torch.quantization | |
| # Exemplo de um modelo simples (CNN) | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super(SimpleCNN, self).__init__() | |
| self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Supondo uma imagem de entrada 224x224, após pooling a dimensão é 112x112 | |
| self.fc = nn.Linear(16 * 112 * 112, 10) # 10 classes de veículos, por exemplo | |
| def forward(self, x): | |
| x = self.pool(torch.relu(self.conv(x))) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc(x) | |
| return x | |
| # Instanciar o modelo e colocá-lo em modo de avaliação | |
| model = SimpleCNN() | |
| model.eval() | |
| # Aplicar pruning na camada convolucional | |
| prune.random_unstructured(model.conv, name="weight", amount=0.3) # remove 30% dos pesos | |
| # Exibir a estrutura do modelo com pruning aplicado | |
| print("Modelo após pruning:") | |
| print(model) | |
| # Aplicar quantização dinâmica para a camada totalmente conectada | |
| model_quantized = torch.quantization.quantize_dynamic( | |
| model, {nn.Linear}, dtype=torch.qint8 | |
| ) | |
| print("\nModelo após quantização dinâmica:") | |
| print(model_quantized) | |
| # Agora, o modelo_quantized está otimizado para inferência em dispositivos com recursos limitados. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment