Skip to content

Instantly share code, notes, and snippets.

@txoof
Created November 29, 2025 11:24
Show Gist options
  • Select an option

  • Save txoof/3b1d5c3e8d5f68193903e549fbde5d9c to your computer and use it in GitHub Desktop.

Select an option

Save txoof/3b1d5c3e8d5f68193903e549fbde5d9c to your computer and use it in GitHub Desktop.
Recovering from Crashed Training Sessions with Weights & Biases

Recovering from Crashed Training Sessions with Weights & Biases

When training deep learning models, session crashes are inevitable (kernel restarts, OOM errors, connection issues). Here's how to recover your work when using Weights & Biases for experiment tracking.

Key Advantage of W&B

W&B saves metrics in real-time - even if your session crashes, all training history up to that point is preserved on W&B servers. You don't lose your work.

Setup: Ensure Reliable Model Checkpointing

Configure your training to save models regularly using Keras' ModelCheckpoint:

from pathlib import Path
from datetime import datetime
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from wandb.integration.keras import WandbMetricsLogger
import tensorflow.keras.backend as K

# Model save location
model_path = Path('path/to/models/')
model_path.mkdir(parents=True, exist_ok=True)

# Extract model info for filename
lr = float(K.get_value(model.optimizer.learning_rate))
optimizer_name = model.optimizer.__class__.__name__.lower()
timestamp = datetime.now().strftime("%y%m%d-%H%M")
lr_str = f"{lr:.0e}".replace('-', '')  # e.g., 0.001 -> 1e03

# Create descriptive filename with epoch and metric placeholders
model_filename = (
    f"model_"
    f"{timestamp}_"
    f"{optimizer_name}_"
    f"lr{lr_str}_"
    f"epoch{{epoch:02d}}_"      # Keras will fill this in
    f"metric{{val_metric:.4f}}_" # Keras will fill this in
    f"best.h5"
)

# W&B run name (without epoch/metric)
run_name = (
    f"model_"
    f"{timestamp}_"
    f"{optimizer_name}_"
    f"lr{lr_str}"
)

# Configure W&B
config = {
    "architecture": "your_architecture",
    "total_params": model.count_params(),
    "optimizer": model.optimizer.__class__.__name__,
    "learning_rate": float(K.get_value(model.optimizer.learning_rate)),
    "loss": model.loss if isinstance(model.loss, str) else model.loss.__name__,
    "batch_size": batch_size,
    "epochs": num_epochs,
}

wandb.init(
    project="your-project-name",
    name=run_name,
    notes="Brief description of this experiment",
    tags=["relevant", "tags"],
    config=config
)

# Setup callbacks
callbacks = [
    # W&B metrics logging (saves to cloud in real-time)
    WandbMetricsLogger(log_freq="epoch"),
    
    # Keras model checkpointing (saves to disk)
    ModelCheckpoint(
        filepath=str(model_path / model_filename),
        monitor="val_metric",  # e.g., "val_f1", "val_loss"
        mode="max",            # "max" for metrics like F1, "min" for loss
        save_best_only=False,  # Set True to only save improvements
        save_weights_only=False,
        verbose=1
    ),
    
    # Optional: Early stopping
    EarlyStopping(
        monitor='val_metric',
        patience=5,
        restore_best_weights=True,
        mode='max'
    )
]

# Train
history = model.fit(
    train_data,
    epochs=num_epochs,
    validation_data=val_data,
    callbacks=callbacks
)

Recovery After a Crash

Step 1: Load Your Saved Model

from tensorflow.keras.models import load_model

# Find your model file (it will have epoch and metric in filename)
# e.g., model_251129-0831_adam_lr1e03_epoch18_metric0.8015_best.h5
model_path = 'path/to/models/your_model_file.h5'

# Load with custom objects if needed
model = load_model(
    model_path,
    custom_objects={'custom_metric': custom_metric}  # Add your custom metrics/losses
)

print(f"Model loaded successfully")
print(f"Best metric from filename: 0.8015")  # Read from filename

Step 2: Retrieve Training History from W&B

W&B saves all metrics in real-time. Retrieve them even after crash:

import wandb
import pandas as pd

# Initialize W&B API
api = wandb.Api()

# Method 1: Get run_id from W&B dashboard
# Go to your project → find your run → click "..." menu → "Copy" → "Run path"
# Example run path: "username/project-name/run_id"
run_id = "abc123xyz"
run = api.run(f"username/project-name/{run_id}")

# Method 2: Search by run name
runs = api.runs("username/project-name", 
                filters={"display_name": "model_251129-0831_adam_lr1e03"})
run = runs[0]

# Download history as DataFrame
history_df = run.history()

# Display available columns
print(history_df.columns.tolist())
# Typical columns: epoch/loss, epoch/val_loss, epoch/metric, epoch/val_metric, etc.

Step 3: Convert to Keras History Format

Many notebooks expect a Keras History object. Create a compatible mock:

# Extract metrics with correct column names
history_dict = {
    'loss': history_df['epoch/loss'].tolist(),
    'val_loss': history_df['epoch/val_loss'].tolist(),
    'metric': history_df['epoch/metric'].tolist(),        # e.g., 'f1', 'accuracy'
    'val_metric': history_df['epoch/val_metric'].tolist(),
    'lr': history_df['epoch/learning_rate'].tolist()
}

# Create mock History object that mimics Keras History
class MockHistory:
    def __init__(self, history_dict):
        self.history = history_dict

history = MockHistory(history_dict)

# Now you can use history.history['loss'] as normal
print(f"Training epochs: {len(history.history['loss'])}")
print(f"Best val_metric: {max(history.history['val_metric']):.4f}")

Step 4: Continue With Your Notebook

The mock history object works with standard Keras plotting code:

import matplotlib.pyplot as plt
import numpy as np

# Standard Keras history plotting (no changes needed)
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = np.arange(1, len(loss) + 1)

plt.plot(epochs, loss, label='train_loss')
plt.plot(epochs, val_loss, label='val_loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(np.arange(1, len(loss) + 1, 3))
plt.show()

Step 5: Evaluate or Continue Training (Optional)

# Option A: Evaluate loaded model on validation set
val_results = model.evaluate(val_data, verbose=1)
print(f"Validation results: {val_results}")

# Option B: Continue training from checkpoint
new_history = model.fit(
    train_data,
    epochs=additional_epochs,
    initial_epoch=18,  # Start from where it crashed
    validation_data=val_data,
    callbacks=callbacks
)

Best Practices

  1. Always use ModelCheckpoint - Don't rely solely on W&B model artifacts (they can have compatibility issues)

  2. Use descriptive filenames - Include timestamp, hyperparameters, and metric placeholders so you can identify the best checkpoint

  3. Set save_best_only=False initially - You can always delete checkpoints later, but you can't recover a model you didn't save

  4. Check W&B dashboard regularly - Verify metrics are uploading during training

  5. Save the run_id - Print it at training start:

   run = wandb.init(...)
   print(f"W&B Run ID: {run.id}")
  1. Avoid WandbModelCheckpoint - Use Keras ModelCheckpoint instead to avoid potential library compatibility issues

Common Issues

Q: My history DataFrame has different column names

A: W&B prefixes metrics with epoch/. Check actual column names:

print(history_df.columns.tolist())

Q: Can I get metrics from multiple runs?

A: Yes, query multiple runs and compare:

runs = api.runs("username/project", filters={"tags": "experiment1"})
for run in runs:
    history = run.history()
    best_metric = history['epoch/val_metric'].max()
    print(f"{run.name}: {best_metric:.4f}")

Q: The model file isn't in my directory

A: Check if ModelCheckpoint actually saved (requires monitor metric to exist). Look for verbose output during training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment