Created
May 14, 2024 03:34
-
-
Save vyeevani/d71b9f31db19bdc7d4b0cbd083c13133 to your computer and use it in GitHub Desktop.
random hodgepodge of models that may or may not be working
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| from operator import mul | |
| from functools import reduce, partial | |
| from collections.abc import Iterable | |
| import jax | |
| import eqxvision as eqv | |
| import equinox as eqx | |
| from eqxvision.utils import CLASSIFICATION_URLS | |
| import equinox | |
| import jax | |
| import einops | |
| import numpy as np | |
| import functools | |
| import typing | |
| # I've forced the remat to happen here but I'm honestly not sure if this is the best thing to do in all cases :| | |
| def make_attention(query_key_dimension, value_dimension): | |
| scale = np.sqrt(query_key_dimension) | |
| def attention(query, key, value, mask): | |
| """ | |
| query: [q_seq_len, dimension] | |
| key: [kv_seq_len, dimension] | |
| value: [kv_seq_len, dimension] | |
| mask: [q_seq_len, kv_seq_len] | |
| """ | |
| assert query.shape[1] == query_key_dimension, "Query dimension mismatch" | |
| assert key.shape[1] == query_key_dimension, "Key dimension mismatch" | |
| assert value.shape[1] == value_dimension, "Value dimension mismatch" | |
| assert key.shape[0] == value.shape[0], "Key and Value sequence length mismatch" | |
| assert mask.shape == (query.shape[0], key.shape[0]), "Mask dimension mismatch" | |
| qk_op = lambda query, key: einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale | |
| # scores = einops.einsum(query, key, "q_seq_len d, kv_seq_len d -> q_seq_len kv_seq_len") / scale | |
| scores = jax.checkpoint(qk_op)(query, key) | |
| if mask is not None: | |
| scores = jax.checkpoint(jax.numpy.where)(mask == 0, -1e9, scores) | |
| softmax_scores = jax.nn.softmax(scores, axis=-1) | |
| qkv_op = lambda qk, value: einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d") | |
| output = jax.checkpoint(qkv_op)(softmax_scores, value) | |
| # output = einops.einsum(softmax_scores, value, "q_seq_len kv_seq_len, kv_seq_len d -> q_seq_len d") | |
| return output | |
| return attention | |
| def make_multiheaded_attention(query_key_dimension, value_dimension, heads): | |
| attention = make_attention(query_key_dimension // heads, value_dimension // heads) | |
| def multi_headed_attention(query, key, value, mask): | |
| query = einops.rearrange(query, "s (h d) -> h s d", h=heads) | |
| key = einops.rearrange(key, "s (h d) -> h s d", h=heads) | |
| value = einops.rearrange(value, "s (h d) -> h s d", h=heads) | |
| output = jax.vmap(attention, in_axes=(0, 0, 0, None))(query, key, value, mask) | |
| output = einops.rearrange(output, "h s d -> s (h d)") | |
| return output | |
| return multi_headed_attention | |
| class Transformer(equinox.Module): | |
| query_projector: jax.Array | |
| key_projector: jax.Array | |
| value_projector: jax.Array | |
| attention: typing.Callable = eqx.field(static=True) | |
| layer_norm_1: equinox.Module | |
| linear: equinox.Module | |
| layer_norm_2: equinox.Module | |
| def __init__( | |
| self, | |
| query_dimension, | |
| key_dimension, | |
| value_dimension, | |
| attention_dimension, | |
| heads, | |
| hidden, | |
| key | |
| ): | |
| rng = key | |
| rng, key = jax.random.split(rng) | |
| self.query_projector = jax.random.normal(key, (attention_dimension, query_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.key_projector = jax.random.normal(key, (attention_dimension, key_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.value_projector = jax.random.normal(key, (query_dimension, value_dimension)) | |
| rng, key = jax.random.split(rng) | |
| self.attention = make_multiheaded_attention(attention_dimension, query_dimension, heads) | |
| rng, key = jax.random.split(rng) | |
| self.layer_norm_1 = equinox.nn.LayerNorm(query_dimension) | |
| rng, key_1, key_2 = jax.random.split(rng, 3) | |
| self.linear = equinox.nn.Sequential([ | |
| equinox.nn.Linear(query_dimension, hidden, key=key_1), | |
| equinox.nn.Lambda(lambda x: jax.nn.relu(x)), | |
| equinox.nn.Linear(hidden, query_dimension, key=key_2) | |
| ]) | |
| self.layer_norm_2 = equinox.nn.LayerNorm(query_dimension) | |
| def __call__(self, query, key, value, mask): | |
| query_attention_projected = einops.einsum( | |
| self.query_projector, | |
| query, | |
| "a q, s q -> s a" | |
| ) | |
| key_attention_projected = einops.einsum( | |
| self.key_projector, | |
| key, | |
| "a k, s k -> s a" | |
| ) | |
| value_query_projected = einops.einsum( | |
| self.value_projector, | |
| value, | |
| "q v, s v -> s q" | |
| ) | |
| x = jax.vmap(self.layer_norm_1)(query + self.attention(query_attention_projected, key_attention_projected, value_query_projected, mask)) | |
| x = jax.vmap(self.layer_norm_2)(x + jax.vmap(self.linear)(x)) | |
| return x | |
| class EncoderTransformer(equinox.Module): | |
| transformer: Transformer | |
| def __init__(self, transformer): | |
| self.transformer = transformer | |
| def __call__(self, x, mask): | |
| return self.transformer(x, x, x, x, mask) | |
| class DecoderTransformer(equinox.Module): | |
| transformer: Transformer | |
| def __init__(self, transformer): | |
| self.transformer = transformer | |
| def __call__(self, target, memory, mask): | |
| return self.transformer(target, target, memory, memory, mask) | |
| def any_shape_linear(input_shape, output_shape, key=jax.random.PRNGKey(0)): | |
| flattened_input_shape = reduce(mul, input_shape) | |
| flattened_output_shape = reduce(mul, output_shape) | |
| return eqx.nn.Sequential([ | |
| eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape))), | |
| eqx.nn.Linear(flattened_input_shape, flattened_output_shape, key=key), | |
| eqx.nn.Lambda(lambda x: x.reshape(output_shape)) | |
| ]) | |
| VGG_CLASSIFIER_INPUT_SIZE = 512 * 7 * 7 | |
| def make_mlp(input_shape, output_shape, num_layers=4, hidden_size=100, key=jax.random.PRNGKey(0)): | |
| flattened_input_shape = reduce(mul, input_shape) | |
| flattened_output_shape = reduce(mul, output_shape) | |
| layers = [] | |
| layers.append( | |
| eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape))) | |
| ) | |
| input_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| eqx.nn.Linear(flattened_input_shape, hidden_size, key=input_layer_key) | |
| ) | |
| for i in range(num_layers): | |
| layers.append(eqx.nn.Lambda(jax.nn.relu)) | |
| hidden_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| eqx.nn.Linear(hidden_size, hidden_size, key=hidden_layer_key) | |
| ) | |
| layers.append(eqx.nn.Lambda(jax.nn.relu)) | |
| output_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| eqx.nn.Linear(hidden_size, flattened_output_shape, key=output_layer_key) | |
| ) | |
| layers.append( | |
| eqx.nn.Lambda(lambda x: x.reshape(output_shape)) | |
| ) | |
| model= eqx.nn.Sequential(layers) | |
| return model | |
| class Residual(eqx.Module): | |
| linear: eqx.Module | |
| def __init__(self, linear): | |
| self.linear = linear | |
| def __call__(self, x, key): | |
| return self.linear(x) + x | |
| def make_resnet_mlp(input_shape, output_shape, num_layers=4, hidden_size=100, key=jax.random.PRNGKey(0)): | |
| flattened_input_shape = reduce(mul, input_shape) | |
| flattened_output_shape = reduce(mul, output_shape) | |
| layers = [] | |
| layers.append( | |
| eqx.nn.Lambda(lambda x: x.reshape((flattened_input_shape))) | |
| ) | |
| input_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| eqx.nn.Linear(flattened_input_shape, hidden_size, key=input_layer_key) | |
| ) | |
| for i in range(num_layers): | |
| layers.append(eqx.nn.Lambda(jax.nn.relu)) | |
| hidden_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| Residual(eqx.nn.Linear(hidden_size, hidden_size, key=hidden_layer_key)) | |
| ) | |
| layers.append(eqx.nn.Lambda(jax.nn.relu)) | |
| output_layer_key, key = jax.random.split(key) | |
| layers.append( | |
| eqx.nn.Linear(hidden_size, flattened_output_shape, key=output_layer_key) | |
| ) | |
| layers.append( | |
| eqx.nn.Lambda(lambda x: x.reshape(output_shape)) | |
| ) | |
| model= eqx.nn.Sequential(layers) | |
| return model | |
| def make_cnn(input_shape, output_shape, num_layers, filters=64, kernel_size=3, key=jax.random.PRNGKey(0)): | |
| layers = [] | |
| def reshape(x, input_shape): | |
| return x.reshape(input_shape) | |
| layers.append(eqx.nn.Lambda(partial(reshape, input_shape=input_shape))) | |
| for _ in range(num_layers): | |
| cnn_key, key = jax.random.split(key) | |
| layers.append(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key)) | |
| layers.append(eqx.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))) | |
| input_shape = [filters, input_shape[1] // 2, input_shape[2] // 2] | |
| print(input_shape) | |
| layers.append(any_shape_linear(input_shape, output_shape, key=key)) | |
| model = eqx.nn.Sequential(layers) | |
| return model | |
| def make_resnet_cnn(input_shape, output_shape, num_conv_layers=2, num_fc_layers=2, filters=64, kernel_size=3, hidden_size=100, key=jax.random.PRNGKey(0)): | |
| """ | |
| expects channels to be the last dimension both in the input shape and in __call__ | |
| """ | |
| layers = [] | |
| input_shape = [input_shape[2], input_shape[0], input_shape[1]] | |
| def reshape(x, input_shape): | |
| return x.reshape(input_shape) | |
| layers.append(eqx.nn.Lambda(partial(reshape, input_shape=input_shape))) | |
| for i in range(num_conv_layers): | |
| cnn_key, key = jax.random.split(key) | |
| if i == 0: | |
| layers.append(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key)) | |
| else: | |
| layers.append(Residual(eqx.nn.Conv2d(input_shape[0], filters, kernel_size=kernel_size, padding=(1, 1), key=cnn_key))) | |
| layers.append(eqx.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))) | |
| input_shape = [filters, input_shape[1] // 2, input_shape[2] // 2] # Update input shape for the next layer | |
| layers.append(make_resnet_mlp(input_shape, output_shape, num_fc_layers, hidden_size=hidden_size, key=key)) | |
| # Combine all layers | |
| model = eqx.nn.Sequential(layers) | |
| return model | |
| class NNTuple(eqx.Module): | |
| modules: list[eqx.Module] | |
| def __init__(self, modules): | |
| self.modules = modules | |
| def __call__(self, xs, key): | |
| return [ | |
| module(x, key=key) for (module, x) in zip(self.modules, xs) | |
| ] | |
| # calling value should have at least one dimension! shape of (1) instead of () | |
| class FourierFeatures(eqx.Module): | |
| kernel: jax.numpy.ndarray | |
| reshape: eqx.Module | |
| def __init__(self, input_size, output_size, key): | |
| self.kernel = jax.random.normal(key, (output_size // 2, input_size)) * 0.2 | |
| self.reshape = eqx.nn.Lambda(lambda x: x.reshape(output_size)) | |
| def __call__(self, x, key): | |
| f = 2 * jax.numpy.pi * einops.einsum(self.kernel, x, "o i, i -> o") | |
| return self.reshape(jax.numpy.concatenate([jax.numpy.cos(f), jax.numpy.sin(f)])) | |
| def make_diffusion_model(encoders, diffusion): | |
| return eqx.nn.Sequential( | |
| layers=[ | |
| NNTuple(encoders), | |
| eqx.nn.Lambda(lambda encoded_values: jax.numpy.concatenate(encoded_values)), | |
| diffusion, | |
| ] | |
| ) | |
| def make_vgg_model(classifier, filename=None): | |
| model, state = eqv.models.vgg11(torch_weights=CLASSIFICATION_URLS["vgg11"]) | |
| model = eqx.tree_at( | |
| lambda m: m.classifier, | |
| model, | |
| classifier | |
| ) | |
| if filename: | |
| model = load_model(model, filename) | |
| return model | |
| # accepts a vgg model and will return models that are frozen except for the finetuned part | |
| def make_classifier_finetuned_vgg(model): | |
| tree_struct = jax.tree_util.tree_map(lambda _: False, model) | |
| filter_spec = eqx.tree_at( | |
| lambda tree: [layer for layer in tree.classifier.layers], | |
| tree_struct, | |
| replace=[True for _ in range(len(model.classifier.layers))], | |
| ) | |
| dynamic_model, static_model = eqx.partition(model, filter_spec) | |
| return dynamic_model, static_model | |
| def make_fully_unfrozen_model(model): | |
| filter_spec = jax.tree_util.tree_map(lambda _: True, model) | |
| dynamic_model, static_model = eqx.partition(model, filter_spec) | |
| return dynamic_model, static_model | |
| def save_model(model, filename): | |
| eqx.tree_serialise_leaves(filename, model) | |
| def load_model(model, filename): | |
| with open(filename, "rb") as f: | |
| return eqx.tree_deserialise_leaves(f, model) | |
| import unittest | |
| class TestModel(unittest.TestCase): | |
| def test_make_and_save_model(self): | |
| model = make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)) | |
| state = eqx.nn.State(model) | |
| self.assertIsNotNone(model) | |
| self.assertIsNotNone(state) | |
| filename = "/tmp/model.eqx" | |
| save_model(model, filename) | |
| loaded_model = load_model(make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)), filename) | |
| loaded_state = eqx.nn.State(loaded_model) | |
| num_iter = 10 | |
| for i in range(num_iter): | |
| sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224)) | |
| key = jax.random.PRNGKey(i + num_iter) | |
| model_output, _ = model(sample_input, state, key=key) | |
| loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key) | |
| self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output)) | |
| def test_fully_unfrozen_model(self): | |
| model = make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)) | |
| dynamic_model, static_model = make_fully_unfrozen_model(model) | |
| model = eqx.combine(dynamic_model, static_model) | |
| state = eqx.nn.State(model) | |
| self.assertIsNotNone(model) | |
| self.assertIsNotNone(state) | |
| filename = "/tmp/model.eqx" | |
| save_model(model, filename) | |
| loaded_model = load_model(make_mlp((3, 224, 224), (4, 4), 100, key=jax.random.PRNGKey(1)), filename) | |
| loaded_dynamic_model, loaded_static_model = make_fully_unfrozen_model(model) | |
| loaded_model = eqx.combine(loaded_dynamic_model, loaded_static_model) | |
| loaded_state = eqx.nn.State(loaded_model) | |
| num_iter = 10 | |
| for i in range(num_iter): | |
| sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224)) | |
| key = jax.random.PRNGKey(i + num_iter) | |
| model_output, _ = model(sample_input, state, key=key) | |
| loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key) | |
| self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output)) | |
| def test_cnn(self): | |
| image_shape = (16, 512, 512, 3) | |
| output_shape = (16, 3,) | |
| fake_image = jax.random.normal(jax.random.PRNGKey(0), image_shape) | |
| model = make_cnn(image_shape[1:], output_shape[1:], num_conv_layers=2, num_fc_layers=2, filters=64, kernel_size=3, hidden_size=100, key=jax.random.PRNGKey(0)) | |
| state = eqx.nn.State(model) | |
| model_output, _ = eqx.filter_vmap(model)(fake_image, state) | |
| self.assertEqual(model_output.shape, output_shape) | |
| self.assertIsNotNone(model_output) | |
| def test_vgg(self): | |
| mlp_model = make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1)) | |
| model = make_vgg_model(mlp_model) | |
| state = eqx.nn.State(model) | |
| self.assertIsNotNone(model) | |
| self.assertIsNotNone(state) | |
| filename = "/tmp/model.eqx" | |
| save_model(model, filename) | |
| loaded_model = load_model(make_vgg_model(make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))), filename) | |
| loaded_state = eqx.nn.State(loaded_model) | |
| num_iter = 10 | |
| for i in range(num_iter): | |
| sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224)) | |
| key = jax.random.PRNGKey(i + num_iter) | |
| model_output, _ = model(sample_input, state, key=key) | |
| loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key) | |
| self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output)) | |
| def test_finetuned_vgg(self): | |
| mlp_model = make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1)) | |
| model = make_vgg_model(mlp_model) | |
| dynamic_model, static_model = make_classifier_finetuned_vgg(model) | |
| model = eqx.combine(dynamic_model, static_model) | |
| state = eqx.nn.State(model) | |
| self.assertIsNotNone(model) | |
| self.assertIsNotNone(state) | |
| filename = "/tmp/model.eqx" | |
| save_model(model, filename) | |
| loaded_model = load_model(make_vgg_model(make_mlp((VGG_CLASSIFIER_INPUT_SIZE,), (4, 4), 100, key=jax.random.PRNGKey(1))), filename) | |
| loaded_dynamic_model, loaded_static_model = make_classifier_finetuned_vgg(model) | |
| loaded_model = eqx.combine(loaded_dynamic_model, loaded_static_model) | |
| loaded_state = eqx.nn.State(loaded_model) | |
| num_iter = 10 | |
| for i in range(num_iter): | |
| sample_input = jax.random.normal(jax.random.PRNGKey(i), (3, 224, 224)) | |
| key = jax.random.PRNGKey(i + num_iter) | |
| model_output, _ = model(sample_input, state, key=key) | |
| loaded_model_output, _ = loaded_model(sample_input, loaded_state, key=key) | |
| self.assertTrue(jax.numpy.allclose(model_output, loaded_model_output)) | |
| def test_diffusion_model(self): | |
| IMAGE_SIZE = (1024, 1024, 3) | |
| LOOKAHEAD_STEPS = 4 | |
| DIFFUSION_STEPS = 6 | |
| POSE_SIZE = (4, 4) | |
| image_encoder_model = make_cnn(input_shape=IMAGE_SIZE, output_shape=(100,), num_conv_layers=2, num_fc_layers=2) | |
| pose_encoder_model = make_mlp(input_shape=((LOOKAHEAD_STEPS, *POSE_SIZE)), output_shape=(100,)) | |
| diffusion_model = make_mlp(input_shape=(200,), output_shape=(LOOKAHEAD_STEPS, *POSE_SIZE)) | |
| model = make_diffusion_model([image_encoder_model, pose_encoder_model], diffusion_model) | |
| self.assertTrue(model(( | |
| jax.numpy.zeros(IMAGE_SIZE), | |
| jax.numpy.zeros((LOOKAHEAD_STEPS, *POSE_SIZE)) | |
| )).shape, (LOOKAHEAD_STEPS, *POSE_SIZE)) | |
| if __name__ == '__main__': | |
| unittest.main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the cnn and vgg stuff I wouldn't recommend using. the transformer + Fourier feature stuff works for sure though