Skip to content

Instantly share code, notes, and snippets.

@sebastiandro
Last active January 4, 2026 16:38
Show Gist options
  • Select an option

  • Save sebastiandro/051b152cb8fc1a7c14b6dc960b6962cf to your computer and use it in GitHub Desktop.

Select an option

Save sebastiandro/051b152cb8fc1a7c14b6dc960b6962cf to your computer and use it in GitHub Desktop.
GraphBench - Electronic Circuits Example
import copy
import torch
import graphbench
from tqdm import tqdm
import torch_geometric
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
### ========= Load Dataset and Setup GraphBench Components ========= ###
dataset_name = "electronic_circuits_5_eff" # name of the task or list of tasks
evaluator_name = "electroniccircuit"
# Use relative path to data directory (from project root)
data_path = "./data"
# Setting up the components of GraphBench
Evaluator = graphbench.Evaluator("electroniccircuit")
Loader = graphbench.Loader(data_path, dataset_name)
# Load a GraphBench dataset and get splits
dataset = Loader.load()
### ========================== Model Setup ========================== ###
class GINGraphRegressor(torch.nn.Module):
def __init__(self, in_channels, hidden=64, out_channels=1, num_layers=3, dropout=0.5):
super().__init__()
self.convs = torch.nn.ModuleList()
for i in range(num_layers):
mlp = Sequential(
Linear(in_channels if i == 0 else hidden, hidden),
BatchNorm1d(hidden),
ReLU(),
Linear(hidden, hidden),
ReLU(),
)
self.convs.append(GINConv(mlp))
self.lin1 = Linear(hidden + 1, hidden)
self.lin2 = Linear(hidden, out_channels)
self.dropout = dropout
def forward(self, data, duty):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = x.squeeze(1)
for conv in self.convs:
x = conv(x, edge_index)
x = F.relu(x)
g = global_mean_pool(x, batch)
g = torch.cat([g, duty.view(-1, 1)], dim=1)
g = self.lin1(g)
g = F.relu(g)
g = F.dropout(g, p=self.dropout, training=self.training)
g = self.lin2(g)
return torch.sigmoid(g)
first_graph_node_features = dataset[0]['train'][0].x
in_channels = first_graph_node_features.shape[2]
out_channels = 1
### ========================= Clean Dataset ========================= ###
# Training set has some nan values, we filter them out
def clean_dataset(data_list, name="Data"):
cleaned = []
for data in tqdm(data_list, desc=f"Cleaning {name}"):
# Check for NaNs in x, y, duty and check for empty graphs
if not (torch.isnan(data.x).any() or torch.isnan(data.y).any() or torch.isnan(data.duty).any() or data.num_nodes == 0):
cleaned.append(data)
print(f"{name}: Original size {len(data_list)}, Cleaned size {len(cleaned)}")
return cleaned
print("Filtering corrupted data...")
train_data = clean_dataset(dataset[0]['train'], "Train")
val_data = clean_dataset(dataset[0]['valid'], "Val")
test_data = clean_dataset(dataset[0]['test'], "Test")
### ========================= Training Setup ========================= ###
model = GINGraphRegressor(
in_channels=in_channels,
out_channels=out_channels,
hidden=64,
num_layers=3,
dropout=0.5,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()
# Create DataLoaders using the cleaned data
train_loader = torch_geometric.loader.DataLoader(train_data, batch_size=128, shuffle=True)
val_loader = torch_geometric.loader.DataLoader(val_data, batch_size=128, shuffle=False)
test_loader = torch_geometric.loader.DataLoader(test_data, batch_size=128, shuffle=False)
def train(model, loader, optimizer, loss_fn):
model.train()
total_loss = 0
for data in tqdm(loader, desc="Training"): # Wrap loader with tqdm
optimizer.zero_grad()
out = model(data, data.duty.float())
# Reshape data.y to match out shape [batch_size, 1]
loss = loss_fn(out, data.y.float().view(-1, 1))
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(loader.dataset)
def evaluate(model, loader, loss_fn):
model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for data in tqdm(loader, desc="Evaluating"): # Wrap loader with tqdm
out = model(data, data.duty.float())
loss = loss_fn(out, data.y.float().view(-1, 1))
total_loss += loss.item() * data.num_graphs
all_preds.append(out.squeeze(-1) if out.dim() > 1 and out.shape[1] == 1 else out)
all_labels.append(data.y.float())
return total_loss / len(loader.dataset), torch.cat(all_labels), torch.cat(all_preds)
# Training loop with Early Stopping
num_epochs = 1
patience = 3
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None
### ========================= Training Loop ========================= ###
for epoch in range(1, num_epochs + 1):
print(f'Epoch {epoch}/{num_epochs}')
loss = train(model, train_loader, optimizer, loss_fn)
val_loss, _, _ = evaluate(model, val_loader, loss_fn)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}')
# Early Stopping Logic
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
best_model_state = copy.deepcopy(model.state_dict())
print(f"Validation loss improved. Saving model.")
else:
patience_counter += 1
print(f"Validation loss did not improve. Counter: {patience_counter}/{patience}")
if patience_counter >= patience:
print("Early stopping triggered.")
break
# Load the best model
if best_model_state is not None:
model.load_state_dict(best_model_state)
print("Loaded best model weights.")
# Evaluate on the test set
test_loss, y_true, y_pred = evaluate(model, test_loader, loss_fn)
print(f'Test Loss: {test_loss:.4f}')
Evaluator = graphbench.Evaluator("electroniccircuit") # Re-instantiate if needed
results = Evaluator.evaluate(y_pred.view(-1, 1).cpu().numpy(), y_true.view(-1, 1).cpu().numpy())
print("GraphBench Evaluation Results:", results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment