TL;DR: Add BatchNormalization() layer before final dense layer to fix exploding loss on Apple Silicon. Alternatively, use mixed_precision.set_global_policy('float32') for a quick one-line fix.
IMPACTED VERSIONS
- Python: 3.11.9
- TensorFlow: 2.15.0
- TensorFlow-MacOS: 2.15
- TensorFlow-Metal: 1.1.0
Symptom: Simple TensorFlow Metal models on Mac M-series architecture often suffer from exploding loss after the first few epochs. Loss goes from < 0.3 to > 1000 in several epochs [1], [2].
This typically occurs after the first two or three epochs. The model will train as expected with decreasing loss and validation loss, but then in the fourth or fifth epoch, loss begins growing exponentially. Adding data augmentation via ImageDataGenerator or Keras preprocessing layers tends to magnify this issue.
Example Training Output:
Epoch 1/10: loss: 0.4154 - val_loss: 0.0782 - val_accuracy: 0.9749
Epoch 2/10: loss: 0.1896 - val_loss: 0.1014 - val_accuracy: 0.9700
Epoch 3/10: loss: 0.3001 - val_loss: 0.3074 - val_accuracy: 0.9574
Epoch 4/10: loss: 4.6413 - val_loss: 9.5222 - val_accuracy: 0.9277 ← Explosion starts
Epoch 5/10: loss: 69.2627 - val_loss: 53.7030 - val_accuracy: 0.9290
Epoch 6/10: loss: 262.3925 - val_loss: 383.2134 - val_accuracy: 0.8640
- TensorFlow Metal may default to mixed precision (16-bit + 32-bit floats)
- 16-bit floats (float16) have limited range: ~±65,000
- When values exceed this range → overflow → becomes infinity or NaN
- This typically happens around epoch 3 when values accumulate
- Augmentation introduces additional variation in inputs
- Without normalization, activations grow through layers:
- Layer 1 outputs get large
- Layer 2 receives large inputs → outputs get HUGE
- Layer 3 receives HUGE inputs → outputs get ASTRONOMICAL
- This cascade leads to exploding activations → exploding gradients → exploding loss
Best approach - Addresses root cause and improves model performance.
model = models.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(), # ← Add this
layers.Dense(10, activation='softmax')
])Why it works: Normalizes activations (mean=0, std=1) between layers, preventing the cascade explosion. Think of it like resetting the volume to a normal level between stages of an audio system.
Best practice placement:
layers.Dense(128),
layers.BatchNormalization(), # After dense, before activation
layers.Activation('relu'),One-line fix - Simple but doesn't improve the model.
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('float32') # Add at top of scriptWhy it works: Uses 32-bit floats with larger range (~±10^38), preventing overflow. Like using a bigger bucket so water doesn't spill over.
Prevents exploding gradients by limiting their magnitude.
from tensorflow.keras.optimizers import Adam
optimizer = Adam(learning_rate=1e-4, clipnorm=1.0)
model.compile(
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy']
)Why it works: Clips gradients to maximum norm of 1.0, directly preventing explosion. Like a speed limiter on a car.
Prevents overconfident predictions that can cause extreme loss values.
from tensorflow.keras.losses import CategoricalCrossentropy
loss = CategoricalCrossentropy(label_smoothing=0.1)
model.compile(
optimizer='adam',
loss=loss,
metrics=['accuracy']
)Why it works: Changes one-hot labels from hard targets [0, 0, 1, 0, ...] to soft targets [0.01, 0.01, 0.91, 0.01, ...], preventing overconfidence and softening the loss landscape.
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models
# Load data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
# Model with augmentation
model = models.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
layers.RandomZoom(height_factor=0.1, width_factor=0.1),
layers.Conv2D(32, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=10,
validation_data=(x_test, y_test)
)import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models
# Load data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
# Model with augmentation and BatchNorm
model = models.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
layers.RandomZoom(height_factor=0.1, width_factor=0.1),
layers.Conv2D(32, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(), # ← Fix: Add BatchNorm
layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=10,
validation_data=(x_test, y_test)
)Expected Output:
Epoch 1/10: loss: 1.2430 - val_accuracy: 0.9135
Epoch 2/10: loss: 0.5036 - val_accuracy: 0.9496
Epoch 3/10: loss: 0.3356 - val_accuracy: 0.9577
Epoch 4/10: loss: 0.2783 - val_accuracy: 0.9598 ← Stable
Epoch 5/10: loss: 0.2675 - val_accuracy: 0.9539
...
Epoch 10/10: loss: 0.4104 - val_accuracy: 0.9354
For maximum stability, use both float32 policy AND BatchNormalization:
from tensorflow.keras import mixed_precision
# Set precision policy first
mixed_precision.set_global_policy('float32')
# Model with BatchNorm
model = models.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
layers.RandomZoom(height_factor=0.1, width_factor=0.1),
layers.Conv2D(32, (3,3)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64, (3,3)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(128),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])DON'T normalize data AND use rescale=1./255 in ImageDataGenerator:
# Wrong - double normalization
x_train = x_train / 255.0 # First normalization
datagen = ImageDataGenerator(rescale=1./255) # Second normalization (0-1 → 0-0.0039)DO either normalize manually OR let ImageDataGenerator do it:
# Option 1: Manual normalization (no rescale in datagen)
x_train = x_train / 255.0
datagen = ImageDataGenerator(width_shift_range=0.1) # No rescale
# Option 2: Let datagen handle it
x_train = x_train # No division
datagen = ImageDataGenerator(rescale=1./255, width_shift_range=0.1)Always rebuild the model fresh when debugging. Once weights explode, they can't recover.
# Rebuild model from scratch
model = Sequential([...])
model.compile(...)Use categorical_crossentropy with one-hot encoded labels:
y_train = keras.utils.to_categorical(y_train, 10) # One-hot encode
model.compile(loss='categorical_crossentropy', ...)# Load and normalize
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
# Augmentation (NO rescale - already normalized)
datagen = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1
# NO rescale parameter!
)