Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Created May 24, 2024 02:18
Show Gist options
  • Select an option

  • Save vyeevani/00127614e51841acb4d6cfc583de930a to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/00127614e51841acb4d6cfc583de930a to your computer and use it in GitHub Desktop.
the application of a perceiver-io capable of autoregressive generation to a diffusion model. depends on: https://gist.github.com/vyeevani/aee668ad21b3e4744af26305455790a1
import jax
import einops
import equinox
from perceiver import Perceiver
class DiffusionPerciever(equinox.Module):
perceiver: equinox.Module
input_start: jax.Array
def __init__(self, context_shapes, input_shape, latent_size, latent_count, key):
rng = key
rng, key = jax.random.split(rng)
# we add input shapes twice because once for the input that's being passed in shifted one timestep backwards and once for the noisy inputs
diffusion_step_shape = (1,)
self.perceiver = Perceiver([diffusion_step_shape] + context_shapes + [input_shape] + [input_shape], input_shape, latent_size, latent_count, key)
rng, key = jax.random.split(rng)
self.input_start = jax.random.normal(key, input_shape)
def __call__(self, diffusion_step, context, input_, noisy_input, padding, key):
# we will always drop the last timestep of this process
return self.perceiver([diffusion_step] + context + [einops.pack([self.input_start, input_], "* r d")[0][:-1]] + [noisy_input], padding, key)
import unittest
class TestDiffusionPerceiver(unittest.TestCase):
def test_diffusion_perceiver(self):
num_timesteps = 6
context_shapes = [(10,), (20, 20, 3)]
input_shape = (15, 3)
latent_size = 64
latent_count = 12
key = jax.random.PRNGKey(0)
diffusion_perceiver = DiffusionPerciever(context_shapes, input_shape, latent_size, latent_count, key)
# Generate mock context and noisy inputs
context = [jax.random.normal(key, (num_timesteps, *shape)) for shape in context_shapes]
inputs = jax.random.normal(key, (num_timesteps, *input_shape))
noisy_input = jax.random.normal(key, (num_timesteps, *input_shape))
padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps
rng, key = jax.random.split(key)
# Generate diffusion step
diffusion_step = jax.numpy.array([0] * num_timesteps)
# Invoke the DiffusionPerciever with mock context, noisy inputs, padding, diffusion step and key
outputs = diffusion_perceiver(diffusion_step, context, inputs, noisy_input, padding, key)
# Validate the shape of the outputs
self.assertEqual(outputs.shape, (num_timesteps, *input_shape), "Output shape mismatch")
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment