Created
August 21, 2025 12:13
-
-
Save YannBerthelot/5dac52c11da4e121248b2551fa74596a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| import numpy as np | |
| import optax | |
| from flax.linen.initializers import constant, orthogonal | |
| from typing import Sequence, NamedTuple | |
| from flax.training.train_state import TrainState | |
| import distrax | |
| from gymnax.wrappers.purerl import LogWrapper | |
| from brax import envs | |
| from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper | |
| from flax import struct | |
| from gymnax.environments import environment, spaces | |
| from gymnax.wrappers.purerl import GymnaxWrapper | |
| import matplotlib.pyplot as plt | |
| class BraxGymnaxWrapper: | |
| def __init__(self, env_name, backend="positional"): | |
| env = envs.get_environment(env_name=env_name, backend=backend) | |
| env = EpisodeWrapper(env, episode_length=1000, action_repeat=1) | |
| env = AutoResetWrapper(env) | |
| self._env = env | |
| self.action_size = env.action_size | |
| self.observation_size = (env.observation_size,) | |
| def reset(self, key, params=None): | |
| state = self._env.reset(key) | |
| return state.obs, state | |
| def step(self, key, state, action, params=None): | |
| next_state = self._env.step(state, action) | |
| return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} | |
| def observation_space(self, params): | |
| return spaces.Box( | |
| low=-jnp.inf, | |
| high=jnp.inf, | |
| shape=(self._env.observation_size,), | |
| ) | |
| def action_space(self, params): | |
| return spaces.Box( | |
| low=-1.0, | |
| high=1.0, | |
| shape=(self._env.action_size,), | |
| ) | |
| class ClipAction(GymnaxWrapper): | |
| def __init__(self, env, low=-1.0, high=1.0): | |
| super().__init__(env) | |
| self.low = low | |
| self.high = high | |
| def step(self, key, state, action, params=None): | |
| """TODO: In theory the below line should be the way to do this.""" | |
| # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high) | |
| action = jnp.clip(action, self.low, self.high) | |
| return self._env.step(key, state, action, params) | |
| class VecEnv(GymnaxWrapper): | |
| def __init__(self, env): | |
| super().__init__(env) | |
| self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) | |
| self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) | |
| @struct.dataclass | |
| class NormalizeVecObsEnvState: | |
| mean: jnp.ndarray | |
| var: jnp.ndarray | |
| count: float | |
| env_state: environment.EnvState | |
| class NormalizeVecObservation(GymnaxWrapper): | |
| def __init__(self, env): | |
| super().__init__(env) | |
| def reset(self, key, params=None): | |
| obs, state = self._env.reset(key, params) | |
| state = NormalizeVecObsEnvState( | |
| mean=jnp.zeros_like(obs), | |
| var=jnp.ones_like(obs), | |
| count=1e-4, | |
| env_state=state, | |
| ) | |
| batch_mean = jnp.mean(obs, axis=0) | |
| batch_var = jnp.var(obs, axis=0) | |
| batch_count = obs.shape[0] | |
| delta = batch_mean - state.mean | |
| tot_count = state.count + batch_count | |
| new_mean = state.mean + delta * batch_count / tot_count | |
| m_a = state.var * state.count | |
| m_b = batch_var * batch_count | |
| M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count | |
| new_var = M2 / tot_count | |
| new_count = tot_count | |
| state = NormalizeVecObsEnvState( | |
| mean=new_mean, | |
| var=new_var, | |
| count=new_count, | |
| env_state=state.env_state, | |
| ) | |
| return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state | |
| def step(self, key, state, action, params=None): | |
| obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params) | |
| batch_mean = jnp.mean(obs, axis=0) | |
| batch_var = jnp.var(obs, axis=0) | |
| batch_count = obs.shape[0] | |
| delta = batch_mean - state.mean | |
| tot_count = state.count + batch_count | |
| new_mean = state.mean + delta * batch_count / tot_count | |
| m_a = state.var * state.count | |
| m_b = batch_var * batch_count | |
| M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count | |
| new_var = M2 / tot_count | |
| new_count = tot_count | |
| state = NormalizeVecObsEnvState( | |
| mean=new_mean, | |
| var=new_var, | |
| count=new_count, | |
| env_state=env_state, | |
| ) | |
| return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state, reward, done, info | |
| @struct.dataclass | |
| class NormalizeVecRewEnvState: | |
| mean: jnp.ndarray | |
| var: jnp.ndarray | |
| count: float | |
| return_val: float | |
| env_state: environment.EnvState | |
| class NormalizeVecReward(GymnaxWrapper): | |
| def __init__(self, env, gamma): | |
| super().__init__(env) | |
| self.gamma = gamma | |
| def reset(self, key, params=None): | |
| obs, state = self._env.reset(key, params) | |
| batch_count = obs.shape[0] | |
| state = NormalizeVecRewEnvState( | |
| mean=0.0, | |
| var=1.0, | |
| count=1e-4, | |
| return_val=jnp.zeros((batch_count,)), | |
| env_state=state, | |
| ) | |
| return obs, state | |
| def step(self, key, state, action, params=None): | |
| obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params) | |
| return_val = (state.return_val * self.gamma * (1 - done) + reward) | |
| batch_mean = jnp.mean(return_val, axis=0) | |
| batch_var = jnp.var(return_val, axis=0) | |
| batch_count = obs.shape[0] | |
| delta = batch_mean - state.mean | |
| tot_count = state.count + batch_count | |
| new_mean = state.mean + delta * batch_count / tot_count | |
| m_a = state.var * state.count | |
| m_b = batch_var * batch_count | |
| M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count | |
| new_var = M2 / tot_count | |
| new_count = tot_count | |
| state = NormalizeVecRewEnvState( | |
| mean=new_mean, | |
| var=new_var, | |
| count=new_count, | |
| return_val=return_val, | |
| env_state=env_state, | |
| ) | |
| return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info | |
| class ActorCritic(nn.Module): | |
| action_dim: Sequence[int] | |
| activation: str = "tanh" | |
| @nn.compact | |
| def __call__(self, x): | |
| if self.activation == "relu": | |
| activation = nn.relu | |
| else: | |
| activation = nn.tanh | |
| actor_mean = nn.Dense( | |
| 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | |
| )(x) | |
| actor_mean = activation(actor_mean) | |
| actor_mean = nn.Dense( | |
| 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | |
| )(actor_mean) | |
| actor_mean = activation(actor_mean) | |
| actor_mean = nn.Dense( | |
| self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) | |
| )(actor_mean) | |
| actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) | |
| pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) | |
| critic = nn.Dense( | |
| 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | |
| )(x) | |
| critic = activation(critic) | |
| critic = nn.Dense( | |
| 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | |
| )(critic) | |
| critic = activation(critic) | |
| critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( | |
| critic | |
| ) | |
| return pi, jnp.squeeze(critic, axis=-1) | |
| class Transition(NamedTuple): | |
| done: jnp.ndarray | |
| action: jnp.ndarray | |
| value: jnp.ndarray | |
| reward: jnp.ndarray | |
| log_prob: jnp.ndarray | |
| obs: jnp.ndarray | |
| info: jnp.ndarray | |
| def make_train(config): | |
| config["NUM_UPDATES"] = ( | |
| config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] | |
| ) | |
| config["MINIBATCH_SIZE"] = ( | |
| config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] | |
| ) | |
| env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None | |
| env = LogWrapper(env) | |
| env = ClipAction(env) | |
| env = VecEnv(env) | |
| if config["NORMALIZE_ENV"]: | |
| env = NormalizeVecObservation(env) | |
| env = NormalizeVecReward(env, config["GAMMA"]) | |
| def linear_schedule(count): | |
| frac = ( | |
| 1.0 | |
| - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) | |
| / config["NUM_UPDATES"] | |
| ) | |
| return config["LR"] * frac | |
| def train(rng): | |
| # INIT NETWORK | |
| network = ActorCritic( | |
| env.action_space(env_params).shape[0], activation=config["ACTIVATION"] | |
| ) | |
| rng, _rng = jax.random.split(rng) | |
| init_x = jnp.zeros(env.observation_space(env_params).shape) | |
| network_params = network.init(_rng, init_x) | |
| if config["ANNEAL_LR"]: | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), | |
| optax.adam(learning_rate=linear_schedule, eps=1e-5), | |
| ) | |
| else: | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), | |
| optax.adam(config["LR"], eps=1e-5), | |
| ) | |
| train_state = TrainState.create( | |
| apply_fn=network.apply, | |
| params=network_params, | |
| tx=tx, | |
| ) | |
| # INIT ENV | |
| rng, _rng = jax.random.split(rng) | |
| reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) | |
| obsv, env_state = env.reset(reset_rng, env_params) | |
| # TRAIN LOOP | |
| def _update_step(runner_state, unused): | |
| # COLLECT TRAJECTORIES | |
| def _env_step(runner_state, unused): | |
| train_state, env_state, last_obs, rng = runner_state | |
| # SELECT ACTION | |
| rng, _rng = jax.random.split(rng) | |
| pi, value = network.apply(train_state.params, last_obs) | |
| action = pi.sample(seed=_rng) | |
| log_prob = pi.log_prob(action) | |
| # STEP ENV | |
| rng, _rng = jax.random.split(rng) | |
| rng_step = jax.random.split(_rng, config["NUM_ENVS"]) | |
| obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params) | |
| transition = Transition( | |
| done, action, value, reward, log_prob, last_obs, info | |
| ) | |
| runner_state = (train_state, env_state, obsv, rng) | |
| return runner_state, transition | |
| runner_state, traj_batch = jax.lax.scan( | |
| _env_step, runner_state, None, config["NUM_STEPS"] | |
| ) | |
| # CALCULATE ADVANTAGE | |
| train_state, env_state, last_obs, rng = runner_state | |
| _, last_val = network.apply(train_state.params, last_obs) | |
| def _calculate_gae(traj_batch, last_val): | |
| def _get_advantages(gae_and_next_value, transition): | |
| gae, next_value = gae_and_next_value | |
| done, value, reward = ( | |
| transition.done, | |
| transition.value, | |
| transition.reward, | |
| ) | |
| delta = reward + config["GAMMA"] * next_value * (1 - done) - value | |
| gae = ( | |
| delta | |
| + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae | |
| ) | |
| return (gae, value), gae | |
| _, advantages = jax.lax.scan( | |
| _get_advantages, | |
| (jnp.zeros_like(last_val), last_val), | |
| traj_batch, | |
| reverse=True, | |
| unroll=16, | |
| ) | |
| return advantages, advantages + traj_batch.value | |
| advantages, targets = _calculate_gae(traj_batch, last_val) | |
| # UPDATE NETWORK | |
| def _update_epoch(update_state, unused): | |
| def _update_minbatch(train_state, batch_info): | |
| traj_batch, advantages, targets = batch_info | |
| def _loss_fn(params, traj_batch, gae, targets): | |
| # RERUN NETWORK | |
| pi, value = network.apply(params, traj_batch.obs) | |
| log_prob = pi.log_prob(traj_batch.action) | |
| # CALCULATE VALUE LOSS | |
| value_pred_clipped = traj_batch.value + ( | |
| value - traj_batch.value | |
| ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) | |
| value_losses = jnp.square(value - targets) | |
| value_losses_clipped = jnp.square(value_pred_clipped - targets) | |
| value_loss = ( | |
| 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() | |
| ) | |
| # CALCULATE ACTOR LOSS | |
| ratio = jnp.exp(log_prob - traj_batch.log_prob) | |
| gae = (gae - gae.mean()) / (gae.std() + 1e-8) | |
| loss_actor1 = ratio * gae | |
| loss_actor2 = ( | |
| jnp.clip( | |
| ratio, | |
| 1.0 - config["CLIP_EPS"], | |
| 1.0 + config["CLIP_EPS"], | |
| ) | |
| * gae | |
| ) | |
| loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | |
| loss_actor = loss_actor.mean() | |
| entropy = pi.entropy().mean() | |
| total_loss = ( | |
| loss_actor | |
| + config["VF_COEF"] * value_loss | |
| - config["ENT_COEF"] * entropy | |
| ) | |
| return total_loss, (value_loss, loss_actor, entropy) | |
| grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | |
| total_loss, grads = grad_fn( | |
| train_state.params, traj_batch, advantages, targets | |
| ) | |
| train_state = train_state.apply_gradients(grads=grads) | |
| return train_state, total_loss | |
| train_state, traj_batch, advantages, targets, rng = update_state | |
| rng, _rng = jax.random.split(rng) | |
| batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] | |
| assert ( | |
| batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] | |
| ), "batch size must be equal to number of steps * number of envs" | |
| permutation = jax.random.permutation(_rng, batch_size) | |
| batch = (traj_batch, advantages, targets) | |
| batch = jax.tree_util.tree_map( | |
| lambda x: x.reshape((batch_size,) + x.shape[2:]), batch | |
| ) | |
| shuffled_batch = jax.tree_util.tree_map( | |
| lambda x: jnp.take(x, permutation, axis=0), batch | |
| ) | |
| minibatches = jax.tree_util.tree_map( | |
| lambda x: jnp.reshape( | |
| x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) | |
| ), | |
| shuffled_batch, | |
| ) | |
| train_state, total_loss = jax.lax.scan( | |
| _update_minbatch, train_state, minibatches | |
| ) | |
| update_state = (train_state, traj_batch, advantages, targets, rng) | |
| return update_state, total_loss | |
| update_state = (train_state, traj_batch, advantages, targets, rng) | |
| update_state, loss_info = jax.lax.scan( | |
| _update_epoch, update_state, None, config["UPDATE_EPOCHS"] | |
| ) | |
| train_state = update_state[0] | |
| metric = traj_batch.info | |
| rng = update_state[-1] | |
| if config.get("DEBUG"): | |
| def callback(info): | |
| return_values = info["returned_episode_returns"][info["returned_episode"]] | |
| timesteps = info["timestep"][info["returned_episode"]] * config["NUM_ENVS"] | |
| for t in range(len(timesteps)): | |
| print(f"global step={timesteps[t]}, episodic return={return_values[t]}") | |
| jax.debug.callback(callback, metric) | |
| runner_state = (train_state, env_state, last_obs, rng) | |
| return runner_state, metric | |
| rng, _rng = jax.random.split(rng) | |
| runner_state = (train_state, env_state, obsv, _rng) | |
| runner_state, metric = jax.lax.scan( | |
| _update_step, runner_state, None, config["NUM_UPDATES"] | |
| ) | |
| return {"runner_state": runner_state, "metrics": metric} | |
| return train | |
| if __name__ == "__main__": | |
| config = { | |
| "LR": 3e-4, | |
| "NUM_ENVS": 2048, | |
| "NUM_STEPS": 10, | |
| "TOTAL_TIMESTEPS": 5e7, | |
| "UPDATE_EPOCHS": 4, | |
| "NUM_MINIBATCHES": 32, | |
| "GAMMA": 0.99, | |
| "GAE_LAMBDA": 0.95, | |
| "CLIP_EPS": 0.2, | |
| "ENT_COEF": 0.0, | |
| "VF_COEF": 0.5, | |
| "MAX_GRAD_NORM": 0.5, | |
| "ACTIVATION": "tanh", | |
| "ENV_NAME": "hopper", | |
| "ANNEAL_LR": False, | |
| "NORMALIZE_ENV": True, | |
| "DEBUG": True, | |
| } | |
| rng = jax.random.PRNGKey(30) | |
| train_jit = jax.jit(make_train(config)) | |
| out = train_jit(rng) | |
| plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)) | |
| plt.xlabel("Updates") | |
| plt.ylabel("Return") | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment