Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Last active May 24, 2024 02:17
Show Gist options
  • Select an option

  • Save vyeevani/aee668ad21b3e4744af26305455790a1 to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/aee668ad21b3e4744af26305455790a1 to your computer and use it in GitHub Desktop.
single equinox implementation of a perceiver like architecture that can support any input size + autoregressive generation
import numpy as np
import jax
import equinox
import einops
import typing
def make_attention(query_key_dimension, value_dimension):
scale = np.sqrt(query_key_dimension)
def attention(query, key, value, mask):
"""
query: [q_seq_len, dimension]
key: [kv_seq_len, dimension]
value: [kv_seq_len, dimension]
mask: [q_seq_len, kv_seq_len]
"""
assert query.shape[1] == query_key_dimension, "Query dimension mismatch"
assert key.shape[1] == query_key_dimension, "Key dimension mismatch"
assert value.shape[1] == value_dimension, "Value dimension mismatch"
assert key.shape[0] == value.shape[0], "Key and Value sequence length mismatch"
assert mask.shape == (query.shape[0], key.shape[0]), "Mask dimension mismatch"
qk_op = lambda query, key: einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale
# scores = einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale
scores = jax.checkpoint(qk_op)(query, key)
if mask is not None:
scores = jax.checkpoint(jax.numpy.where)(mask == 0, -1e9, scores)
softmax_scores = jax.nn.softmax(scores, axis=-1)
qkv_op = lambda qk, value: einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d")
output = jax.checkpoint(qkv_op)(softmax_scores, value)
# output = einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d")
return output
return attention
def make_multiheaded_attention(query_key_dimension, value_dimension, heads):
attention = make_attention(query_key_dimension // heads, value_dimension // heads)
def multi_headed_attention(query, key, value, mask):
query = einops.rearrange(query, "s (h d) -> h s d", h=heads)
key = einops.rearrange(key, "s (h d) -> h s d", h=heads)
value = einops.rearrange(value, "s (h d) -> h s d", h=heads)
output = jax.vmap(attention, in_axes=(0, 0, 0, None))(query, key, value, mask)
output = einops.rearrange(output, "h s d -> s (h d)")
return output
return multi_headed_attention
class FourierFeatures(equinox.Module):
kernel: jax.numpy.ndarray
reshape: equinox.Module
def __init__(self, input_size, output_size, key):
self.kernel = jax.random.normal(key, (output_size // 2, input_size)) * 0.2
self.reshape = equinox.nn.Lambda(lambda x: x.reshape(output_size))
def __call__(self, x, key):
f = 2 * jax.numpy.pi * einops.einsum(self.kernel, x, "o i, i -> o")
return self.reshape(jax.numpy.concatenate([jax.numpy.cos(f), jax.numpy.sin(f)]))
class Transformer(equinox.Module):
query_projector: jax.Array
key_projector: jax.Array
value_projector: jax.Array
attention: typing.Callable = equinox.field(static=True)
layer_norm_1: equinox.Module
linear: equinox.Module
layer_norm_2: equinox.Module
def __init__(
self,
query_dimension,
key_dimension,
value_dimension,
attention_dimension,
heads,
hidden,
key
):
rng = key
rng, key = jax.random.split(rng)
self.query_projector = jax.random.normal(key, (attention_dimension, query_dimension))
rng, key = jax.random.split(rng)
self.key_projector = jax.random.normal(key, (attention_dimension, key_dimension))
rng, key = jax.random.split(rng)
self.value_projector = jax.random.normal(key, (query_dimension, value_dimension))
rng, key = jax.random.split(rng)
self.attention = make_multiheaded_attention(attention_dimension, query_dimension, heads)
rng, key = jax.random.split(rng)
self.layer_norm_1 = equinox.nn.LayerNorm(query_dimension)
rng, key_1, key_2 = jax.random.split(rng, 3)
self.linear = equinox.nn.Sequential([
equinox.nn.Linear(query_dimension, hidden, key=key_1),
equinox.nn.Lambda(lambda x: jax.nn.relu(x)),
equinox.nn.Linear(hidden, query_dimension, key=key_2)
])
self.layer_norm_2 = equinox.nn.LayerNorm(query_dimension)
def __call__(self, query, key, value, mask):
query_attention_projected = einops.einsum(
self.query_projector,
query,
"a q, s q -> s a"
)
key_attention_projected = einops.einsum(
self.key_projector,
key,
"a k, s k -> s a"
)
value_query_projected = einops.einsum(
self.value_projector,
value,
"q v, s v -> s q"
)
x = jax.vmap(self.layer_norm_1)(query + self.attention(query_attention_projected, key_attention_projected, value_query_projected, mask))
x = jax.vmap(self.layer_norm_2)(x + jax.vmap(self.linear)(x))
return x
class Perceiver(equinox.Module):
input_max_dimensions: int
latent_size: int
latent_count: int
latent: jax.Array
positional_embedders: typing.List[equinox.Module]
input_encoder: equinox.Module
input_backbone: equinox.Module
backbone: equinox.Module
output_decoder: equinox.Module
output_latent: jax.Array
output_projector: jax.Array
def __init__(self, input_shapes, output_shape, latent_size, latent_count, key):
rng = key
rng, key = jax.random.split(rng)
self.latent_size = latent_size
self.latent_count = latent_count
self.latent = jax.random.normal(key, (self.latent_count, self.latent_size))
self.input_max_dimensions = max([len(shape) + 1 for shape in input_shapes])
self.input_max_dimensions = self.input_max_dimensions + 1 if self.input_max_dimensions % 2 else self.input_max_dimensions
self.positional_embedders = []
for input_shape in input_shapes:
rng, key = jax.random.split(rng)
self.positional_embedders.append(FourierFeatures(len(input_shape) + 1, self.input_max_dimensions, key))
rng, key = jax.random.split(rng)
# the dimension is equal to the max dimension + the value in the byte array
self.input_encoder = Transformer(self.latent_size, self.input_max_dimensions + 1, self.input_max_dimensions + 1, self.input_max_dimensions + 1, 1, (self.input_max_dimensions + 1) * 10, key)
rng, key = jax.random.split(rng)
self.input_backbone = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key)
rng, key = jax.random.split(rng)
self.backbone = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key)
rng, key = jax.random.split(rng)
self.output_decoder = Transformer(self.latent_size, self.latent_size, self.latent_size, self.latent_size, 8, self.latent_size * 10, key)
rng, key = jax.random.split(rng)
self.output_latent = jax.random.normal(key, (1, self.latent_size))
rng, key = jax.random.split(rng)
self.output_projector = jax.random.normal(key, (*output_shape, self.latent_size))
def __call__(self, xs, padding, key):
rng = key
def create_coordinate_space(shape):
"""
Create a coordinate space for an arbitrary shape.
Parameters:
- shape (tuple of ints): The shape of the desired coordinate space.
Returns:
- numpy.ndarray: An array of coordinates, with the last dimension representing the coordinate values.
"""
# Generate grids of indices along each dimension
grids = jax.numpy.meshgrid(*[jax.numpy.arange(dim) for dim in shape], indexing='ij')
# Stack the grids along a new last axis, so each "pixel" has a coordinate
coordinates = jax.numpy.stack(grids, axis=-1)
return coordinates
def multiple_vmap(func, count, in_axes):
"""Apply `jax.vmap` to `func` over all dimensions of `x` except the last one."""
for _ in range(count):
func = jax.vmap(func, in_axes=in_axes)
return func
# Assuming the input_project function is defined as provided.
def input_project(x, positional_embedder):
positional_embeddings = einops.rearrange(multiple_vmap(positional_embedder, len(x.shape), (0, None))(create_coordinate_space(x.shape), None), "t ... d -> t (...) d")
x = einops.rearrange(x, "t ... -> t (...) 1")
return einops.pack([x, positional_embeddings], "t r *")[0]
inputs = []
for (input, embedder) in zip(xs, self.positional_embedders):
projected_input = input_project(input, embedder)
inputs.append(projected_input)
inputs = einops.pack(inputs, "t * d")[0]
timesteps = inputs.shape[0]
tokens_per_timestep = inputs.shape[1]
inputs = einops.rearrange(inputs, "t r d -> (t r) d")
casual_input_encoding_mask = einops.rearrange(
jax.numpy.tril(jax.numpy.ones((self.latent_count, tokens_per_timestep, timesteps, timesteps))),
"c e t1 t2 -> (t1 c) (t2 e)"
)
padding_input_encoding_mask = einops.repeat(einops.einsum(padding, padding, "t1, t2 -> t1 t2"), "t1 t2 -> (t1 c) (t2 e)", c=self.latent_count, e=tokens_per_timestep)
input_encoding_mask = jax.checkpoint(lambda: casual_input_encoding_mask * padding_input_encoding_mask)() # regenerate the input encoding mask to avoid keeping around a massive matrix
backbone_mask = jax.numpy.ones((timesteps, self.latent_count, self.latent_count))
latent = einops.repeat(self.latent, "c d -> t c d", t=timesteps)
latent = einops.rearrange(latent, "t c d -> (t c) d", t=timesteps)
latent = self.input_encoder(latent, inputs, inputs, input_encoding_mask)
latent = einops.rearrange(latent, "(t c) d -> t c d", t=timesteps)
latent = jax.vmap(self.input_backbone)(latent, latent, latent, backbone_mask)
latent = einops.rearrange(latent, "t c d -> (t c) d", t=timesteps)
latent = self.input_encoder(latent, inputs, inputs, input_encoding_mask)
latent = einops.rearrange(latent, "(t c) d -> t c d", t=timesteps)
latent = jax.vmap(self.input_backbone)(latent, latent, latent, backbone_mask)
latent = jax.vmap(self.backbone)(latent, latent, latent, backbone_mask)
output_mask = jax.numpy.ones((1, self.latent_count))
outputs = jax.vmap(self.output_decoder, in_axes=(None, 0, 0, None))(self.output_latent, latent, latent, output_mask)
outputs = einops.rearrange(outputs, "t s i -> t (s i)")
outputs = einops.einsum(self.output_projector, outputs, "... i, t i -> t ...")
return outputs
import unittest
class TestPerceiver(unittest.TestCase):
def test_perceiver_generic(self):
num_timesteps = 6
input_shapes = [(10,), (20, 20, 3), (15, 3), (5, 4), (5, 4)]
output_shape = (5, 4)
latent_size = 64
latent_count = 12
key = jax.random.PRNGKey(0)
perceiver = Perceiver(input_shapes, output_shape, latent_size, latent_count, key)
# Create mock inputs
x = [jax.random.normal(key, (num_timesteps, *shape)) for shape in input_shapes]
state = {}
padding = jax.numpy.ones(num_timesteps) # Generate padding based on num_timesteps
rng, key = jax.random.split(key)
# Call the PerceiverGeneric with mock inputs
outputs = perceiver(x, padding, key)
# Check outputs
self.assertEqual(outputs.shape, (num_timesteps, *output_shape), "Output shape mismatch")
# Run the tests
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment