Skip to content

Instantly share code, notes, and snippets.

@tripplyons
Created January 16, 2025 16:27
Show Gist options
  • Select an option

  • Save tripplyons/e0eb3215a66720008f430b7141bc20d9 to your computer and use it in GitHub Desktop.

Select an option

Save tripplyons/e0eb3215a66720008f430b7141bc20d9 to your computer and use it in GitHub Desktop.
Parallel and serial JAX implementations of the RWKV-7 recurrence
import jax
import jax.numpy as jnp
def associative_op(left, right):
left_wab, left_vk = left
right_wab, right_vk = right
new_wab = jnp.matmul(left_wab, right_wab)
new_vk = jnp.matmul(left_vk, right_wab) + right_vk
return new_wab, new_vk
def parallel_states(wab, vk):
_, S = jax.lax.associative_scan(associative_op, (wab, vk))
return S
def recurrent_states(S, wab, vk):
return jnp.matmul(S, wab) + vk
wab_rng, vk_rng = jax.random.split(jax.random.PRNGKey(0), 2)
wab = jax.random.normal(wab_rng, (10, 2, 2))
vk = jax.random.normal(vk_rng, (10, 2, 2))
parallel_S = parallel_states(wab, vk)
a = (wab[0], vk[0])
b = (wab[1], vk[1])
c = (wab[2], vk[2])
ab_c = associative_op(associative_op(a, b), c)
a_bc = associative_op(a, associative_op(b, c))
print([x[0][0] for x in ab_c])
print([x[0][0] for x in a_bc])
current_S = jnp.zeros((2, 2))
recurrent_S = []
for i in range(10):
current_S = recurrent_states(current_S, wab[i], vk[i])
recurrent_S.append(current_S)
recurrent_S = jnp.stack(recurrent_S)
print(parallel_S[:, 0, 0])
print(recurrent_S[:, 0, 0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment