When training deep learning models, session crashes are inevitable (kernel restarts, OOM errors, connection issues). Here's how to recover your work when using Weights & Biases for experiment tracking.
W&B saves metrics in real-time - even if your session crashes, all training history up to that point is preserved on W&B servers. You don't lose your work.
Configure your training to save models regularly using Keras' ModelCheckpoint:
from pathlib import Path
from datetime import datetime
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from wandb.integration.keras import WandbMetricsLogger
import tensorflow.keras.backend as K
# Model save location
model_path = Path('path/to/models/')
model_path.mkdir(parents=True, exist_ok=True)
# Extract model info for filename
lr = float(K.get_value(model.optimizer.learning_rate))
optimizer_name = model.optimizer.__class__.__name__.lower()
timestamp = datetime.now().strftime("%y%m%d-%H%M")
lr_str = f"{lr:.0e}".replace('-', '') # e.g., 0.001 -> 1e03
# Create descriptive filename with epoch and metric placeholders
model_filename = (
f"model_"
f"{timestamp}_"
f"{optimizer_name}_"
f"lr{lr_str}_"
f"epoch{{epoch:02d}}_" # Keras will fill this in
f"metric{{val_metric:.4f}}_" # Keras will fill this in
f"best.h5"
)
# W&B run name (without epoch/metric)
run_name = (
f"model_"
f"{timestamp}_"
f"{optimizer_name}_"
f"lr{lr_str}"
)
# Configure W&B
config = {
"architecture": "your_architecture",
"total_params": model.count_params(),
"optimizer": model.optimizer.__class__.__name__,
"learning_rate": float(K.get_value(model.optimizer.learning_rate)),
"loss": model.loss if isinstance(model.loss, str) else model.loss.__name__,
"batch_size": batch_size,
"epochs": num_epochs,
}
wandb.init(
project="your-project-name",
name=run_name,
notes="Brief description of this experiment",
tags=["relevant", "tags"],
config=config
)
# Setup callbacks
callbacks = [
# W&B metrics logging (saves to cloud in real-time)
WandbMetricsLogger(log_freq="epoch"),
# Keras model checkpointing (saves to disk)
ModelCheckpoint(
filepath=str(model_path / model_filename),
monitor="val_metric", # e.g., "val_f1", "val_loss"
mode="max", # "max" for metrics like F1, "min" for loss
save_best_only=False, # Set True to only save improvements
save_weights_only=False,
verbose=1
),
# Optional: Early stopping
EarlyStopping(
monitor='val_metric',
patience=5,
restore_best_weights=True,
mode='max'
)
]
# Train
history = model.fit(
train_data,
epochs=num_epochs,
validation_data=val_data,
callbacks=callbacks
)from tensorflow.keras.models import load_model
# Find your model file (it will have epoch and metric in filename)
# e.g., model_251129-0831_adam_lr1e03_epoch18_metric0.8015_best.h5
model_path = 'path/to/models/your_model_file.h5'
# Load with custom objects if needed
model = load_model(
model_path,
custom_objects={'custom_metric': custom_metric} # Add your custom metrics/losses
)
print(f"Model loaded successfully")
print(f"Best metric from filename: 0.8015") # Read from filenameW&B saves all metrics in real-time. Retrieve them even after crash:
import wandb
import pandas as pd
# Initialize W&B API
api = wandb.Api()
# Method 1: Get run_id from W&B dashboard
# Go to your project → find your run → click "..." menu → "Copy" → "Run path"
# Example run path: "username/project-name/run_id"
run_id = "abc123xyz"
run = api.run(f"username/project-name/{run_id}")
# Method 2: Search by run name
runs = api.runs("username/project-name",
filters={"display_name": "model_251129-0831_adam_lr1e03"})
run = runs[0]
# Download history as DataFrame
history_df = run.history()
# Display available columns
print(history_df.columns.tolist())
# Typical columns: epoch/loss, epoch/val_loss, epoch/metric, epoch/val_metric, etc.Many notebooks expect a Keras History object. Create a compatible mock:
# Extract metrics with correct column names
history_dict = {
'loss': history_df['epoch/loss'].tolist(),
'val_loss': history_df['epoch/val_loss'].tolist(),
'metric': history_df['epoch/metric'].tolist(), # e.g., 'f1', 'accuracy'
'val_metric': history_df['epoch/val_metric'].tolist(),
'lr': history_df['epoch/learning_rate'].tolist()
}
# Create mock History object that mimics Keras History
class MockHistory:
def __init__(self, history_dict):
self.history = history_dict
history = MockHistory(history_dict)
# Now you can use history.history['loss'] as normal
print(f"Training epochs: {len(history.history['loss'])}")
print(f"Best val_metric: {max(history.history['val_metric']):.4f}")The mock history object works with standard Keras plotting code:
import matplotlib.pyplot as plt
import numpy as np
# Standard Keras history plotting (no changes needed)
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = np.arange(1, len(loss) + 1)
plt.plot(epochs, loss, label='train_loss')
plt.plot(epochs, val_loss, label='val_loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(np.arange(1, len(loss) + 1, 3))
plt.show()# Option A: Evaluate loaded model on validation set
val_results = model.evaluate(val_data, verbose=1)
print(f"Validation results: {val_results}")
# Option B: Continue training from checkpoint
new_history = model.fit(
train_data,
epochs=additional_epochs,
initial_epoch=18, # Start from where it crashed
validation_data=val_data,
callbacks=callbacks
)-
Always use
ModelCheckpoint- Don't rely solely on W&B model artifacts (they can have compatibility issues) -
Use descriptive filenames - Include timestamp, hyperparameters, and metric placeholders so you can identify the best checkpoint
-
Set
save_best_only=Falseinitially - You can always delete checkpoints later, but you can't recover a model you didn't save -
Check W&B dashboard regularly - Verify metrics are uploading during training
-
Save the run_id - Print it at training start:
run = wandb.init(...)
print(f"W&B Run ID: {run.id}")- Avoid
WandbModelCheckpoint- Use KerasModelCheckpointinstead to avoid potential library compatibility issues
Q: My history DataFrame has different column names
A: W&B prefixes metrics with epoch/. Check actual column names:
print(history_df.columns.tolist())Q: Can I get metrics from multiple runs?
A: Yes, query multiple runs and compare:
runs = api.runs("username/project", filters={"tags": "experiment1"})
for run in runs:
history = run.history()
best_metric = history['epoch/val_metric'].max()
print(f"{run.name}: {best_metric:.4f}")Q: The model file isn't in my directory
A: Check if ModelCheckpoint actually saved (requires monitor metric to exist). Look for verbose output during training.