Last active
March 6, 2025 18:51
-
-
Save vyeevani/9e57736e49c7dc180f32be3b809c85d2 to your computer and use it in GitHub Desktop.
exploration of a sequence model that could handle anything
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Sequence | |
| import einops | |
| import jax | |
| import equinox | |
| def make_fourier_features(max_seq_len, embedding_size): | |
| min_freq = 1.0 | |
| max_resolution = max_seq_len | |
| num_bands = embedding_size // 2 | |
| freq_bands = jax.numpy.linspace(min_freq, max_resolution / 2, num=num_bands) | |
| pos = jax.numpy.linspace(0, max_resolution - 1, num=max_resolution) | |
| pos = pos[:, None] | |
| per_pos_features = pos * freq_bands | |
| sine_features = jax.numpy.sin(jax.numpy.pi * per_pos_features) | |
| cosine_features = jax.numpy.cos(jax.numpy.pi * per_pos_features) | |
| fourier_features = jax.numpy.concatenate([sine_features, cosine_features], axis=-1) | |
| return fourier_features | |
| def generate_io_mask(context_timesteps: jax.Array, latent_timesteps: jax.Array) -> jax.Array: | |
| """ | |
| Generates a casual mask using timesteps for each token. This allows you to have out of order tokens. We can either be encoding things into the latent or decoding things into the context | |
| TODO: this current doesn't take into account padding. This should not be impossible. It'll manifest as the time_in having padded tokens be uint max. | |
| token_timesteps: array of unsigned integers with shape (T1,) | |
| latent_timesteps: array of unsigned integers with shape (T2,) | |
| output: array of bools equal with shape (T2, T1) | |
| """ | |
| return einops.repeat(latent_timesteps, "t1 -> t1 t2", t2=context_timesteps.shape[0]) >= einops.repeat(context_timesteps, "t1 -> t1 t2", t2=latent_timesteps.shape[0]).T | |
| class Layer(equinox.Module): | |
| attn: equinox.nn.MultiheadAttention = equinox.field(static=False) | |
| input_norm: equinox.nn.LayerNorm = equinox.field(static=False) | |
| output_norm: equinox.nn.LayerNorm = equinox.field(static=False) | |
| input_mlp: equinox.nn.MLP = equinox.field(static=False) | |
| output_mlp: equinox.nn.MLP = equinox.field(static=False) | |
| def __init__(self, dimension: int, rng: jax.Array): | |
| rng, key = jax.random.split(rng) | |
| self.attn = equinox.nn.MultiheadAttention( | |
| num_heads=8, | |
| query_size=dimension, | |
| use_query_bias=True, | |
| use_key_bias=True, | |
| use_value_bias=True, | |
| key=key | |
| ) | |
| self.input_norm = equinox.nn.LayerNorm(dimension) | |
| self.output_norm = equinox.nn.LayerNorm(dimension) | |
| rng, key = jax.random.split(rng) | |
| self.input_mlp = equinox.nn.MLP(dimension, dimension, dimension * 10, 1, activation=jax.nn.gelu, key=key) | |
| rng, key = jax.random.split(rng) | |
| self.output_mlp = equinox.nn.MLP(dimension, dimension, dimension * 10, 1, activation=jax.nn.gelu, key=key) | |
| def __call__(self, input_: jax.Array, output: jax.Array, mask): | |
| input_ = input_ + equinox.filter_vmap(self.input_mlp)(input_) | |
| normalized_input = equinox.filter_vmap(self.input_norm)(input_) | |
| output = output + self.attn(output, normalized_input, input_, mask) | |
| output = output + equinox.filter_vmap(self.output_mlp)(equinox.filter_vmap(self.output_norm)(output)) | |
| return output | |
| class Perceiver(equinox.Module): | |
| fourier_features: jax.Array = equinox.field(static=True) | |
| latent: jax.Array = equinox.field(static=False) | |
| input_layers: Sequence[Layer] = equinox.field(static=False) | |
| output_layers: Sequence[Layer] = equinox.field(static=False) | |
| output_projector: equinox.nn.Linear = equinox.field(static=False) | |
| def __init__(self, latent_count, dimension, num_layers, rng): | |
| rng, key = jax.random.split(rng) | |
| self.fourier_features = make_fourier_features(max_seq_len=200000, embedding_size=dimension) | |
| rng, key = jax.random.split(rng) | |
| self.latent = jax.random.normal(key, (latent_count, dimension * 2)) | |
| rng, key = jax.random.split(rng) | |
| keys = jax.random.split(key, num_layers) | |
| self.input_layers = [Layer(dimension * 2, key) for key in keys] | |
| rng, key = jax.random.split(rng) | |
| keys = jax.random.split(key, num_layers) | |
| self.output_layers = [Layer(dimension * 2, key) for key in keys] | |
| rng, key = jax.random.split(rng) | |
| self.output_projector = equinox.nn.Linear(dimension * 2, dimension, key=key) | |
| def __call__(self, x, x_timesteps, num_timesteps): | |
| latent = self.latent | |
| latent = einops.repeat(latent, 'l d -> (t l) d', t=num_timesteps) | |
| latent_timesteps = einops.repeat(jax.numpy.arange(num_timesteps), "d -> (d a)", a=self.latent.shape[0]) | |
| input_mask = generate_io_mask(x_timesteps, latent_timesteps) | |
| output_mask = generate_io_mask(latent_timesteps, x_timesteps) | |
| x = einops.pack([self.fourier_features[:x.shape[0]], x], "s *")[0] | |
| for input_layer, output_layer in zip(self.input_layers, self.output_layers): | |
| latent = input_layer(x, latent, input_mask) | |
| x = output_layer(latent, x, output_mask) | |
| return equinox.filter_vmap(self.output_projector)(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import List, Dict | |
| import jax | |
| import equinox | |
| import einops | |
| from perceiver import Perceiver | |
| @dataclass | |
| class StartupSystem: | |
| name: str | |
| output: List[str] | |
| @dataclass | |
| class UpdateSystem: | |
| name: str | |
| input: List[str] | |
| output: List[str] | |
| class Token(equinox.Module): | |
| key: str = equinox.field(static=True) | |
| valid: jax.Array = equinox.field(static=False) # bool (t,) | |
| timestep: jax.Array = equinox.field(static=False) # int (t,) | |
| data: jax.Array = equinox.field(static=False) # float (t, *) | |
| class SequenceModel(equinox.Module): | |
| casual_id_map: Dict[str, int] = equinox.field(static=True) | |
| max_casual_id: int = equinox.field(static=True) | |
| backbone: Perceiver = equinox.field(static=False) | |
| def __init__(self, casual_id_map: Dict[str, int], backbone: Perceiver): | |
| self.casual_id_map = casual_id_map | |
| self.backbone = backbone | |
| self.max_casual_id = max(self.casual_id_map.values()) | |
| def __call__(self, x_list: List[Token]): | |
| casual_list, _ = einops.pack([einops.repeat(self.max_casual_id * x.timestep + self.casual_id_map[x.key], "t -> t s", s=x.data.shape[1]) for x in x_list], "*") | |
| data_list, data_shapes = einops.pack([x.data for x in x_list], "* d") | |
| num_timesteps = max([x.timestep.shape[0] for x in x_list]) * self.max_casual_id | |
| data_list = self.backbone(data_list, casual_list, num_timesteps) | |
| x_data = einops.unpack(data_list, data_shapes, "* d") | |
| x_list = [Token(key=x.key, valid=x.valid, timestep=x.timestep, data=data) for x, data in zip(x_list, x_data)] | |
| return x_list | |
| @dataclass | |
| class SequenceModelBuilder: | |
| backbone: equinox.Module = None | |
| startup_systems: List[StartupSystem] = equinox.field(default_factory=list) | |
| update_systems: List[UpdateSystem] = equinox.field(default_factory=list) | |
| def set_backbone(self, backbone: equinox.Module): | |
| self.backbone = backbone | |
| return self | |
| def add_startup_system(self, system: StartupSystem): | |
| self.startup_systems.append(system) | |
| return self | |
| def add_update_system(self, system: UpdateSystem): | |
| self.update_systems.append(system) | |
| return self | |
| def build(self) -> SequenceModel: | |
| # Create a mapping of keys to causal IDs | |
| casual_id_map = {} | |
| current_id = 0 | |
| # Process startup systems | |
| for startup_system in self.startup_systems: | |
| provisioned_new_casual_id = False | |
| for output_key in startup_system.output: | |
| if output_key not in casual_id_map: | |
| provisioned_new_casual_id = True | |
| casual_id_map[output_key] = current_id | |
| if provisioned_new_casual_id: | |
| current_id += 1 | |
| # Process update systems | |
| for update_system in self.update_systems: | |
| provisioned_new_casual_id = False | |
| for input_key in update_system.input: | |
| if input_key not in casual_id_map: | |
| provisioned_new_casual_id = True | |
| casual_id_map[input_key] = current_id | |
| if provisioned_new_casual_id: | |
| current_id += 1 | |
| provisioned_new_casual_id = False | |
| for output_key in update_system.output: | |
| if output_key not in casual_id_map: | |
| provisioned_new_casual_id = True | |
| casual_id_map[output_key] = current_id | |
| if provisioned_new_casual_id: | |
| current_id += 1 | |
| return SequenceModel( | |
| casual_id_map=casual_id_map, | |
| backbone=self.backbone, | |
| ) | |
| def test(): | |
| # Create a builder | |
| builder = SequenceModelBuilder() | |
| # Add systems as per instructions | |
| builder.add_startup_system(StartupSystem("env", ["obs", "rew"])) | |
| builder.add_update_system(UpdateSystem("policy", ["obs", "rew"], ["act"])) | |
| builder.add_update_system(UpdateSystem("env", ["obs", "act"], ["obs", "rew"])) | |
| # Set a mock backbone | |
| # mock_backbone = equinox.nn.Lambda(lambda x: x) | |
| dim = 1024 | |
| mock_backbone = Perceiver(2, dim, 1, jax.random.key(0)) | |
| builder.set_backbone(mock_backbone) | |
| # Build the world model | |
| model = builder.build() | |
| input_data = [ | |
| Token( | |
| key="obs", | |
| valid=jax.numpy.ones(2), | |
| timestep=jax.numpy.array([0, 1]), | |
| data=jax.numpy.ones((2, 3, 1024)) | |
| ), | |
| Token( | |
| key="rew", | |
| valid=jax.numpy.ones(2), | |
| timestep=jax.numpy.array([0, 1]), | |
| data=jax.numpy.ones((2, 1, 1024)) | |
| ), | |
| Token( | |
| key="act", | |
| valid=jax.numpy.ones(3), | |
| timestep=jax.numpy.array([0, 1, 2]), | |
| data=jax.numpy.ones((3, 1, 1024)) | |
| ) | |
| ] | |
| input_data = jax.tree.map(lambda x: einops.repeat(x, "... -> b ...", b=10), input_data) | |
| result = equinox.filter_vmap(model)(input_data) | |
| print("Result:", result) | |
| if __name__ == "__main__": | |
| test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment