Skip to content

Instantly share code, notes, and snippets.

@YannBerthelot
Created August 21, 2025 12:13
Show Gist options
  • Select an option

  • Save YannBerthelot/5dac52c11da4e121248b2551fa74596a to your computer and use it in GitHub Desktop.

Select an option

Save YannBerthelot/5dac52c11da4e121248b2551fa74596a to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple
from flax.training.train_state import TrainState
import distrax
from gymnax.wrappers.purerl import LogWrapper
from brax import envs
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper
from flax import struct
from gymnax.environments import environment, spaces
from gymnax.wrappers.purerl import GymnaxWrapper
import matplotlib.pyplot as plt
class BraxGymnaxWrapper:
def __init__(self, env_name, backend="positional"):
env = envs.get_environment(env_name=env_name, backend=backend)
env = EpisodeWrapper(env, episode_length=1000, action_repeat=1)
env = AutoResetWrapper(env)
self._env = env
self.action_size = env.action_size
self.observation_size = (env.observation_size,)
def reset(self, key, params=None):
state = self._env.reset(key)
return state.obs, state
def step(self, key, state, action, params=None):
next_state = self._env.step(state, action)
return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {}
def observation_space(self, params):
return spaces.Box(
low=-jnp.inf,
high=jnp.inf,
shape=(self._env.observation_size,),
)
def action_space(self, params):
return spaces.Box(
low=-1.0,
high=1.0,
shape=(self._env.action_size,),
)
class ClipAction(GymnaxWrapper):
def __init__(self, env, low=-1.0, high=1.0):
super().__init__(env)
self.low = low
self.high = high
def step(self, key, state, action, params=None):
"""TODO: In theory the below line should be the way to do this."""
# action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high)
action = jnp.clip(action, self.low, self.high)
return self._env.step(key, state, action, params)
class VecEnv(GymnaxWrapper):
def __init__(self, env):
super().__init__(env)
self.reset = jax.vmap(self._env.reset, in_axes=(0, None))
self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
@struct.dataclass
class NormalizeVecObsEnvState:
mean: jnp.ndarray
var: jnp.ndarray
count: float
env_state: environment.EnvState
class NormalizeVecObservation(GymnaxWrapper):
def __init__(self, env):
super().__init__(env)
def reset(self, key, params=None):
obs, state = self._env.reset(key, params)
state = NormalizeVecObsEnvState(
mean=jnp.zeros_like(obs),
var=jnp.ones_like(obs),
count=1e-4,
env_state=state,
)
batch_mean = jnp.mean(obs, axis=0)
batch_var = jnp.var(obs, axis=0)
batch_count = obs.shape[0]
delta = batch_mean - state.mean
tot_count = state.count + batch_count
new_mean = state.mean + delta * batch_count / tot_count
m_a = state.var * state.count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
state = NormalizeVecObsEnvState(
mean=new_mean,
var=new_var,
count=new_count,
env_state=state.env_state,
)
return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state
def step(self, key, state, action, params=None):
obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
batch_mean = jnp.mean(obs, axis=0)
batch_var = jnp.var(obs, axis=0)
batch_count = obs.shape[0]
delta = batch_mean - state.mean
tot_count = state.count + batch_count
new_mean = state.mean + delta * batch_count / tot_count
m_a = state.var * state.count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
state = NormalizeVecObsEnvState(
mean=new_mean,
var=new_var,
count=new_count,
env_state=env_state,
)
return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state, reward, done, info
@struct.dataclass
class NormalizeVecRewEnvState:
mean: jnp.ndarray
var: jnp.ndarray
count: float
return_val: float
env_state: environment.EnvState
class NormalizeVecReward(GymnaxWrapper):
def __init__(self, env, gamma):
super().__init__(env)
self.gamma = gamma
def reset(self, key, params=None):
obs, state = self._env.reset(key, params)
batch_count = obs.shape[0]
state = NormalizeVecRewEnvState(
mean=0.0,
var=1.0,
count=1e-4,
return_val=jnp.zeros((batch_count,)),
env_state=state,
)
return obs, state
def step(self, key, state, action, params=None):
obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
return_val = (state.return_val * self.gamma * (1 - done) + reward)
batch_mean = jnp.mean(return_val, axis=0)
batch_var = jnp.var(return_val, axis=0)
batch_count = obs.shape[0]
delta = batch_mean - state.mean
tot_count = state.count + batch_count
new_mean = state.mean + delta * batch_count / tot_count
m_a = state.var * state.count
m_b = batch_var * batch_count
M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
state = NormalizeVecRewEnvState(
mean=new_mean,
var=new_var,
count=new_count,
return_val=return_val,
env_state=env_state,
)
return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info
class ActorCritic(nn.Module):
action_dim: Sequence[int]
activation: str = "tanh"
@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(
256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean)
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
critic = nn.Dense(
256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)
return pi, jnp.squeeze(critic, axis=-1)
class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
def make_train(config):
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)
env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)
if config["NORMALIZE_ENV"]:
env = NormalizeVecObservation(env)
env = NormalizeVecReward(env, config["GAMMA"])
def linear_schedule(count):
frac = (
1.0
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
/ config["NUM_UPDATES"]
)
return config["LR"] * frac
def train(rng):
# INIT NETWORK
network = ActorCritic(
env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(config["LR"], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = env.reset(reset_rng, env_params)
# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs, info
)
runner_state = (train_state, env_state, obsv, rng)
return runner_state, transition
runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config["NUM_STEPS"]
)
# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
gae = (
delta
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
)
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info
def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)
# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)
# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["CLIP_EPS"],
1.0 + config["CLIP_EPS"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
assert (
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
), "batch size must be equal to number of steps * number of envs"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]
if config.get("DEBUG"):
def callback(info):
return_values = info["returned_episode_returns"][info["returned_episode"]]
timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
for t in range(len(timesteps)):
print(f"global step={timesteps[t]}, episodic return={return_values[t]}")
jax.debug.callback(callback, metric)
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config["NUM_UPDATES"]
)
return {"runner_state": runner_state, "metrics": metric}
return train
if __name__ == "__main__":
config = {
"LR": 3e-4,
"NUM_ENVS": 2048,
"NUM_STEPS": 10,
"TOTAL_TIMESTEPS": 5e7,
"UPDATE_EPOCHS": 4,
"NUM_MINIBATCHES": 32,
"GAMMA": 0.99,
"GAE_LAMBDA": 0.95,
"CLIP_EPS": 0.2,
"ENT_COEF": 0.0,
"VF_COEF": 0.5,
"MAX_GRAD_NORM": 0.5,
"ACTIVATION": "tanh",
"ENV_NAME": "hopper",
"ANNEAL_LR": False,
"NORMALIZE_ENV": True,
"DEBUG": True,
}
rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Updates")
plt.ylabel("Return")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment