Skip to content

Instantly share code, notes, and snippets.

@pswpswpsw
Created February 12, 2026 18:32
Show Gist options
  • Select an option

  • Save pswpswpsw/7012f3ec9d055539cb42d532ddd9fa98 to your computer and use it in GitHub Desktop.

Select an option

Save pswpswpsw/7012f3ec9d055539cb42d532ddd9fa98 to your computer and use it in GitHub Desktop.
Simple 2D Conv Autoencoder (in tensorflow 2)
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt
import argparse
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras import layers, models
from skimage.transform import resize
def mkdir(CASE_NAME):
if not os.path.exists(CASE_NAME):
os.makedirs(CASE_NAME)
def all_resize_to_512_256(image_data):
new_image_data = []
for j in range(image_data.shape[0]):
new_image = resize(image_data[j], (512,256), order=0, mode='constant')
new_image_data.append(new_image)
return np.array(new_image_data)
# def ragimg(image_data):
# # image data shape = (N_snap, 32x64)
# total_img_list = []
# for j in range(image_data.shape[0]):
# total_img = []
# for i in range(image_data.shape[1]):
# total_img.append(image_data[j,i,0:image_data.shape[1]])
# total_img.append(image_data[j,i,image_data.shape[1]:])
# total_img = np.stack(total_img,axis=0)
# total_img_list.append(total_img)
# return np.array(total_img_list)
# Options
# Shape Network Parameters
parser = argparse.ArgumentParser(description='config')
parser.add_argument('--TRAIN_DATA', type=str, help='normalized train data file path')
# parser.add_argument('--TEST_DATA',type=str, help='raw test data file path')
parser.add_argument('--CHANNEL_LIST', type=str, help='the number of channels distribution for CAE')
parser.add_argument('--REDUCED_DIM', type=str, help='bottleneck dimension')
parser.add_argument('--BATCH_SIZE', type=int, help='batch size', default=6400)
parser.add_argument('--EPOCHS', type=int, help='total number of iterations', default=10000)
parser.add_argument('--LR', type=float, help='learning rate', default=1e-3)
parser.add_argument('--ACT', type=str, help='activation function', default='swish')
args = parser.parse_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
if '32x64' in args.TRAIN_DATA:
INPUT_SHAPE = (64, 32, 1)
DIR_NAME = 'RT_CNN_32x64_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT
elif '64x128' in args.TRAIN_DATA:
INPUT_SHAPE = (128, 64, 1)
DIR_NAME = 'RT_CNN_64x128_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT
elif '128x256' in args.TRAIN_DATA:
INPUT_SHAPE = (256, 128, 1)
DIR_NAME = 'RT_CNN_128x256_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT
mkdir(DIR_NAME)
KERNEL_SIZE = (3, 3)
POOL_SCALING_SIZE = (2, 2)
CHANNEL_LIST = [int(item) for item in args.CHANNEL_LIST.split(',')]
ACTIVATION = args.ACT
# build CNN autoencoder
# get the CNN encoder
model = models.Sequential()
for i, channel in enumerate(CHANNEL_LIST):
if i == 0:
model.add(layers.Conv2D(channel, KERNEL_SIZE, activation=None, input_shape=INPUT_SHAPE, padding='same',
strides=POOL_SCALING_SIZE))
model.add(layers.BatchNormalization())
model.add(layers.Activation(ACTIVATION))
else:
model.add(layers.Conv2D(channel, KERNEL_SIZE, activation=None, padding='same',
strides=POOL_SCALING_SIZE))
model.add(layers.BatchNormalization())
model.add(layers.Activation(ACTIVATION))
# model.add(layers.AveragePooling2D(POOL_SCALING_SIZE))
last_shape = model.output_shape[1:]
# print(last_shape)
last_shape_dim = last_shape[0] * last_shape[1] * last_shape[2]
# flatten the SMALLEST "square"
model.add(layers.Flatten())
# get the latent representation!
model.add(layers.Dense(args.REDUCED_DIM))
model.add(layers.Dense(last_shape_dim, activation=ACTIVATION))
# reshape the flattened to the image
model.add(layers.Reshape(last_shape))
# model.summary()
# get the CNN decoder
for i, channel in enumerate(CHANNEL_LIST[::-1]):
if i == len(CHANNEL_LIST[::-1])-1:
model.add(layers.Conv2DTranspose(INPUT_SHAPE[-1], KERNEL_SIZE, strides=POOL_SCALING_SIZE, padding='same'))
else:
model.add(layers.Conv2DTranspose(CHANNEL_LIST[::-1][i+1], KERNEL_SIZE, strides=POOL_SCALING_SIZE, padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.Activation(ACTIVATION))
opt = optimizers.Adam(learning_rate=args.LR)
model.compile(optimizer=opt, loss='mean_squared_error')
model.summary()
# train the model
# - get training data
DATA = np.load(args.TRAIN_DATA)
x_train = np.expand_dims(DATA['train_data'], -1)
y_train = x_train
x_test = np.expand_dims(DATA['test_data'], -1)
y_test = x_test
# - get higher resolution data
H_RES_DATA = np.load('rt-256x512-cnn.npz')
x_train_hr = np.expand_dims(H_RES_DATA['train_data'], -1)
x_test_hr = np.expand_dims(H_RES_DATA['test_data'], -1)
# - train the model
train_loss = []
test_loss = []
for i in range(int(args.EPOCHS / 1000)):
result = model.fit(x_train, y_train,
batch_size=args.BATCH_SIZE,
validation_data=(x_test, y_test),
epochs=1000)
train_loss.append(result.history['loss'])
test_loss.append(result.history['val_loss'])
train_loss_ar = np.hstack(np.array(train_loss))
test_loss_ar = np.hstack(np.array(test_loss))
# plot loss at CURRENT resolution! so no 512x256
plt.figure()
plt.semilogy(train_loss_ar, label='train')
plt.semilogy(test_loss_ar, label='test')
plt.xlabel('epoch')
plt.ylabel('MSE')
plt.title('best test loss = %.4f' % np.min(test_loss_ar))
plt.legend(loc='best')
plt.savefig(DIR_NAME + '/loss.png')
plt.close()
# plot first TEST image comparison
y_pred_first = model.predict(x_train[[0]])
y_pred_last = model.predict(x_train[[-1]])
y_pred_test_last = model.predict(x_test[[-1]])
fig, axs = plt.subplots(3, 3, figsize=(16, 16))
axs[0, 0].imshow(y_pred_first[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[0, 0].set_title('first train snap - pred')
axs[0, 1].imshow(x_train[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[0, 1].set_title('first train snap - true')
axs[0, 2].imshow(x_train[0, :, :, 0]-y_pred_first[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5)
axs[0, 2].set_title('first train snap - err')
axs[1, 0].imshow(y_pred_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[1, 0].set_title('last train snap - pred')
axs[1, 1].imshow(x_train[-1, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[1, 1].set_title('last train snap - true')
axs[1, 2].imshow(x_train[-1, :, :, 0]-y_pred_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5)
axs[1, 2].set_title('last train snap - err')
axs[2, 0].imshow(y_pred_test_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[2, 0].set_title('last test snap - pred')
axs[2, 1].imshow(x_test[-1, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3)
axs[2, 1].set_title('last test snap - true')
axs[2, 2].imshow(x_test[-1, :, :, 0]-y_pred_test_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5)
axs[2, 2].set_title('last test snap - err')
for i1 in range(2):
for i2 in range(2):
axs[i1, i2].grid(False)
plt.savefig(DIR_NAME + '/compare_vs_gt_under_low_res_1st_test.png', bbox_inches='tight')
plt.close()
# compute 3 errors at high resolution
# - train
y_pred = all_resize_to_512_256(model.predict(x_train)[:,:,:,0])
y_true = all_resize_to_512_256(x_train[:,:,:,0])
y_true_hr = x_train_hr[:,:,:,0]
train_mse_per_pixel_12 = np.linalg.norm(y_pred - y_true, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
train_mse_per_pixel_23 = np.linalg.norm(y_true - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
train_mse_per_pixel_13 = np.linalg.norm(y_pred - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
# - test
y_pred = all_resize_to_512_256(model.predict(x_test)[:,:,:,0])
y_true = all_resize_to_512_256(x_test[:,:,:,0])
y_true_hr = x_test_hr[:,:,:,0]
test_mse_per_pixel_12 = np.linalg.norm(y_pred - y_true, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
test_mse_per_pixel_23 = np.linalg.norm(y_true - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
test_mse_per_pixel_13 = np.linalg.norm(y_pred - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2]
plt.figure(figsize=(8,8))
plt.plot(train_mse_per_pixel_12,'r-',label='learning error (train)')
plt.plot(train_mse_per_pixel_23,'b-',label='projection error (train)')
plt.plot(train_mse_per_pixel_13,'k-',label='total error (train)')
plt.plot(test_mse_per_pixel_12,'r--',label='learning error (test)' )
plt.plot(test_mse_per_pixel_23,'b--',label='projection error (test)')
plt.plot(test_mse_per_pixel_13,'k--',label='total error (test)')
plt.xlabel('time index')
plt.ylabel('MSE per pixel')
plt.legend(loc='best',fontsize=10)
plt.yscale('log')
# plt.ylim([5e-5,4e-2])
plt.savefig(DIR_NAME + '/3_errors_vs_time.png', bbox_inches='tight',dpi=200)
plt.close()
## save errors profiles
np.savez(DIR_NAME+'/err.npz',
train_mse_per_pixel_12=train_mse_per_pixel_12,
train_mse_per_pixel_23=train_mse_per_pixel_23,
train_mse_per_pixel_13=train_mse_per_pixel_13,
test_mse_per_pixel_12=test_mse_per_pixel_12,
test_mse_per_pixel_23=test_mse_per_pixel_23,
test_mse_per_pixel_13=test_mse_per_pixel_13)
## save predictions at HIGH RESOLUTION on testing data
np.savez(DIR_NAME + '/pred.npz',
test_pod_pred_projected_on_hr=y_pred,
test_pod_true_projected_on_hr=y_true,
test_true_on_hr=y_true_hr
)
# save the model
model.save(DIR_NAME)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment