Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Last active March 6, 2025 18:51
Show Gist options
  • Select an option

  • Save vyeevani/9e57736e49c7dc180f32be3b809c85d2 to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/9e57736e49c7dc180f32be3b809c85d2 to your computer and use it in GitHub Desktop.
exploration of a sequence model that could handle anything
from typing import Sequence
import einops
import jax
import equinox
def make_fourier_features(max_seq_len, embedding_size):
min_freq = 1.0
max_resolution = max_seq_len
num_bands = embedding_size // 2
freq_bands = jax.numpy.linspace(min_freq, max_resolution / 2, num=num_bands)
pos = jax.numpy.linspace(0, max_resolution - 1, num=max_resolution)
pos = pos[:, None]
per_pos_features = pos * freq_bands
sine_features = jax.numpy.sin(jax.numpy.pi * per_pos_features)
cosine_features = jax.numpy.cos(jax.numpy.pi * per_pos_features)
fourier_features = jax.numpy.concatenate([sine_features, cosine_features], axis=-1)
return fourier_features
def generate_io_mask(context_timesteps: jax.Array, latent_timesteps: jax.Array) -> jax.Array:
"""
Generates a casual mask using timesteps for each token. This allows you to have out of order tokens. We can either be encoding things into the latent or decoding things into the context
TODO: this current doesn't take into account padding. This should not be impossible. It'll manifest as the time_in having padded tokens be uint max.
token_timesteps: array of unsigned integers with shape (T1,)
latent_timesteps: array of unsigned integers with shape (T2,)
output: array of bools equal with shape (T2, T1)
"""
return einops.repeat(latent_timesteps, "t1 -> t1 t2", t2=context_timesteps.shape[0]) >= einops.repeat(context_timesteps, "t1 -> t1 t2", t2=latent_timesteps.shape[0]).T
class Layer(equinox.Module):
attn: equinox.nn.MultiheadAttention = equinox.field(static=False)
input_norm: equinox.nn.LayerNorm = equinox.field(static=False)
output_norm: equinox.nn.LayerNorm = equinox.field(static=False)
input_mlp: equinox.nn.MLP = equinox.field(static=False)
output_mlp: equinox.nn.MLP = equinox.field(static=False)
def __init__(self, dimension: int, rng: jax.Array):
rng, key = jax.random.split(rng)
self.attn = equinox.nn.MultiheadAttention(
num_heads=8,
query_size=dimension,
use_query_bias=True,
use_key_bias=True,
use_value_bias=True,
key=key
)
self.input_norm = equinox.nn.LayerNorm(dimension)
self.output_norm = equinox.nn.LayerNorm(dimension)
rng, key = jax.random.split(rng)
self.input_mlp = equinox.nn.MLP(dimension, dimension, dimension * 10, 1, activation=jax.nn.gelu, key=key)
rng, key = jax.random.split(rng)
self.output_mlp = equinox.nn.MLP(dimension, dimension, dimension * 10, 1, activation=jax.nn.gelu, key=key)
def __call__(self, input_: jax.Array, output: jax.Array, mask):
input_ = input_ + equinox.filter_vmap(self.input_mlp)(input_)
normalized_input = equinox.filter_vmap(self.input_norm)(input_)
output = output + self.attn(output, normalized_input, input_, mask)
output = output + equinox.filter_vmap(self.output_mlp)(equinox.filter_vmap(self.output_norm)(output))
return output
class Perceiver(equinox.Module):
fourier_features: jax.Array = equinox.field(static=True)
latent: jax.Array = equinox.field(static=False)
input_layers: Sequence[Layer] = equinox.field(static=False)
output_layers: Sequence[Layer] = equinox.field(static=False)
output_projector: equinox.nn.Linear = equinox.field(static=False)
def __init__(self, latent_count, dimension, num_layers, rng):
rng, key = jax.random.split(rng)
self.fourier_features = make_fourier_features(max_seq_len=200000, embedding_size=dimension)
rng, key = jax.random.split(rng)
self.latent = jax.random.normal(key, (latent_count, dimension * 2))
rng, key = jax.random.split(rng)
keys = jax.random.split(key, num_layers)
self.input_layers = [Layer(dimension * 2, key) for key in keys]
rng, key = jax.random.split(rng)
keys = jax.random.split(key, num_layers)
self.output_layers = [Layer(dimension * 2, key) for key in keys]
rng, key = jax.random.split(rng)
self.output_projector = equinox.nn.Linear(dimension * 2, dimension, key=key)
def __call__(self, x, x_timesteps, num_timesteps):
latent = self.latent
latent = einops.repeat(latent, 'l d -> (t l) d', t=num_timesteps)
latent_timesteps = einops.repeat(jax.numpy.arange(num_timesteps), "d -> (d a)", a=self.latent.shape[0])
input_mask = generate_io_mask(x_timesteps, latent_timesteps)
output_mask = generate_io_mask(latent_timesteps, x_timesteps)
x = einops.pack([self.fourier_features[:x.shape[0]], x], "s *")[0]
for input_layer, output_layer in zip(self.input_layers, self.output_layers):
latent = input_layer(x, latent, input_mask)
x = output_layer(latent, x, output_mask)
return equinox.filter_vmap(self.output_projector)(x)
from dataclasses import dataclass
from typing import List, Dict
import jax
import equinox
import einops
from perceiver import Perceiver
@dataclass
class StartupSystem:
name: str
output: List[str]
@dataclass
class UpdateSystem:
name: str
input: List[str]
output: List[str]
class Token(equinox.Module):
key: str = equinox.field(static=True)
valid: jax.Array = equinox.field(static=False) # bool (t,)
timestep: jax.Array = equinox.field(static=False) # int (t,)
data: jax.Array = equinox.field(static=False) # float (t, *)
class SequenceModel(equinox.Module):
casual_id_map: Dict[str, int] = equinox.field(static=True)
max_casual_id: int = equinox.field(static=True)
backbone: Perceiver = equinox.field(static=False)
def __init__(self, casual_id_map: Dict[str, int], backbone: Perceiver):
self.casual_id_map = casual_id_map
self.backbone = backbone
self.max_casual_id = max(self.casual_id_map.values())
def __call__(self, x_list: List[Token]):
casual_list, _ = einops.pack([einops.repeat(self.max_casual_id * x.timestep + self.casual_id_map[x.key], "t -> t s", s=x.data.shape[1]) for x in x_list], "*")
data_list, data_shapes = einops.pack([x.data for x in x_list], "* d")
num_timesteps = max([x.timestep.shape[0] for x in x_list]) * self.max_casual_id
data_list = self.backbone(data_list, casual_list, num_timesteps)
x_data = einops.unpack(data_list, data_shapes, "* d")
x_list = [Token(key=x.key, valid=x.valid, timestep=x.timestep, data=data) for x, data in zip(x_list, x_data)]
return x_list
@dataclass
class SequenceModelBuilder:
backbone: equinox.Module = None
startup_systems: List[StartupSystem] = equinox.field(default_factory=list)
update_systems: List[UpdateSystem] = equinox.field(default_factory=list)
def set_backbone(self, backbone: equinox.Module):
self.backbone = backbone
return self
def add_startup_system(self, system: StartupSystem):
self.startup_systems.append(system)
return self
def add_update_system(self, system: UpdateSystem):
self.update_systems.append(system)
return self
def build(self) -> SequenceModel:
# Create a mapping of keys to causal IDs
casual_id_map = {}
current_id = 0
# Process startup systems
for startup_system in self.startup_systems:
provisioned_new_casual_id = False
for output_key in startup_system.output:
if output_key not in casual_id_map:
provisioned_new_casual_id = True
casual_id_map[output_key] = current_id
if provisioned_new_casual_id:
current_id += 1
# Process update systems
for update_system in self.update_systems:
provisioned_new_casual_id = False
for input_key in update_system.input:
if input_key not in casual_id_map:
provisioned_new_casual_id = True
casual_id_map[input_key] = current_id
if provisioned_new_casual_id:
current_id += 1
provisioned_new_casual_id = False
for output_key in update_system.output:
if output_key not in casual_id_map:
provisioned_new_casual_id = True
casual_id_map[output_key] = current_id
if provisioned_new_casual_id:
current_id += 1
return SequenceModel(
casual_id_map=casual_id_map,
backbone=self.backbone,
)
def test():
# Create a builder
builder = SequenceModelBuilder()
# Add systems as per instructions
builder.add_startup_system(StartupSystem("env", ["obs", "rew"]))
builder.add_update_system(UpdateSystem("policy", ["obs", "rew"], ["act"]))
builder.add_update_system(UpdateSystem("env", ["obs", "act"], ["obs", "rew"]))
# Set a mock backbone
# mock_backbone = equinox.nn.Lambda(lambda x: x)
dim = 1024
mock_backbone = Perceiver(2, dim, 1, jax.random.key(0))
builder.set_backbone(mock_backbone)
# Build the world model
model = builder.build()
input_data = [
Token(
key="obs",
valid=jax.numpy.ones(2),
timestep=jax.numpy.array([0, 1]),
data=jax.numpy.ones((2, 3, 1024))
),
Token(
key="rew",
valid=jax.numpy.ones(2),
timestep=jax.numpy.array([0, 1]),
data=jax.numpy.ones((2, 1, 1024))
),
Token(
key="act",
valid=jax.numpy.ones(3),
timestep=jax.numpy.array([0, 1, 2]),
data=jax.numpy.ones((3, 1, 1024))
)
]
input_data = jax.tree.map(lambda x: einops.repeat(x, "... -> b ...", b=10), input_data)
result = equinox.filter_vmap(model)(input_data)
print("Result:", result)
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment