Created
March 12, 2025 19:09
-
-
Save samuellangajr/8dded26f4728f9a458aa0d80214f5716 to your computer and use it in GitHub Desktop.
Descreva e implemente uma estratégia para lidar com um conjunto de dados de treinamento desbalanceado, onde algumas classes de veículos são muito mais frequentes que outras.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from sklearn.utils.class_weight import compute_class_weight | |
| from imblearn.over_sampling import SMOTE | |
| # Dados de exemplo: classe 0: 1000 amostras, classe 1: 100 amostras | |
| y_train = np.array([0]*1000 + [1]*100) | |
| # 1. Pesos nas classes | |
| # Calcular os pesos das classes para balancear a função de perda | |
| class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train) | |
| class_weights = torch.tensor(class_weights, dtype=torch.float32) | |
| # Função de perda com pesos nas classes | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| # Exemplo de previsões do modelo (suponha que o modelo tem 2 classes) | |
| y_pred = torch.randn(1100, 2) # Exemplo de previsões do modelo (1100 amostras, 2 classes) | |
| y_true = torch.tensor(y_train) | |
| # Calcular a perda | |
| loss = criterion(y_pred, y_true) | |
| print(f"Perda com pesos nas classes: {loss.item()}") | |
| # 2. SMOTE (Gerar Exemplos Sintéticos) | |
| # Suponha que temos dados X e rótulos y desbalanceados | |
| X = np.random.randn(1100, 30) # 1100 amostras, 30 características | |
| y = np.array([0]*1000 + [1]*100) | |
| # Aplicar SMOTE para balancear as classes | |
| smote = SMOTE(sampling_strategy='auto', random_state=42) | |
| X_resampled, y_resampled = smote.fit_resample(X, y) | |
| print(f"Original dataset: {len(y)} classes 0: {sum(y == 0)}, classes 1: {sum(y == 1)}") | |
| print(f"Resampled dataset: {len(y_resampled)} classes 0: {sum(y_resampled == 0)}, classes 1: {sum(y_resampled == 1)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment