Skip to content

Instantly share code, notes, and snippets.

@vyeevani
Created May 23, 2024 17:40
Show Gist options
  • Select an option

  • Save vyeevani/e10c4a92bb74edf51b03d8a05e652049 to your computer and use it in GitHub Desktop.

Select an option

Save vyeevani/e10c4a92bb74edf51b03d8a05e652049 to your computer and use it in GitHub Desktop.
ragged arrays in jax
import jax
import chex
@chex.dataclass
class OptionalPyTree:
array: jax.Array
exists: bool
def make_optional_pytree(array, exists):
return OptionalPyTree(
array=array,
exists=exists
)
x = jax.numpy.array([[1., 2., 3.], [4., 5., 6.]])
y = jax.vmap(make_optional_pytree, in_axes=(0, 0))(jax.numpy.zeros((2, 3)), jax.numpy.array([False, True]))
def optional_vmap(func, zero_value, in_axes=0, out_axes=0):
def optional_apply(*args):
args = tuple(
make_optional_pytree(arg, True) if not isinstance(arg, OptionalPyTree) else arg
for arg in args
)
exists = jax.numpy.all(
jax.numpy.array(
jax.tree.leaves(
jax.tree.map(
lambda x: x.exists,
args,
is_leaf=lambda tree: isinstance(tree, OptionalPyTree)
)
)
)
)
return jax.lax.cond(
exists,
lambda: func(
*(jax.tree.map(
lambda x: x.array,
args,
is_leaf=lambda tree: isinstance(tree, OptionalPyTree)
))
),
lambda: zero_value
)
return jax.vmap(optional_apply, in_axes=in_axes, out_axes=out_axes)
print(optional_vmap(lambda x, y: x + y, jax.numpy.array([0., 0., 0.]))(x, y))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment