Skip to content

Instantly share code, notes, and snippets.

@txoof
Last active November 8, 2025 06:43
Show Gist options
  • Select an option

  • Save txoof/2973522f45f4cea980648e12dfd1d3d0 to your computer and use it in GitHub Desktop.

Select an option

Save txoof/2973522f45f4cea980648e12dfd1d3d0 to your computer and use it in GitHub Desktop.
Avoiding Exploding Loss in Tensor Metal

Issue: Exploding Loss for Simple TensorFlow Metal Models

TL;DR: Add BatchNormalization() layer before final dense layer to fix exploding loss on Apple Silicon. Alternatively, use mixed_precision.set_global_policy('float32') for a quick one-line fix.


Problem Description

IMPACTED VERSIONS

  • Python: 3.11.9
  • TensorFlow: 2.15.0
  • TensorFlow-MacOS: 2.15
  • TensorFlow-Metal: 1.1.0

Symptom: Simple TensorFlow Metal models on Mac M-series architecture often suffer from exploding loss after the first few epochs. Loss goes from < 0.3 to > 1000 in several epochs [1], [2].

This typically occurs after the first two or three epochs. The model will train as expected with decreasing loss and validation loss, but then in the fourth or fifth epoch, loss begins growing exponentially. Adding data augmentation via ImageDataGenerator or Keras preprocessing layers tends to magnify this issue.

Example Training Output:

Epoch 1/10: loss: 0.4154 - val_loss: 0.0782 - val_accuracy: 0.9749
Epoch 2/10: loss: 0.1896 - val_loss: 0.1014 - val_accuracy: 0.9700
Epoch 3/10: loss: 0.3001 - val_loss: 0.3074 - val_accuracy: 0.9574
Epoch 4/10: loss: 4.6413 - val_loss: 9.5222 - val_accuracy: 0.9277  ← Explosion starts
Epoch 5/10: loss: 69.2627 - val_loss: 53.7030 - val_accuracy: 0.9290
Epoch 6/10: loss: 262.3925 - val_loss: 383.2134 - val_accuracy: 0.8640

Root Causes

1. Numerical Instability with Mixed Precision

  • TensorFlow Metal may default to mixed precision (16-bit + 32-bit floats)
  • 16-bit floats (float16) have limited range: ~±65,000
  • When values exceed this range → overflow → becomes infinity or NaN
  • This typically happens around epoch 3 when values accumulate

2. Cascading Activation Explosion

  • Augmentation introduces additional variation in inputs
  • Without normalization, activations grow through layers:
    • Layer 1 outputs get large
    • Layer 2 receives large inputs → outputs get HUGE
    • Layer 3 receives HUGE inputs → outputs get ASTRONOMICAL
  • This cascade leads to exploding activations → exploding gradients → exploding loss

Solutions

Solution 1: Add BatchNormalization (RECOMMENDED)

Best approach - Addresses root cause and improves model performance.

model = models.Sequential([
    layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.BatchNormalization(),  # ← Add this
    layers.Dense(10, activation='softmax')
])

Why it works: Normalizes activations (mean=0, std=1) between layers, preventing the cascade explosion. Think of it like resetting the volume to a normal level between stages of an audio system.

Best practice placement:

layers.Dense(128),
layers.BatchNormalization(),  # After dense, before activation
layers.Activation('relu'),

Solution 2: Force Float32 Precision (QUICK FIX)

One-line fix - Simple but doesn't improve the model.

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('float32')  # Add at top of script

Why it works: Uses 32-bit floats with larger range (~±10^38), preventing overflow. Like using a bigger bucket so water doesn't spill over.


Solution 3: Gradient Clipping (ADDITIONAL SAFETY)

Prevents exploding gradients by limiting their magnitude.

from tensorflow.keras.optimizers import Adam
optimizer = Adam(learning_rate=1e-4, clipnorm=1.0)

model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

Why it works: Clips gradients to maximum norm of 1.0, directly preventing explosion. Like a speed limiter on a car.


Solution 4: Label Smoothing (OPTIONAL)

Prevents overconfident predictions that can cause extreme loss values.

from tensorflow.keras.losses import CategoricalCrossentropy
loss = CategoricalCrossentropy(label_smoothing=0.1)

model.compile(
    optimizer='adam',
    loss=loss,
    metrics=['accuracy']
)

Why it works: Changes one-hot labels from hard targets [0, 0, 1, 0, ...] to soft targets [0.01, 0.01, 0.91, 0.01, ...], preventing overconfidence and softening the loss landscape.


Sample Code

Code That Causes Exploding Loss

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models

# Load data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# Model with augmentation
model = models.Sequential([
    layers.Input(shape=(28, 28, 1)),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
    layers.RandomZoom(height_factor=0.1, width_factor=0.1),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    x_train, y_train,
    batch_size=128,
    epochs=10,
    validation_data=(x_test, y_test)
)

Fixed Code with BatchNormalization

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models

# Load data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# Model with augmentation and BatchNorm
model = models.Sequential([
    layers.Input(shape=(28, 28, 1)),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
    layers.RandomZoom(height_factor=0.1, width_factor=0.1),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.BatchNormalization(),  # ← Fix: Add BatchNorm
    layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    x_train, y_train,
    batch_size=128,
    epochs=10,
    validation_data=(x_test, y_test)
)

Expected Output:

Epoch 1/10: loss: 1.2430 - val_accuracy: 0.9135
Epoch 2/10: loss: 0.5036 - val_accuracy: 0.9496
Epoch 3/10: loss: 0.3356 - val_accuracy: 0.9577
Epoch 4/10: loss: 0.2783 - val_accuracy: 0.9598  ← Stable
Epoch 5/10: loss: 0.2675 - val_accuracy: 0.9539
...
Epoch 10/10: loss: 0.4104 - val_accuracy: 0.9354

Recommended Combined Approach

For maximum stability, use both float32 policy AND BatchNormalization:

from tensorflow.keras import mixed_precision

# Set precision policy first
mixed_precision.set_global_policy('float32')

# Model with BatchNorm
model = models.Sequential([
    layers.Input(shape=(28, 28, 1)),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
    layers.RandomZoom(height_factor=0.1, width_factor=0.1),
    
    layers.Conv2D(32, (3,3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPooling2D((2,2)),
    
    layers.Conv2D(64, (3,3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPooling2D((2,2)),
    
    layers.Flatten(),
    layers.Dense(128),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

Common Mistakes to Avoid

1. Double Normalization

DON'T normalize data AND use rescale=1./255 in ImageDataGenerator:

# Wrong - double normalization
x_train = x_train / 255.0  # First normalization
datagen = ImageDataGenerator(rescale=1./255)  # Second normalization (0-1 → 0-0.0039)

DO either normalize manually OR let ImageDataGenerator do it:

# Option 1: Manual normalization (no rescale in datagen)
x_train = x_train / 255.0
datagen = ImageDataGenerator(width_shift_range=0.1)  # No rescale

# Option 2: Let datagen handle it
x_train = x_train  # No division
datagen = ImageDataGenerator(rescale=1./255, width_shift_range=0.1)

2. Retraining on Exploded Weights

Always rebuild the model fresh when debugging. Once weights explode, they can't recover.

# Rebuild model from scratch
model = Sequential([...])
model.compile(...)

3. Wrong Loss Function

Use categorical_crossentropy with one-hot encoded labels:

y_train = keras.utils.to_categorical(y_train, 10)  # One-hot encode
model.compile(loss='categorical_crossentropy', ...)

Data Preprocessing Template

# Load and normalize
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# Augmentation (NO rescale - already normalized)
datagen = ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1
    # NO rescale parameter!
)

References

  1. GH - Apple Silicon M2 Loss Goes Up After 1-2 (Few) Epochs but not on Google Colab #140
  2. SO - Loss increasing to extremely high numbers during training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment