Created
February 12, 2026 18:32
-
-
Save pswpswpsw/7012f3ec9d055539cb42d532ddd9fa98 to your computer and use it in GitHub Desktop.
Simple 2D Conv Autoencoder (in tensorflow 2)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib | |
| matplotlib.use('pdf') | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| import numpy as np | |
| import os | |
| import tensorflow as tf | |
| from tensorflow.keras import optimizers | |
| from tensorflow.keras import layers, models | |
| from skimage.transform import resize | |
| def mkdir(CASE_NAME): | |
| if not os.path.exists(CASE_NAME): | |
| os.makedirs(CASE_NAME) | |
| def all_resize_to_512_256(image_data): | |
| new_image_data = [] | |
| for j in range(image_data.shape[0]): | |
| new_image = resize(image_data[j], (512,256), order=0, mode='constant') | |
| new_image_data.append(new_image) | |
| return np.array(new_image_data) | |
| # def ragimg(image_data): | |
| # # image data shape = (N_snap, 32x64) | |
| # total_img_list = [] | |
| # for j in range(image_data.shape[0]): | |
| # total_img = [] | |
| # for i in range(image_data.shape[1]): | |
| # total_img.append(image_data[j,i,0:image_data.shape[1]]) | |
| # total_img.append(image_data[j,i,image_data.shape[1]:]) | |
| # total_img = np.stack(total_img,axis=0) | |
| # total_img_list.append(total_img) | |
| # return np.array(total_img_list) | |
| # Options | |
| # Shape Network Parameters | |
| parser = argparse.ArgumentParser(description='config') | |
| parser.add_argument('--TRAIN_DATA', type=str, help='normalized train data file path') | |
| # parser.add_argument('--TEST_DATA',type=str, help='raw test data file path') | |
| parser.add_argument('--CHANNEL_LIST', type=str, help='the number of channels distribution for CAE') | |
| parser.add_argument('--REDUCED_DIM', type=str, help='bottleneck dimension') | |
| parser.add_argument('--BATCH_SIZE', type=int, help='batch size', default=6400) | |
| parser.add_argument('--EPOCHS', type=int, help='total number of iterations', default=10000) | |
| parser.add_argument('--LR', type=float, help='learning rate', default=1e-3) | |
| parser.add_argument('--ACT', type=str, help='activation function', default='swish') | |
| args = parser.parse_args() | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| try: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| except RuntimeError as e: | |
| print(e) | |
| if '32x64' in args.TRAIN_DATA: | |
| INPUT_SHAPE = (64, 32, 1) | |
| DIR_NAME = 'RT_CNN_32x64_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT | |
| elif '64x128' in args.TRAIN_DATA: | |
| INPUT_SHAPE = (128, 64, 1) | |
| DIR_NAME = 'RT_CNN_64x128_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT | |
| elif '128x256' in args.TRAIN_DATA: | |
| INPUT_SHAPE = (256, 128, 1) | |
| DIR_NAME = 'RT_CNN_128x256_' + args.CHANNEL_LIST + '_RANK_' + args.REDUCED_DIM + '_ACT_' + args.ACT | |
| mkdir(DIR_NAME) | |
| KERNEL_SIZE = (3, 3) | |
| POOL_SCALING_SIZE = (2, 2) | |
| CHANNEL_LIST = [int(item) for item in args.CHANNEL_LIST.split(',')] | |
| ACTIVATION = args.ACT | |
| # build CNN autoencoder | |
| # get the CNN encoder | |
| model = models.Sequential() | |
| for i, channel in enumerate(CHANNEL_LIST): | |
| if i == 0: | |
| model.add(layers.Conv2D(channel, KERNEL_SIZE, activation=None, input_shape=INPUT_SHAPE, padding='same', | |
| strides=POOL_SCALING_SIZE)) | |
| model.add(layers.BatchNormalization()) | |
| model.add(layers.Activation(ACTIVATION)) | |
| else: | |
| model.add(layers.Conv2D(channel, KERNEL_SIZE, activation=None, padding='same', | |
| strides=POOL_SCALING_SIZE)) | |
| model.add(layers.BatchNormalization()) | |
| model.add(layers.Activation(ACTIVATION)) | |
| # model.add(layers.AveragePooling2D(POOL_SCALING_SIZE)) | |
| last_shape = model.output_shape[1:] | |
| # print(last_shape) | |
| last_shape_dim = last_shape[0] * last_shape[1] * last_shape[2] | |
| # flatten the SMALLEST "square" | |
| model.add(layers.Flatten()) | |
| # get the latent representation! | |
| model.add(layers.Dense(args.REDUCED_DIM)) | |
| model.add(layers.Dense(last_shape_dim, activation=ACTIVATION)) | |
| # reshape the flattened to the image | |
| model.add(layers.Reshape(last_shape)) | |
| # model.summary() | |
| # get the CNN decoder | |
| for i, channel in enumerate(CHANNEL_LIST[::-1]): | |
| if i == len(CHANNEL_LIST[::-1])-1: | |
| model.add(layers.Conv2DTranspose(INPUT_SHAPE[-1], KERNEL_SIZE, strides=POOL_SCALING_SIZE, padding='same')) | |
| else: | |
| model.add(layers.Conv2DTranspose(CHANNEL_LIST[::-1][i+1], KERNEL_SIZE, strides=POOL_SCALING_SIZE, padding='same')) | |
| model.add(layers.BatchNormalization()) | |
| model.add(layers.Activation(ACTIVATION)) | |
| opt = optimizers.Adam(learning_rate=args.LR) | |
| model.compile(optimizer=opt, loss='mean_squared_error') | |
| model.summary() | |
| # train the model | |
| # - get training data | |
| DATA = np.load(args.TRAIN_DATA) | |
| x_train = np.expand_dims(DATA['train_data'], -1) | |
| y_train = x_train | |
| x_test = np.expand_dims(DATA['test_data'], -1) | |
| y_test = x_test | |
| # - get higher resolution data | |
| H_RES_DATA = np.load('rt-256x512-cnn.npz') | |
| x_train_hr = np.expand_dims(H_RES_DATA['train_data'], -1) | |
| x_test_hr = np.expand_dims(H_RES_DATA['test_data'], -1) | |
| # - train the model | |
| train_loss = [] | |
| test_loss = [] | |
| for i in range(int(args.EPOCHS / 1000)): | |
| result = model.fit(x_train, y_train, | |
| batch_size=args.BATCH_SIZE, | |
| validation_data=(x_test, y_test), | |
| epochs=1000) | |
| train_loss.append(result.history['loss']) | |
| test_loss.append(result.history['val_loss']) | |
| train_loss_ar = np.hstack(np.array(train_loss)) | |
| test_loss_ar = np.hstack(np.array(test_loss)) | |
| # plot loss at CURRENT resolution! so no 512x256 | |
| plt.figure() | |
| plt.semilogy(train_loss_ar, label='train') | |
| plt.semilogy(test_loss_ar, label='test') | |
| plt.xlabel('epoch') | |
| plt.ylabel('MSE') | |
| plt.title('best test loss = %.4f' % np.min(test_loss_ar)) | |
| plt.legend(loc='best') | |
| plt.savefig(DIR_NAME + '/loss.png') | |
| plt.close() | |
| # plot first TEST image comparison | |
| y_pred_first = model.predict(x_train[[0]]) | |
| y_pred_last = model.predict(x_train[[-1]]) | |
| y_pred_test_last = model.predict(x_test[[-1]]) | |
| fig, axs = plt.subplots(3, 3, figsize=(16, 16)) | |
| axs[0, 0].imshow(y_pred_first[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[0, 0].set_title('first train snap - pred') | |
| axs[0, 1].imshow(x_train[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[0, 1].set_title('first train snap - true') | |
| axs[0, 2].imshow(x_train[0, :, :, 0]-y_pred_first[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5) | |
| axs[0, 2].set_title('first train snap - err') | |
| axs[1, 0].imshow(y_pred_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[1, 0].set_title('last train snap - pred') | |
| axs[1, 1].imshow(x_train[-1, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[1, 1].set_title('last train snap - true') | |
| axs[1, 2].imshow(x_train[-1, :, :, 0]-y_pred_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5) | |
| axs[1, 2].set_title('last train snap - err') | |
| axs[2, 0].imshow(y_pred_test_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[2, 0].set_title('last test snap - pred') | |
| axs[2, 1].imshow(x_test[-1, :, :, 0], origin='lower', cmap='RdBu', vmin=0.9, vmax=2.3) | |
| axs[2, 1].set_title('last test snap - true') | |
| axs[2, 2].imshow(x_test[-1, :, :, 0]-y_pred_test_last[0, :, :, 0], origin='lower', cmap='RdBu', vmin=-0.5, vmax=0.5) | |
| axs[2, 2].set_title('last test snap - err') | |
| for i1 in range(2): | |
| for i2 in range(2): | |
| axs[i1, i2].grid(False) | |
| plt.savefig(DIR_NAME + '/compare_vs_gt_under_low_res_1st_test.png', bbox_inches='tight') | |
| plt.close() | |
| # compute 3 errors at high resolution | |
| # - train | |
| y_pred = all_resize_to_512_256(model.predict(x_train)[:,:,:,0]) | |
| y_true = all_resize_to_512_256(x_train[:,:,:,0]) | |
| y_true_hr = x_train_hr[:,:,:,0] | |
| train_mse_per_pixel_12 = np.linalg.norm(y_pred - y_true, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| train_mse_per_pixel_23 = np.linalg.norm(y_true - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| train_mse_per_pixel_13 = np.linalg.norm(y_pred - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| # - test | |
| y_pred = all_resize_to_512_256(model.predict(x_test)[:,:,:,0]) | |
| y_true = all_resize_to_512_256(x_test[:,:,:,0]) | |
| y_true_hr = x_test_hr[:,:,:,0] | |
| test_mse_per_pixel_12 = np.linalg.norm(y_pred - y_true, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| test_mse_per_pixel_23 = np.linalg.norm(y_true - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| test_mse_per_pixel_13 = np.linalg.norm(y_pred - y_true_hr, axis=(1,2))**2/y_pred.shape[1]/y_pred.shape[2] | |
| plt.figure(figsize=(8,8)) | |
| plt.plot(train_mse_per_pixel_12,'r-',label='learning error (train)') | |
| plt.plot(train_mse_per_pixel_23,'b-',label='projection error (train)') | |
| plt.plot(train_mse_per_pixel_13,'k-',label='total error (train)') | |
| plt.plot(test_mse_per_pixel_12,'r--',label='learning error (test)' ) | |
| plt.plot(test_mse_per_pixel_23,'b--',label='projection error (test)') | |
| plt.plot(test_mse_per_pixel_13,'k--',label='total error (test)') | |
| plt.xlabel('time index') | |
| plt.ylabel('MSE per pixel') | |
| plt.legend(loc='best',fontsize=10) | |
| plt.yscale('log') | |
| # plt.ylim([5e-5,4e-2]) | |
| plt.savefig(DIR_NAME + '/3_errors_vs_time.png', bbox_inches='tight',dpi=200) | |
| plt.close() | |
| ## save errors profiles | |
| np.savez(DIR_NAME+'/err.npz', | |
| train_mse_per_pixel_12=train_mse_per_pixel_12, | |
| train_mse_per_pixel_23=train_mse_per_pixel_23, | |
| train_mse_per_pixel_13=train_mse_per_pixel_13, | |
| test_mse_per_pixel_12=test_mse_per_pixel_12, | |
| test_mse_per_pixel_23=test_mse_per_pixel_23, | |
| test_mse_per_pixel_13=test_mse_per_pixel_13) | |
| ## save predictions at HIGH RESOLUTION on testing data | |
| np.savez(DIR_NAME + '/pred.npz', | |
| test_pod_pred_projected_on_hr=y_pred, | |
| test_pod_true_projected_on_hr=y_true, | |
| test_true_on_hr=y_true_hr | |
| ) | |
| # save the model | |
| model.save(DIR_NAME) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment