Created
May 24, 2024 02:18
-
-
Save vyeevani/00127614e51841acb4d6cfc583de930a to your computer and use it in GitHub Desktop.
the application of a perceiver-io capable of autoregressive generation to a diffusion model. depends on: https://gist.github.com/vyeevani/aee668ad21b3e4744af26305455790a1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import einops | |
| import equinox | |
| from perceiver import Perceiver | |
| class DiffusionPerciever(equinox.Module): | |
| perceiver: equinox.Module | |
| input_start: jax.Array | |
| def __init__(self, context_shapes, input_shape, latent_size, latent_count, key): | |
| rng = key | |
| rng, key = jax.random.split(rng) | |
| # we add input shapes twice because once for the input that's being passed in shifted one timestep backwards and once for the noisy inputs | |
| diffusion_step_shape = (1,) | |
| self.perceiver = Perceiver([diffusion_step_shape] + context_shapes + [input_shape] + [input_shape], input_shape, latent_size, latent_count, key) | |
| rng, key = jax.random.split(rng) | |
| self.input_start = jax.random.normal(key, input_shape) | |
| def __call__(self, diffusion_step, context, input_, noisy_input, padding, key): | |
| # we will always drop the last timestep of this process | |
| return self.perceiver([diffusion_step] + context + [einops.pack([self.input_start, input_], "* r d")[0][:-1]] + [noisy_input], padding, key) | |
| import unittest | |
| class TestDiffusionPerceiver(unittest.TestCase): | |
| def test_diffusion_perceiver(self): | |
| num_timesteps = 6 | |
| context_shapes = [(10,), (20, 20, 3)] | |
| input_shape = (15, 3) | |
| latent_size = 64 | |
| latent_count = 12 | |
| key = jax.random.PRNGKey(0) | |
| diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key) | |
| # Generate mock context and noisy inputs | |
| context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes] | |
| inputs = jax.random.normal(key, (num_timesteps, *input_shape)) | |
| noisy_input = jax.random.normal(key, (num_timesteps, *input_shape)) | |
| padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps | |
| rng, key = jax.random.split(key) | |
| # Generate diffusion step | |
| diffusion_step = jax.numpy.array([0] * num_timesteps) | |
| # Invoke the DiffusionPerciever with mock context, noisy inputs, padding, diffusion step and key | |
| outputs = diffusion_perceiver(diffusion_step, context, inputs, noisy_input, padding, key) | |
| # Validate the shape of the outputs | |
| self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch") | |
| if __name__ == '__main__': | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment