Last active
January 4, 2026 16:38
-
-
Save sebastiandro/051b152cb8fc1a7c14b6dc960b6962cf to your computer and use it in GitHub Desktop.
GraphBench - Electronic Circuits Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import copy | |
| import torch | |
| import graphbench | |
| from tqdm import tqdm | |
| import torch_geometric | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import GINConv, global_mean_pool | |
| from torch.nn import Sequential, Linear, ReLU, BatchNorm1d | |
| ### ========= Load Dataset and Setup GraphBench Components ========= ### | |
| dataset_name = "electronic_circuits_5_eff" # name of the task or list of tasks | |
| evaluator_name = "electroniccircuit" | |
| # Use relative path to data directory (from project root) | |
| data_path = "./data" | |
| # Setting up the components of GraphBench | |
| Evaluator = graphbench.Evaluator("electroniccircuit") | |
| Loader = graphbench.Loader(data_path, dataset_name) | |
| # Load a GraphBench dataset and get splits | |
| dataset = Loader.load() | |
| ### ========================== Model Setup ========================== ### | |
| class GINGraphRegressor(torch.nn.Module): | |
| def __init__(self, in_channels, hidden=64, out_channels=1, num_layers=3, dropout=0.5): | |
| super().__init__() | |
| self.convs = torch.nn.ModuleList() | |
| for i in range(num_layers): | |
| mlp = Sequential( | |
| Linear(in_channels if i == 0 else hidden, hidden), | |
| BatchNorm1d(hidden), | |
| ReLU(), | |
| Linear(hidden, hidden), | |
| ReLU(), | |
| ) | |
| self.convs.append(GINConv(mlp)) | |
| self.lin1 = Linear(hidden + 1, hidden) | |
| self.lin2 = Linear(hidden, out_channels) | |
| self.dropout = dropout | |
| def forward(self, data, duty): | |
| x, edge_index, batch = data.x, data.edge_index, data.batch | |
| x = x.squeeze(1) | |
| for conv in self.convs: | |
| x = conv(x, edge_index) | |
| x = F.relu(x) | |
| g = global_mean_pool(x, batch) | |
| g = torch.cat([g, duty.view(-1, 1)], dim=1) | |
| g = self.lin1(g) | |
| g = F.relu(g) | |
| g = F.dropout(g, p=self.dropout, training=self.training) | |
| g = self.lin2(g) | |
| return torch.sigmoid(g) | |
| first_graph_node_features = dataset[0]['train'][0].x | |
| in_channels = first_graph_node_features.shape[2] | |
| out_channels = 1 | |
| ### ========================= Clean Dataset ========================= ### | |
| # Training set has some nan values, we filter them out | |
| def clean_dataset(data_list, name="Data"): | |
| cleaned = [] | |
| for data in tqdm(data_list, desc=f"Cleaning {name}"): | |
| # Check for NaNs in x, y, duty and check for empty graphs | |
| if not (torch.isnan(data.x).any() or torch.isnan(data.y).any() or torch.isnan(data.duty).any() or data.num_nodes == 0): | |
| cleaned.append(data) | |
| print(f"{name}: Original size {len(data_list)}, Cleaned size {len(cleaned)}") | |
| return cleaned | |
| print("Filtering corrupted data...") | |
| train_data = clean_dataset(dataset[0]['train'], "Train") | |
| val_data = clean_dataset(dataset[0]['valid'], "Val") | |
| test_data = clean_dataset(dataset[0]['test'], "Test") | |
| ### ========================= Training Setup ========================= ### | |
| model = GINGraphRegressor( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| hidden=64, | |
| num_layers=3, | |
| dropout=0.5, | |
| ) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| loss_fn = torch.nn.MSELoss() | |
| # Create DataLoaders using the cleaned data | |
| train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=128, shuffle=True) | |
| val_loader = torch_geometric.loader.DataLoader(val_data, batch_size=128, shuffle=False) | |
| test_loader = torch_geometric.loader.DataLoader(test_data, batch_size=128, shuffle=False) | |
| def train(model, loader, optimizer, loss_fn): | |
| model.train() | |
| total_loss = 0 | |
| for data in tqdm(loader, desc="Training"): # Wrap loader with tqdm | |
| optimizer.zero_grad() | |
| out = model(data, data.duty.float()) | |
| # Reshape data.y to match out shape [batch_size, 1] | |
| loss = loss_fn(out, data.y.float().view(-1, 1)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * data.num_graphs | |
| return total_loss / len(loader.dataset) | |
| def evaluate(model, loader, loss_fn): | |
| model.eval() | |
| total_loss = 0 | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for data in tqdm(loader, desc="Evaluating"): # Wrap loader with tqdm | |
| out = model(data, data.duty.float()) | |
| loss = loss_fn(out, data.y.float().view(-1, 1)) | |
| total_loss += loss.item() * data.num_graphs | |
| all_preds.append(out.squeeze(-1) if out.dim() > 1 and out.shape[1] == 1 else out) | |
| all_labels.append(data.y.float()) | |
| return total_loss / len(loader.dataset), torch.cat(all_labels), torch.cat(all_preds) | |
| # Training loop with Early Stopping | |
| num_epochs = 1 | |
| patience = 3 | |
| best_val_loss = float('inf') | |
| patience_counter = 0 | |
| best_model_state = None | |
| ### ========================= Training Loop ========================= ### | |
| for epoch in range(1, num_epochs + 1): | |
| print(f'Epoch {epoch}/{num_epochs}') | |
| loss = train(model, train_loader, optimizer, loss_fn) | |
| val_loss, _, _ = evaluate(model, val_loader, loss_fn) | |
| print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}') | |
| # Early Stopping Logic | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| patience_counter = 0 | |
| best_model_state = copy.deepcopy(model.state_dict()) | |
| print(f"Validation loss improved. Saving model.") | |
| else: | |
| patience_counter += 1 | |
| print(f"Validation loss did not improve. Counter: {patience_counter}/{patience}") | |
| if patience_counter >= patience: | |
| print("Early stopping triggered.") | |
| break | |
| # Load the best model | |
| if best_model_state is not None: | |
| model.load_state_dict(best_model_state) | |
| print("Loaded best model weights.") | |
| # Evaluate on the test set | |
| test_loss, y_true, y_pred = evaluate(model, test_loader, loss_fn) | |
| print(f'Test Loss: {test_loss:.4f}') | |
| Evaluator = graphbench.Evaluator("electroniccircuit") # Re-instantiate if needed | |
| results = Evaluator.evaluate(y_pred.view(-1, 1).cpu().numpy(), y_true.view(-1, 1).cpu().numpy()) | |
| print("GraphBench Evaluation Results:", results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment