Last active
May 31, 2024 01:44
-
-
Save vyeevani/e35a72b531f1fa7a47f6cadbaa5af6d5 to your computer and use it in GitHub Desktop.
diffusion processes. tests depend on: https://gist.github.com/vyeevani/00127614e51841acb4d6cfc583de930a and https://gist.github.com/vyeevani/aee668ad21b3e4744af26305455790a1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import einops | |
| import equinox | |
| def linear_beta_schedule(num_steps, beta_start, beta_end): | |
| return jax.numpy.linspace(beta_start, beta_end, num_steps) | |
| def cosine_beta_schedule(num_steps, s=0.008): | |
| """ | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| """ | |
| steps = num_steps + 1 | |
| t = jax.numpy.linspace(0, num_steps, steps) / num_steps | |
| alphas_cumprod = jax.numpy.cos((t + s) / (1 + s) * jax.numpy.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| return jax.numpy.clip(betas, 0, 0.999) | |
| def make_forward_diffusion_process(schedule): | |
| num_steps = schedule.shape[0] | |
| betas = schedule | |
| alphas = 1 - betas | |
| alpha_hats = jax.numpy.cumprod(alphas) | |
| def forward(input_, key): | |
| rng = key | |
| rng, key = jax.random.split(rng) | |
| diffusion_steps = jax.random.randint(key, (), 0, num_steps) | |
| rng, key = jax.random.split(rng) | |
| noise = jax.random.normal(key, input_.shape) | |
| return jax.numpy.sqrt(alpha_hats[diffusion_steps]) * input_ + jax.numpy.sqrt(1.0 - alpha_hats[diffusion_steps]) * noise | |
| return forward | |
| def make_sample_diffusion_sequence_model_timestep(model, betas, temperature, input_shape): | |
| num_steps = betas.shape[0] # the shape of beta schedule is the total number of timesteps | |
| alphas = 1 - betas | |
| alpha_hats = jax.numpy.cumprod(alphas) | |
| alpha_hats_prev = jax.numpy.append(1., alpha_hats[:-1]) | |
| jax.numpy.append(1., alpha_hats) | |
| @equinox.filter_jit | |
| def sample_diffusion_model(context, input_, padding, key): | |
| rng = key | |
| @jax.jit | |
| def sample_step(input_tuple, t): | |
| noisy_input, rng = input_tuple | |
| rng, key = jax.random.split(rng) | |
| diffusion_step = einops.repeat(t, "-> t", t=timesteps) | |
| input_hat = model(diffusion_step, context, noisy_input, input_, padding=padding, key=key) | |
| coef1 = jax.numpy.sqrt(alphas[t]) * (1. - alpha_hats_prev[t]) / (1. - alpha_hats[t]) | |
| coef2 = jax.numpy.sqrt(alpha_hats_prev[t]) * (1. - alphas[t]) / (1. - alpha_hats[t]) | |
| input_mean = coef1 * noisy_input + coef2 * input_hat | |
| rng, key = jax.random.split(rng) | |
| z = jax.random.normal(key, shape=input_mean.shape) | |
| z_scaled = temperature * z | |
| denoised_input = input_mean + ((t > 0) * jax.numpy.sqrt(betas[t]) * z_scaled) | |
| return (denoised_input, rng), () | |
| rng, key = jax.random.split(rng) | |
| timesteps = input_.shape[0] | |
| noisy_input = jax.random.normal(key, (timesteps, *input_shape)) | |
| (input_, _), _ = jax.lax.scan( | |
| sample_step, | |
| ( | |
| noisy_input, | |
| rng | |
| ), | |
| jax.numpy.arange( | |
| num_steps - 1, | |
| -1, | |
| -1 | |
| ) | |
| ) | |
| return input_ | |
| return sample_diffusion_model | |
| def sample_diffusion_sequence_model_trajectory(sample_diffusion_sequence_model_timestep, context, input_, key): | |
| rng = key | |
| starting_count = 0 | |
| timesteps = input_.shape[0] | |
| paddings = jax.numpy.tril(jax.numpy.ones((timesteps, timesteps)))[starting_count:] | |
| def sample(input_tuple, padding): | |
| input_, rng = input_tuple | |
| rng, key = jax.random.split(rng) | |
| input_ = sample_diffusion_sequence_model_timestep(context, input_, padding, key) | |
| return (input_, rng), () | |
| (input_, _), _ = jax.lax.scan( | |
| sample, | |
| ( | |
| input_, | |
| rng | |
| ), | |
| paddings | |
| ) | |
| return input_ | |
| import unittest | |
| from diffusion_perceiver import DiffusionPerciever | |
| class TestDiffusionModel(unittest.TestCase): | |
| def test_make_sample_diffusion_sequence_model_timestep(self): | |
| num_timesteps = 6 | |
| context_shapes = [(10,), (20, 20, 3)] | |
| input_shape = (15, 3) | |
| latent_size = 64 | |
| latent_count = 12 | |
| key = jax.random.PRNGKey(0) | |
| temperature = 0.1 | |
| # Generate cosine beta schedule | |
| betas = cosine_beta_schedule(num_timesteps) | |
| # Initialize DiffusionPerceiver | |
| diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key) | |
| # Generate mock context and noisy inputs | |
| context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes] | |
| inputs = jax.random.normal(key, (num_timesteps, *input_shape)) | |
| padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps | |
| # Create diffusion model | |
| sample_diffusion_model = make_sample_diffusion_sequence_model_timestep(diffusion_perceiver, betas, temperature, input_shape) | |
| # Invoke the diffusion model with mock context, noisy inputs, padding, and key | |
| outputs = sample_diffusion_model(context, inputs, padding, key) | |
| # Validate the shape of the outputs | |
| self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch") | |
| def test_trajectory_sampling(self): | |
| num_timesteps = 6 | |
| context_shapes = [(10,), (20, 20, 3)] | |
| input_shape = (15, 3) | |
| latent_size = 64 | |
| latent_count = 12 | |
| key = jax.random.PRNGKey(0) | |
| temperature = 0.1 | |
| # Generate cosine beta schedule | |
| betas = cosine_beta_schedule(num_timesteps) | |
| # Initialize DiffusionPerceiver | |
| diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key) | |
| # Generate mock context and noisy inputs | |
| context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes] | |
| inputs = jax.random.normal(key, (num_timesteps, *input_shape)) | |
| padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps | |
| # Create diffusion model | |
| sample_diffusion_model = make_sample_diffusion_sequence_model_timestep(diffusion_perceiver, betas, temperature, input_shape) | |
| # Test trajectory sampling function | |
| trajectory_outputs = sample_diffusion_sequence_model_trajectory(sample_diffusion_model, context, inputs, key) | |
| # Validate the shape of the trajectory outputs | |
| self.assertEqual(trajectory_outputs.shape, (num_timesteps, *input_shape), "Trajectory output shape mismatch") | |
| def test_make_forward_diffusion(self): | |
| num_timesteps = 6 | |
| key = jax.random.PRNGKey(0) | |
| input_shape = (15, 3) | |
| # Generate cosine beta schedule | |
| betas = cosine_beta_schedule(num_timesteps) | |
| # Create forward diffusion model | |
| forward_diffusion_process = make_forward_diffusion_process(betas) | |
| # Generate mock inputs | |
| inputs = jax.random.normal(key, (num_timesteps, *input_shape)) | |
| # Invoke the forward diffusion model with mock inputs | |
| outputs = forward_diffusion_process(inputs, key) | |
| # Validate the shape of the outputs | |
| self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch") | |
| if __name__ == '__main__': | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
need to get the forward part of this in here from my hacky notebook files