Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Last active May 31, 2024 01:44
Show Gist options
  • Select an option

  • Save vyeevani/e35a72b531f1fa7a47f6cadbaa5af6d5 to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/e35a72b531f1fa7a47f6cadbaa5af6d5 to your computer and use it in GitHub Desktop.
import jax
import einops
import equinox
def linear_beta_schedule(num_steps, beta_start, beta_end):
return jax.numpy.linspace(beta_start, beta_end, num_steps)
def cosine_beta_schedule(num_steps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = num_steps + 1
t = jax.numpy.linspace(0, num_steps, steps) / num_steps
alphas_cumprod = jax.numpy.cos((t + s) / (1 + s) * jax.numpy.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return jax.numpy.clip(betas, 0, 0.999)
def make_forward_diffusion_process(schedule):
num_steps = schedule.shape[0]
betas = schedule
alphas = 1 - betas
alpha_hats = jax.numpy.cumprod(alphas)
def forward(input_, key):
rng = key
rng, key = jax.random.split(rng)
diffusion_steps = jax.random.randint(key, (), 0, num_steps)
rng, key = jax.random.split(rng)
noise = jax.random.normal(key, input_.shape)
return jax.numpy.sqrt(alpha_hats[diffusion_steps]) * input_ + jax.numpy.sqrt(1.0 - alpha_hats[diffusion_steps]) * noise
return forward
def make_sample_diffusion_sequence_model_timestep(model, betas, temperature, input_shape):
num_steps = betas.shape[0] # the shape of beta schedule is the total number of timesteps
alphas = 1 - betas
alpha_hats = jax.numpy.cumprod(alphas)
alpha_hats_prev = jax.numpy.append(1., alpha_hats[:-1])
jax.numpy.append(1., alpha_hats)
@equinox.filter_jit
def sample_diffusion_model(context, input_, padding, key):
rng = key
@jax.jit
def sample_step(input_tuple, t):
noisy_input, rng = input_tuple
rng, key = jax.random.split(rng)
diffusion_step = einops.repeat(t, "-> t", t=timesteps)
input_hat = model(diffusion_step, context, noisy_input, input_, padding=padding, key=key)
coef1 = jax.numpy.sqrt(alphas[t]) * (1. - alpha_hats_prev[t]) / (1. - alpha_hats[t])
coef2 = jax.numpy.sqrt(alpha_hats_prev[t]) * (1. - alphas[t]) / (1. - alpha_hats[t])
input_mean = coef1 * noisy_input + coef2 * input_hat
rng, key = jax.random.split(rng)
z = jax.random.normal(key, shape=input_mean.shape)
z_scaled = temperature * z
denoised_input = input_mean + ((t > 0) * jax.numpy.sqrt(betas[t]) * z_scaled)
return (denoised_input, rng), ()
rng, key = jax.random.split(rng)
timesteps = input_.shape[0]
noisy_input = jax.random.normal(key, (timesteps, *input_shape))
(input_, _), _ = jax.lax.scan(
sample_step,
(
noisy_input,
rng
),
jax.numpy.arange(
num_steps - 1,
-1,
-1
)
)
return input_
return sample_diffusion_model
def sample_diffusion_sequence_model_trajectory(sample_diffusion_sequence_model_timestep, context, input_, key):
rng = key
starting_count = 0
timesteps = input_.shape[0]
paddings = jax.numpy.tril(jax.numpy.ones((timesteps, timesteps)))[starting_count:]
def sample(input_tuple, padding):
input_, rng = input_tuple
rng, key = jax.random.split(rng)
input_ = sample_diffusion_sequence_model_timestep(context, input_, padding, key)
return (input_, rng), ()
(input_, _), _ = jax.lax.scan(
sample,
(
input_,
rng
),
paddings
)
return input_
import unittest
from diffusion_perceiver import DiffusionPerciever
class TestDiffusionModel(unittest.TestCase):
def test_make_sample_diffusion_sequence_model_timestep(self):
num_timesteps = 6
context_shapes = [(10,), (20, 20, 3)]
input_shape = (15, 3)
latent_size = 64
latent_count = 12
key = jax.random.PRNGKey(0)
temperature = 0.1
# Generate cosine beta schedule
betas = cosine_beta_schedule(num_timesteps)
# Initialize DiffusionPerceiver
diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key)
# Generate mock context and noisy inputs
context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes]
inputs = jax.random.normal(key, (num_timesteps, *input_shape))
padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps
# Create diffusion model
sample_diffusion_model = make_sample_diffusion_sequence_model_timestep(diffusion_perceiver, betas, temperature, input_shape)
# Invoke the diffusion model with mock context, noisy inputs, padding, and key
outputs = sample_diffusion_model(context, inputs, padding, key)
# Validate the shape of the outputs
self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch")
def test_trajectory_sampling(self):
num_timesteps = 6
context_shapes = [(10,), (20, 20, 3)]
input_shape = (15, 3)
latent_size = 64
latent_count = 12
key = jax.random.PRNGKey(0)
temperature = 0.1
# Generate cosine beta schedule
betas = cosine_beta_schedule(num_timesteps)
# Initialize DiffusionPerceiver
diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key)
# Generate mock context and noisy inputs
context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes]
inputs = jax.random.normal(key, (num_timesteps, *input_shape))
padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps
# Create diffusion model
sample_diffusion_model = make_sample_diffusion_sequence_model_timestep(diffusion_perceiver, betas, temperature, input_shape)
# Test trajectory sampling function
trajectory_outputs = sample_diffusion_sequence_model_trajectory(sample_diffusion_model, context, inputs, key)
# Validate the shape of the trajectory outputs
self.assertEqual(trajectory_outputs.shape, (num_timesteps, *input_shape), "Trajectory output shape mismatch")
def test_make_forward_diffusion(self):
num_timesteps = 6
key = jax.random.PRNGKey(0)
input_shape = (15, 3)
# Generate cosine beta schedule
betas = cosine_beta_schedule(num_timesteps)
# Create forward diffusion model
forward_diffusion_process = make_forward_diffusion_process(betas)
# Generate mock inputs
inputs = jax.random.normal(key, (num_timesteps, *input_shape))
# Invoke the forward diffusion model with mock inputs
outputs = forward_diffusion_process(inputs, key)
# Validate the shape of the outputs
self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch")
if __name__ == '__main__':
unittest.main()
@vyeevani
Copy link
Author

need to get the forward part of this in here from my hacky notebook files

@vyeevani
Copy link
Author

fixed the problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment