Last active
May 24, 2024 02:17
-
-
Save vyeevani/aee668ad21b3e4744af26305455790a1 to your computer and use it in GitHub Desktop.
single equinox implementation of a perceiver like architecture that can support any input size + autoregressive generation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import jax | |
| import equinox | |
| import einops | |
| import typing | |
| def make_attention(query_key_dimension, value_dimension): | |
| scale = np.sqrt(query_key_dimension) | |
| def attention(query, key, value, mask): | |
| """ | |
| query: [q_seq_len, dimension] | |
| key: [kv_seq_len, dimension] | |
| value: [kv_seq_len, dimension] | |
| mask: [q_seq_len, kv_seq_len] | |
| """ | |
| assert query.shape[1] == query_key_dimension, "Query dimension mismatch" | |
| assert key.shape[1] == query_key_dimension, "Key dimension mismatch" | |
| assert value.shape[1] == value_dimension, "Value dimension mismatch" | |
| assert key.shape[0] == value.shape[0], "Key and Value sequence length mismatch" | |
| assert mask.shape == (query.shape[0], key.shape[0]), "Mask dimension mismatch" | |
| qk_op = lambda query, key: einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale | |
| # scores = einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale | |
| scores = jax.checkpoint(qk_op)(query, key) | |
| if mask is not None: | |
| scores = jax.checkpoint(jax.numpy.where)(mask == 0, -1e9, scores) | |
| softmax_scores = jax.nn.softmax(scores, axis=-1) | |
| qkv_op = lambda qk, value: einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d") | |
| output = jax.checkpoint(qkv_op)(softmax_scores, value) | |
| # output = einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d") | |
| return output | |
| return attention | |
| def make_multiheaded_attention(query_key_dimension, value_dimension, heads): | |
| attention = make_attention(query_key_dimension // heads, value_dimension // heads) | |
| def multi_headed_attention(query, key, value, mask): | |
| query = einops.rearrange(query, "s (h d) -> h s d", h=heads) | |
| key = einops.rearrange(key, "s (h d) -> h s d", h=heads) | |
| value = einops.rearrange(value, "s (h d) -> h s d", h=heads) | |
| output = jax.vmap(attention, in_axes=(0, 0, 0, None))(query, key, value, mask) | |
| output = einops.rearrange(output, "h s d -> s (h d)") | |
| return output | |
| return multi_headed_attention | |
| class FourierFeatures(equinox.Module): | |
| kernel: jax.numpy.ndarray | |
| reshape: equinox.Module | |
| def __init__(self, input_size, output_size, key): | |
| self.kernel = jax.random.normal(key, (output_size // 2, input_size)) * 0.2 | |
| self.reshape = equinox.nn.Lambda(lambda x: x.reshape(output_size)) | |
| def __call__(self, x, key): | |
| f = 2 * jax.numpy.pi * einops.einsum(self.kernel, x, "o i, i -> o") | |
| return self.reshape(jax.numpy.concatenate([jax.numpy.cos(f), jax.numpy.sin(f)])) | |
| class Transformer(equinox.Module): | |
| query_projector: jax.Array | |
| key_projector: jax.Array | |
| value_projector: jax.Array | |
| attention: typing.Callable = equinox.field(static=True) | |
| layer_norm_1: equinox.Module | |
| linear: equinox.Module | |
| layer_norm_2: equinox.Module | |
| def __init__( | |
| self, | |
| query_dimension, | |
| key_dimension, | |
| value_dimension, | |
| attention_dimension, | |
| heads, | |
| hidden, | |
| key | |
| ): | |
| rng = key | |
| rng, key = jax.random.split(rng) | |
| self.query_projector = jax.random.normal(key, (attention_dimension, query_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.key_projector = jax.random.normal(key, (attention_dimension, key_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.value_projector = jax.random.normal(key, (query_dimension, value_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.attention = make_multiheaded_attention(attention_dimension, query_dimension, heads) | |
| rng, key = jax.random.split(rng) | |
| self.layer_norm_1 = equinox.nn.LayerNorm(query_dimension) | |
| rng, key_1, key_2 = jax.random.split(rng, 3) | |
| self.linear = equinox.nn.Sequential([ | |
| equinox.nn.Linear(query_dimension, hidden, key=key_1), | |
| equinox.nn.Lambda(lambda x: jax.nn.relu(x)), | |
| equinox.nn.Linear(hidden, query_dimension, key=key_2) | |
| ]) | |
| self.layer_norm_2 = equinox.nn.LayerNorm(query_dimension) | |
| def __call__(self, query, key, value, mask): | |
| query_attention_projected = einops.einsum( | |
| self.query_projector, | |
| query, | |
| "a q, s q -> s a" | |
| ) | |
| key_attention_projected = einops.einsum( | |
| self.key_projector, | |
| key, | |
| "a k, s k -> s a" | |
| ) | |
| value_query_projected = einops.einsum( | |
| self.value_projector, | |
| value, | |
| "q v, s v -> s q" | |
| ) | |
| x = jax.vmap(self.layer_norm_1)(query + self.attention(query_attention_projected, key_attention_projected, value_query_projected, mask)) | |
| x = jax.vmap(self.layer_norm_2)(x + jax.vmap(self.linear)(x)) | |
| return x | |
| class Perceiver(equinox.Module): | |
| input_max_dimensions: int | |
| latent_size: int | |
| latent_count: int | |
| latent: jax.Array | |
| positional_embedders: typing.List[equinox.Module] | |
| input_encoder: equinox.Module | |
| input_backbone: equinox.Module | |
| backbone: equinox.Module | |
| output_decoder: equinox.Module | |
| output_latent: jax.Array | |
| output_projector: jax.Array | |
| def __init__(self, input_shapes, output_shape, latent_size, latent_count, key): | |
| rng = key | |
| rng, key = jax.random.split(rng) | |
| self.latent_size = latent_size | |
| self.latent_count = latent_count | |
| self.latent = jax.random.normal(key, (self.latent_count, self.latent_size)) | |
| self.input_max_dimensions = max([len(shape) + 1 for shape in input_shapes]) | |
| self.input_max_dimensions = self.input_max_dimensions + 1 if self.input_max_dimensions % 2 else self.input_max_dimensions | |
| self.positional_embedders = [] | |
| for input_shape in input_shapes: | |
| rng, key = jax.random.split(rng) | |
| self.positional_embedders.append(FourierFeatures(len(input_shape) + 1, self.input_max_dimensions, key)) | |
| rng, key = jax.random.split(rng) | |
| # the dimension is equal to the max dimension + the value in the byte array | |
| self.input_encoder = Transformer(self.latent_size, self.input_max_dimensions + 1, self.input_max_dimensions + 1, self.input_max_dimensions + 1, 1, (self.input_max_dimensions + 1) * 10, key) | |
| rng, key = jax.random.split(rng) | |
| self.input_backbone = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key) | |
| rng, key = jax.random.split(rng) | |
| self.backbone = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key) | |
| rng, key = jax.random.split(rng) | |
| self.output_decoder = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key) | |
| rng, key = jax.random.split(rng) | |
| self.output_latent = jax.random.normal(key, (1, self.latent_size)) | |
| rng, key = jax.random.split(rng) | |
| self.output_projector = jax.random.normal(key, (*output_shape, self.latent_size)) | |
| def __call__(self, xs, padding, key): | |
| rng = key | |
| def create_coordinate_space(shape): | |
| """ | |
| Create a coordinate space for an arbitrary shape. | |
| Parameters: | |
| - shape (tuple of ints): The shape of the desired coordinate space. | |
| Returns: | |
| - numpy.ndarray: An array of coordinates, with the last dimension representing the coordinate values. | |
| """ | |
| # Generate grids of indices along each dimension | |
| grids = jax.numpy.meshgrid(*[jax.numpy.arange(dim) for dim in shape], indexing='ij') | |
| # Stack the grids along a new last axis, so each "pixel" has a coordinate | |
| coordinates = jax.numpy.stack(grids, axis=-1) | |
| return coordinates | |
| def multiple_vmap(func, count, in_axes): | |
| """Apply `jax.vmap` to `func` over all dimensions of `x` except the last one.""" | |
| for _ in range(count): | |
| func = jax.vmap(func, in_axes=in_axes) | |
| return func | |
| # Assuming the input_project function is defined as provided. | |
| def input_project(x, positional_embedder): | |
| positional_embeddings = einops.rearrange(multiple_vmap(positional_embedder, len(x.shape), (0, None))(create_coordinate_space(x.shape), None), "t ... d -> t (...) d") | |
| x = einops.rearrange(x, "t ... -> t (...) 1") | |
| return einops.pack([x, positional_embeddings], "t r *")[0] | |
| inputs = [] | |
| for (input, embedder) in zip(xs, self.positional_embedders): | |
| projected_input = input_project(input, embedder) | |
| inputs.append(projected_input) | |
| inputs = einops.pack(inputs, "t * d")[0] | |
| timesteps = inputs.shape[0] | |
| tokens_per_timestep = inputs.shape[1] | |
| inputs = einops.rearrange(inputs, "t r d -> (t r) d") | |
| casual_input_encoding_mask = einops.rearrange( | |
| jax.numpy.tril(jax.numpy.ones((self.latent_count, tokens_per_timestep, timesteps, timesteps))), | |
| "c e t1 t2 -> (t1 c) (t2 e)" | |
| ) | |
| padding_input_encoding_mask = einops.repeat(einops.einsum(padding, padding, "t1, t2 -> t1 t2"), "t1 t2 -> (t1 c) (t2 e)", c=self.latent_count, e=tokens_per_timestep) | |
| input_encoding_mask = jax.checkpoint(lambda: casual_input_encoding_mask * padding_input_encoding_mask)() # regenerate the input encoding mask to avoid keeping around a massive matrix | |
| backbone_mask = jax.numpy.ones((timesteps, self.latent_count, self.latent_count)) | |
| latent = einops.repeat(self.latent, "c d -> t c d", t=timesteps) | |
| latent = einops.rearrange(latent, "t c d -> (t c) d", t=timesteps) | |
| latent = self.input_encoder(latent, inputs, inputs, input_encoding_mask) | |
| latent = einops.rearrange(latent, "(t c) d -> t c d", t=timesteps) | |
| latent = jax.vmap(self.input_backbone)(latent, latent, latent, backbone_mask) | |
| latent = einops.rearrange(latent, "t c d -> (t c) d", t=timesteps) | |
| latent = self.input_encoder(latent, inputs, inputs, input_encoding_mask) | |
| latent = einops.rearrange(latent, "(t c) d -> t c d", t=timesteps) | |
| latent = jax.vmap(self.input_backbone)(latent, latent, latent, backbone_mask) | |
| latent = jax.vmap(self.backbone)(latent, latent, latent, backbone_mask) | |
| output_mask = jax.numpy.ones((1, self.latent_count)) | |
| outputs = jax.vmap(self.output_decoder, in_axes=(None, 0, 0, None))(self.output_latent, latent, latent, output_mask) | |
| outputs = einops.rearrange(outputs, "t s i -> t (s i)") | |
| outputs = einops.einsum(self.output_projector, outputs, "... i, t i -> t ...") | |
| return outputs | |
| import unittest | |
| class TestPerceiver(unittest.TestCase): | |
| def test_perceiver_generic(self): | |
| num_timesteps = 6 | |
| input_shapes = [(10,), (20, 20, 3), (15, 3), (5, 4), (5, 4)] | |
| output_shape = (5, 4) | |
| latent_size = 64 | |
| latent_count = 12 | |
| key = jax.random.PRNGKey(0) | |
| perceiver = Perceiver(input_shapes, output_shape, latent_size, latent_count, key) | |
| # Create mock inputs | |
| x = [jax.random.normal(key, (num_timesteps, *shape)) for shape in input_shapes] | |
| state = {} | |
| padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps | |
| rng, key = jax.random.split(key) | |
| # Call the PerceiverGeneric with mock inputs | |
| outputs = perceiver(x, padding, key) | |
| # Check outputs | |
| self.assertEqual(outputs.shape, (num_timesteps, *output_shape), "Output shape mismatch") | |
| # Run the tests | |
| if __name__ == '__main__': | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment