This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import torch | |
| import numpy as np | |
| from astropy.wcs import WCS | |
| from astropy.wcs.wcsapi import BaseHighLevelWCS | |
| def pixel_to_pixel(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs): | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| fig = plt.figure(figsize=(16,8)) | |
| plt.plot(np.linspace(-1,1,n), Y_test[-1], label='Noisy', color='C1') | |
| plt.plot(np.linspace(-1,1,n), X_test[-1], label='True', color='C2', linewidth=4) | |
| plt.plot(np.linspace(-1,1,n), ysol_list[-1][-1][-1].reshape(n), label='Predicted', linestyle='dashed', color='C3', linewidth=3) | |
| plt.legend(prop={'size': 20}) | |
| plt.ylabel('Normalized y-axis', fontsize=20) | |
| plt.xlabel('X-axis', fontsize=20) | |
| plt.xticks(fontsize = 15) | |
| plt.yticks(fontsize = 15) | |
| plt.title('RIM Example using a Noisy Gaussian') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| test_dataset = tf.data.Dataset.from_tensor_slices((Y_test, A_test, N_test)) | |
| test_dataset = test_dataset.batch(batch_size, drop_remainder=True) | |
| ysol = model(test_dataset) | |
| # Obtain better format | |
| ysol_list = [] | |
| for val in ysol: | |
| ysol_ = [val.numpy() for val in val] | |
| ysol_list.append(ysol_) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| plt.plot(np.linspace(0, epochs, epochs), training_loss, label='training') | |
| plt.plot(np.linspace(0, epochs, epochs), valid_loss, label='validation') | |
| plt.legend() | |
| plt.show() | |
| plt.plot(np.linspace(0, epochs, epochs), learning_rates[1:], label='learning rate') | |
| plt.legend() | |
| plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ysol_valid, training_loss, valid_loss, learning_rates = model.fit(batch_size, epochs, train_dataset, val_dataset) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Load model and define hyper parameters | |
| epochs = 100 | |
| batch_size = 16 | |
| model = RIM(rnn_units1=256, rnn_units2=256, conv_filters=8, kernel_size=2, input_size=n, dimensions=1, t_steps=10, learning_rate=0.005) | |
| # Prepare the training dataset | |
| train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train, A_train, N_train)) | |
| train_dataset = train_dataset.batch(batch_size, drop_remainder=True) | |
| train_dataset = train_dataset.prefetch(2) | |
| # Prepare the validation dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Create training, validation, and test sets | |
| train_percentage = 0.7 | |
| valid_percentage = 0.9 | |
| test_percentage = 1.0 | |
| len_X = len(gaussians_initial) | |
| # Training | |
| X_train = gaussians_initial[:int(train_percentage*len_X)] | |
| Y_train = gaussians_final[:int(train_percentage*len_X)] | |
| A_train = powerlaw_conv[:int(train_percentage*len_X)] | |
| N_train = [np.diag(noise_val) for noise_val in noise[:int(train_percentage*len_X)]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| n = 50 # Size of spectrum | |
| N = 5000 # Number of spectra | |
| def gaussian(x, mu, sig): | |
| return 1*np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.))) | |
| def conv_mat(n): | |
| """ | |
| Create convolution matrix that is an identity matrix with noise | |
| """ | |
| conv_mat = np.eye(n)+np.random.normal(0, 0.05, (n,n)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| data = pd.read_csv(DATASET_PATH, delimiter=' ') | |
| data['M_BH'] = np.log10(data['M_BH']) | |
| data['M_low'] = np.log10(data['M_low']) | |
| data['M_high'] = np.log10(data['M_high']) | |
| # Change low and high to errors | |
| data['M_low'] = data['M_BH'] - data['M_low'] | |
| data['M_high'] = data['M_high'] - data['M_BH'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import corner | |
| import emcee | |
| from google.colab import drive | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import mplcyberpunk | |
| import numpy as np | |
| from sklearn.linear_model import LinearRegression | |
| from sklearn.metrics import mean_squared_error | |
| plt.style.use("cyberpunk") |
NewerOlder