Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Created May 14, 2024 03:32
Show Gist options
  • Select an option

  • Save vyeevani/649d723dea23ebf831ba4a7c2ae9d4e9 to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/649d723dea23ebf831ba4a7c2ae9d4e9 to your computer and use it in GitHub Desktop.
random hodgepodge of models that I've written while working on robot stuff
import typing
from operator import mul
from functools import reduce, partial
from collections.abc import Iterable
import jax
import eqxvision as eqv
import equinox as eqx
from eqxvision.utils import CLASSIFICATION_URLS
import equinox
import jax
import einops
import numpy as np
import functools
import typing
# I've forced the remat to happen here but I'm honestly not sure if this is the best thing to do in all cases :|
def make_attention(query_key_dimension, value_dimension):
scale = np.sqrt(query_key_dimension)
def attention(query, key, value, mask):
"""
query: [q_seq_len, dimension]
key: [kv_seq_len, dimension]
value: [kv_seq_len, dimension]
mask: [q_seq_len, kv_seq_len]
"""
assert query.shape[1] == query_key_dimension, "Query dimension mismatch"
assert key.shape[1] == query_key_dimension, "Key dimension mismatch"
assert value.shape[1] == value_dimension, "Value dimension mismatch"
assert key.shape[0] == value.shape[0], "Key and Value sequence length mismatch"
assert mask.shape == (query.shape[0], key.shape[0]), "Mask dimension mismatch"
qk_op = lambda query, key: einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale
# scores = einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale
scores = jax.checkpoint(qk_op)(query, key)
if mask is not None:
scores = jax.checkpoint(jax.numpy.where)(mask == 0, -1e9, scores)
softmax_scores = jax.nn.softmax(scores, axis=-1)
qkv_op = lambda qk, value: einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d")
output = jax.checkpoint(qkv_op)(softmax_scores, value)
# output = einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d")
return output
return attention
def make_multiheaded_attention(query_key_dimension, value_dimension, heads):
attention = make_attention(query_key_dimension // heads, value_dimension // heads)
def multi_headed_attention(query, key, value, mask):
query = einops.rearrange(query, "s (h d) -> h s d", h=heads)
key = einops.rearrange(key, "s (h d) -> h s d", h=heads)
value = einops.rearrange(value, "s (h d) -> h s d", h=heads)
output = jax.vmap(attention, in_axes=(0, 0, 0, None))(query, key, value, mask)
output = einops.rearrange(output, "h s d -> s (h d)")
return output
return multi_headed_attention
class Transformer(equinox.Module):
query_projector: jax.Array
key_projector: jax.Array
value_projector: jax.Array
attention: typing.Callable = eqx.field(static=True)
layer_norm_1: equinox.Module
linear: equinox.Module
layer_norm_2: equinox.Module
def __init__(
self,
query_dimension,
key_dimension,
value_dimension,
attention_dimension,
heads,
hidden,
key
):
rng = key
rng, key = jax.random.split(rng)
self.query_projector = jax.random.normal(key, (attention_dimension, query_dimension))
rng, key = jax.random.split(rng)
self.key_projector = jax.random.normal(key, (attention_dimension, key_dimension))
rng, key = jax.random.split(rng)
self.value_projector = jax.random.normal(key, (query_dimension, value_dimension))
rng, key = jax.random.split(rng)
self.attention = make_multiheaded_attention(attention_dimension, query_dimension, heads)
rng, key = jax.random.split(rng)
self.layer_norm_1 = equinox.nn.LayerNorm(query_dimension)
rng, key_1, key_2 = jax.random.split(rng, 3)
self.linear = equinox.nn.Sequential([
equinox.nn.Linear(query_dimension, hidden, key=key_1),
equinox.nn.Lambda(lambda x: jax.nn.relu(x)),
equinox.nn.Linear(hidden, query_dimension, key=key_2)
])
self.layer_norm_2 = equinox.nn.LayerNorm(query_dimension)
def __call__(self, query, key, value, mask):
query_attention_projected = einops.einsum(
self.query_projector,
query,
"a q, s q -> s a"
)
key_attention_projected = einops.einsum(
self.key_projector,
key,
"a k, s k -> s a"
)
value_query_projected = einops.einsum(
self.value_projector,
value,
"q v, s v -> s q"
)
x = jax.vmap(self.layer_norm_1)(query + self.attention(query_attention_projected, key_attention_projected, value_query_projected, mask))
x = jax.vmap(self.layer_norm_2)(x + jax.vmap(self.linear)(x))
return x
class EncoderTransformer(equinox.Module):
transformer: Transformer
def __init__(self, transformer):
self.transformer = transformer
def __call__(self, x, mask):
return self.transformer(x, x, x, x, mask)
class DecoderTransformer(equinox.Module):
transformer: Transformer
def __init__(self, transformer):
self.transformer = transformer
def __call__(self, target, memory, mask):
return self.transformer(target, target, memory, memory, mask)
def any_shape_linear(input_shape, output_shape, key=jax.random.PRNGKey(0)):
flattened_input_shape = reduce(mul, input_shape)
flattened_output_shape = reduce(mul, output_shape)
return eqx.nn.Sequential([
eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape))),
eqx.nn.Linear(flattened_input_shape, flattened_output_shape, key=key),
eqx.nn.Lambda(lambda x: x.reshape(output_shape))
])
VGG_CLASSIFIER_INPUT_SIZE = 512 * 7 * 7
def make_mlp(input_shape, output_shape, num_layers=4, hidden_size=100, key=jax.random.PRNGKey(0)):
flattened_input_shape = reduce(mul, input_shape)
flattened_output_shape = reduce(mul, output_shape)
layers = []
layers.append(
eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape)))
)
input_layer_key, key = jax.random.split(key)
layers.append(
eqx.nn.Linear(flattened_input_shape, hidden_size, key=input_layer_key)
)
for i in range(num_layers):
layers.append(eqx.nn.Lambda(jax.nn.relu))
hidden_layer_key, key = jax.random.split(key)
layers.append(
eqx.nn.Linear(hidden_size, hidden_size, key=hidden_layer_key)
)
layers.append(eqx.nn.Lambda(jax.nn.relu))
output_layer_key, key = jax.random.split(key)
layers.append(
eqx.nn.Linear(hidden_size, flattened_output_shape, key=output_layer_key)
)
layers.append(
eqx.nn.Lambda(lambda x: x.reshape(output_shape))
)
model= eqx.nn.Sequential(layers)
return model
class Residual(eqx.Module):
linear: eqx.Module
def __init__(self, linear):
self.linear = linear
def __call__(self, x, key):
return self.linear(x) + x
def make_resnet_mlp(input_shape, output_shape, num_layers=4, hidden_size=100, key=jax.random.PRNGKey(0)):
flattened_input_shape = reduce(mul, input_shape)
flattened_output_shape = reduce(mul, output_shape)
layers = []
layers.append(
eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape)))
)
input_layer_key, key = jax.random.split(key)
layers.append(
eqx.nn.Linear(flattened_input_shape, hidden_size, key=input_layer_key)
)
for i in range(num_layers):
layers.append(eqx.nn.Lambda(jax.nn.relu))
hidden_layer_key, key = jax.random.split(key)
layers.append(
Residual(eqx.nn.Linear(hidden_size, hidden_size, key=hidden_layer_key))
)
layers.append(eqx.nn.Lambda(jax.nn.relu))
output_layer_key, key = jax.random.split(key)
layers.append(
eqx.nn.Linear(hidden_size, flattened_output_shape, key=output_layer_key)
)
layers.append(
eqx.nn.Lambda(lambda x: x.reshape(output_shape))
)
model= eqx.nn.Sequential(layers)
return model
def make_cnn(input_shape, output_shape, num_layers, filters=64, kernel_size=3, key=jax.random.PRNGKey(0)):
layers = []
def reshape(x, input_shape):
return x.reshape(input_shape)
layers.append(eqx.nn.Lambda(partial(reshape, input_shape=input_shape)))
for _ in range(num_layers):
cnn_key, key = jax.random.split(key)
layers.append(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key))
layers.append(eqx.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))
input_shape = [filters, input_shape[1] // 2, input_shape[2] // 2]
print(input_shape)
layers.append(any_shape_linear(input_shape, output_shape, key=key))
model = eqx.nn.Sequential(layers)
return model
def make_resnet_cnn(input_shape, output_shape, num_conv_layers=2, num_fc_layers=2, filters=64, kernel_size=3, hidden_size=100, key=jax.random.PRNGKey(0)):
"""
expects channels to be the last dimension both in the input shape and in __call__
"""
layers = []
input_shape = [input_shape[2], input_shape[0], input_shape[1]]
def reshape(x, input_shape):
return x.reshape(input_shape)
layers.append(eqx.nn.Lambda(partial(reshape, input_shape=input_shape)))
for i in range(num_conv_layers):
cnn_key, key = jax.random.split(key)
if i == 0:
layers.append(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key))
else:
layers.append(Residual(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key)))
layers.append(eqx.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))
input_shape = [filters, input_shape[1] // 2, input_shape[2] // 2] # Update input shape for the next layer
layers.append(make_resnet_mlp(input_shape, output_shape, num_fc_layers, hidden_size=hidden_size, key=key))
# Combine all layers
model = eqx.nn.Sequential(layers)
return model
class NNTuple(eqx.Module):
modules: list[eqx.Module]
def __init__(self, modules):
self.modules = modules
def __call__(self, xs, key):
return [
module(x, key=key) for (module, x) in zip(self.modules, xs)
]
# calling value should have at least one dimension! shape of (1) instead of ()
class FourierFeatures(eqx.Module):
kernel: jax.numpy.ndarray
reshape: eqx.Module
def __init__(self, input_size, output_size, key):
self.kernel = jax.random.normal(key, (output_size // 2, input_size)) * 0.2
self.reshape = eqx.nn.Lambda(lambda x: x.reshape(output_size))
def __call__(self, x, key):
f = 2 * jax.numpy.pi * einops.einsum(self.kernel, x, "o i, i -> o")
return self.reshape(jax.numpy.concatenate([jax.numpy.cos(f), jax.numpy.sin(f)]))
def make_diffusion_model(encoders, diffusion):
return eqx.nn.Sequential(
layers=[
NNTuple(encoders),
eqx.nn.Lambda(lambda encoded_values: jax.numpy.concatenate(encoded_values)),
diffusion,
]
)
def make_vgg_model(classifier, filename=None):
model, state = eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"])
model = eqx.tree_at(
lambda m: m.classifier,
model,
classifier
)
if filename:
model = load_model(model, filename)
return model
# accepts a vgg model and will return models that are frozen except for the finetuned part
def make_classifier_finetuned_vgg(model):
tree_struct = jax.tree_util.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda tree: [layer for layer in tree.classifier.layers],
tree_struct,
replace=[True for _ in range(len(model.classifier.layers))],
)
dynamic_model, static_model = eqx.partition(model, filter_spec)
return dynamic_model, static_model
def make_fully_unfrozen_model(model):
filter_spec = jax.tree_util.tree_map(lambda _: True, model)
dynamic_model, static_model = eqx.partition(model, filter_spec)
return dynamic_model, static_model
def save_model(model, filename):
eqx.tree_serialise_leaves(filename, model)
def load_model(model, filename):
with open(filename, "rb") as f:
return eqx.tree_deserialise_leaves(f, model)
import unittest
class TestModel(unittest.TestCase):
def test_make_and_save_model(self):
model = make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1))
state = eqx.nn.State(model)
self.assertIsNotNone(model)
self.assertIsNotNone(state)
filename = "/tmp/model.eqx"
save_model(model, filename)
loaded_model = load_model(make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)), filename)
loaded_state = eqx.nn.State(loaded_model)
num_iter = 10
for i in range(num_iter):
sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224))
key = jax.random.PRNGKey(i + num_iter)
model_output, _ = model(sample_input, state, key=key)
loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key)
self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output))
def test_fully_unfrozen_model(self):
model = make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1))
dynamic_model, static_model = make_fully_unfrozen_model(model)
model = eqx.combine(dynamic_model, static_model)
state = eqx.nn.State(model)
self.assertIsNotNone(model)
self.assertIsNotNone(state)
filename = "/tmp/model.eqx"
save_model(model, filename)
loaded_model = load_model(make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)), filename)
loaded_dynamic_model, loaded_static_model = make_fully_unfrozen_model(model)
loaded_model = eqx.combine(loaded_dynamic_model, loaded_static_model)
loaded_state = eqx.nn.State(loaded_model)
num_iter = 10
for i in range(num_iter):
sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224))
key = jax.random.PRNGKey(i + num_iter)
model_output, _ = model(sample_input, state, key=key)
loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key)
self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output))
def test_cnn(self):
image_shape = (16, 512, 512, 3)
output_shape = (16, 3,)
fake_image = jax.random.normal(jax.random.PRNGKey(0), image_shape)
model = make_cnn(image_shape[1:], output_shape[1:], num_conv_layers=2, num_fc_layers=2, filters=64, kernel_size=3, hidden_size=100, key=jax.random.PRNGKey(0))
state = eqx.nn.State(model)
model_output, _ = eqx.filter_vmap(model)(fake_image, state)
self.assertEqual(model_output.shape, output_shape)
self.assertIsNotNone(model_output)
def test_vgg(self):
mlp_model = make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))
model = make_vgg_model(mlp_model)
state = eqx.nn.State(model)
self.assertIsNotNone(model)
self.assertIsNotNone(state)
filename = "/tmp/model.eqx"
save_model(model, filename)
loaded_model = load_model(make_vgg_model(make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))), filename)
loaded_state = eqx.nn.State(loaded_model)
num_iter = 10
for i in range(num_iter):
sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224))
key = jax.random.PRNGKey(i + num_iter)
model_output, _ = model(sample_input, state, key=key)
loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key)
self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output))
def test_finetuned_vgg(self):
mlp_model = make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))
model = make_vgg_model(mlp_model)
dynamic_model, static_model = make_classifier_finetuned_vgg(model)
model = eqx.combine(dynamic_model, static_model)
state = eqx.nn.State(model)
self.assertIsNotNone(model)
self.assertIsNotNone(state)
filename = "/tmp/model.eqx"
save_model(model, filename)
loaded_model = load_model(make_vgg_model(make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))), filename)
loaded_dynamic_model, loaded_static_model = make_classifier_finetuned_vgg(model)
loaded_model = eqx.combine(loaded_dynamic_model, loaded_static_model)
loaded_state = eqx.nn.State(loaded_model)
num_iter = 10
for i in range(num_iter):
sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224))
key = jax.random.PRNGKey(i + num_iter)
model_output, _ = model(sample_input, state, key=key)
loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key)
self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output))
def test_diffusion_model(self):
IMAGE_SIZE = (1024, 1024, 3)
LOOKAHEAD_STEPS = 4
DIFFUSION_STEPS = 6
POSE_SIZE = (4, 4)
image_encoder_model = make_cnn(input_shape=IMAGE_SIZE, output_shape=(100,), num_conv_layers=2, num_fc_layers=2)
pose_encoder_model = make_mlp(input_shape=((LOOKAHEAD_STEPS, *POSE_SIZE)), output_shape=(100,))
diffusion_model = make_mlp(input_shape=(200,), output_shape=(LOOKAHEAD_STEPS, *POSE_SIZE))
model = make_diffusion_model([image_encoder_model, pose_encoder_model], diffusion_model)
self.assertTrue(model((
jax.numpy.zeros(IMAGE_SIZE),
jax.numpy.zeros((LOOKAHEAD_STEPS, *POSE_SIZE))
)).shape, (LOOKAHEAD_STEPS, *POSE_SIZE))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment