Last active
November 20, 2025 18:01
-
-
Save bhatiaabhinav/edb07949471c0ae9e71811146cd46311 to your computer and use it in GitHub Desktop.
A high-quality, truly single-file implementation of PPO -- simple to use, transparent, and dependency-light (only torch and gymnasium). Includes a Lagrange penalty-based constrained-MDP solver and supports both continuous and discrete action spaces. Compatible with RNN policies. Designed for clarity, reproducibility, and research-grade performance.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ----------------------------------------------------------------------------- | |
| # PPO (Proximal Policy Optimization) — High-Quality Single-File Implementation | |
| # Author: Abhinav Bhatia | |
| # Source: https://gist.github.com/bhatiaabhinav/edb07949471c0ae9e71811146cd46311 | |
| # | |
| # Description: | |
| # A high-quality, truly single-file implementation of PPO (Proximal Policy Optimization), | |
| # designed for transparency, simplicity, and research-grade reproducibility. | |
| # | |
| # Key features: | |
| # • Supports both continuous and discrete action spaces. | |
| # • Compatible with feedforward and recurrent (RNN-based) policies. | |
| # • Includes a Lagrange penalty-based constrained-MDP solver. | |
| # • No external dependencies beyond `torch` and `gymnasium`. | |
| # • Compact, well-structured, and easy to extend for experiments. | |
| # • Ideal for benchmarking, academic study, or educational use. | |
| # | |
| # License: | |
| # Licensed under the Creative Commons Attribution 4.0 International License (CC BY 4.0). | |
| # You are free to share and adapt this work, provided appropriate credit is given. | |
| # See: https://creativecommons.org/licenses/by/4.0/ | |
| # | |
| # Citation: | |
| # If you use or build upon this implementation in research or publications, | |
| # please consider citing the following: | |
| # | |
| # @misc{bhatia2025ppo, | |
| # author = {Abhinav Bhatia}, | |
| # title = {PPO (Proximal Policy Optimization) — High-Quality Single-File Implementation}, | |
| # year = {2025}, | |
| # howpublished = {\url{https://gist.github.com/bhatiaabhinav/edb07949471c0ae9e71811146cd46311}} | |
| # } | |
| # | |
| # ----------------------------------------------------------------------------- | |
| import abc | |
| import copy | |
| import os | |
| import shutil | |
| import sys | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import gymnasium | |
| from gymnasium import Env, make, wrappers | |
| from gymnasium.spaces import Box, Discrete | |
| from gymnasium.vector import AsyncVectorEnv, VectorEnv | |
| from torch import Tensor, nn | |
| autoreset_kwarg_samestep_newgymapi = {} | |
| try: | |
| from gymnasium.vector import AutoresetMode | |
| autoreset_kwarg_samestep_newgymapi['autoreset_mode'] = AutoresetMode.SAME_STEP | |
| except ImportError: # older gymnasium version reset logic is already SAME_STEP by default | |
| pass | |
| # --- Utility functions --- | |
| def log_normal_prob(x: Tensor, mean: Tensor, std: Tensor) -> Tensor: | |
| """Compute log probability of x under a normal distribution with given mean and std.""" | |
| var = std ** 2 | |
| log_prob = -0.5 * (((x - mean) ** 2) / var + 2 * torch.log(std + 1e-6) + np.log(2 * np.pi)) | |
| return log_prob.sum(axis=-1, keepdim=True) | |
| def gaussian_entropy(logstd: Tensor) -> Tensor: | |
| """Compute the entropy of a Gaussian distribution given its log standard deviation.""" | |
| entropy = torch.sum(logstd + 0.5 * (1.0 + np.log(2 * np.pi)), dim=-1, keepdim=True) # (batch_size, 1) or (batch_size, seq_len, 1) | |
| return entropy | |
| def normalize_and_update_stats(mean: Tensor, var: Tensor, count: Tensor, x: Tensor, update_stats: bool): | |
| """Normalize input tensor x using running mean/var stats, optionally updating the stats in-place.""" | |
| if update_stats: | |
| batch_mean = x.mean(dim=0) | |
| batch_var = x.var(dim=0, unbiased=False) | |
| batch_count = x.shape[0] | |
| delta = batch_mean - mean | |
| tot_count = count + batch_count | |
| new_mean = mean + delta * batch_count / tot_count | |
| m_a = var * count | |
| m_b = batch_var * batch_count | |
| m_2 = m_a + m_b + torch.square(delta) * (count * batch_count / tot_count) | |
| new_var = m_2 / tot_count | |
| new_count = tot_count | |
| mean.data.copy_(new_mean) | |
| var.data.copy_(new_var) | |
| count.data.copy_(new_count) | |
| x = (x - mean) / torch.sqrt(var + 1e-8) | |
| x = torch.clamp(x, -10.0, 10.0) # avoid extreme values | |
| return x | |
| class SequenceModel(nn.Module): | |
| """ | |
| A model that may contain some unidirectional `nn.RNNBase` layers. This class helps with context management of the RNN hidden states. On forward, it accepts an optional `cache` argument which is a list of hidden states for the RNN layers in the model. It returns the output and the updated cache. If no cache is provided, it initializes the hidden states to zeros. | |
| """ | |
| def __init__(self, model): | |
| super(SequenceModel, self).__init__() | |
| self.model = model | |
| def forward(self, x: Tensor, cache: Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]) -> Tuple[Tensor, Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]]: | |
| """Forward pass through the model, managing RNN hidden states if present. cache[i] is the hidden state for the i-th RNN layer. It may be of any format depending on the RNN layer type. For example, for nn.LSTM, it is a tuple of (h, c). Assumes the input x is of shape (batch_size, seq_len, feature_dims...)""" | |
| assert len(x.shape) >= 3, "Input x must be of shape (batch_size, seq_len, feature_dims...)" | |
| if cache is not None: | |
| assert isinstance(cache, list), "cache must be a list of hidden states for RNN layers." | |
| new_cache = [] | |
| i = 0 | |
| for (layer_id, layer) in enumerate(self.model.children()): | |
| if isinstance(layer, nn.RNNBase): | |
| hidden_state = cache[i] if cache else None | |
| # make sure hidden state matches the batch size and it is on the correct device | |
| if hidden_state is not None: | |
| if isinstance(hidden_state, tuple) and hidden_state[0].dim() == 3: # LSTM | |
| h, c = hidden_state | |
| assert h.size(1) == x.size(0), "Hidden state batch size does not match input batch size. You may need to clear or reset the cache." | |
| assert c.size(1) == x.size(0), "Cell state batch size does not match input batch size. You may need to clear or reset the cache." | |
| hidden_state = (h.to(x.device), c.to(x.device)) | |
| elif isinstance(hidden_state, torch.Tensor) and hidden_state.dim() == 3: # GRU or RNN | |
| hidden_state = hidden_state.to(x.device) | |
| assert hidden_state.size(1) == x.size(0), "Hidden state batch size does not match input batch size. You may need to clear or reset the cache." | |
| else: | |
| raise ValueError("Unknown hidden state format in cache.") | |
| # if it is None, properly initialize it depending on the RNN type | |
| if hidden_state is None: | |
| if isinstance(layer, nn.LSTM): | |
| h_size = layer.proj_size if layer.proj_size > 0 else layer.hidden_size | |
| h_0 = torch.zeros(layer.num_layers, x.size(0), h_size, device=x.device) | |
| c_0 = torch.zeros(layer.num_layers, x.size(0), layer.hidden_size, device=x.device) | |
| hidden_state = (h_0, c_0) | |
| else: # GRU or RNN | |
| h_0 = torch.zeros(layer.num_layers, x.size(0), layer.hidden_size, device=x.device) | |
| hidden_state = h_0 | |
| x, hidden_state = layer(x, hidden_state) | |
| if isinstance(new_cache, list): | |
| new_cache.append(hidden_state) | |
| i += 1 | |
| else: | |
| # reshape x to combine batch and sequence dims to apply non-RNN modules | |
| B, L = x.shape[0], x.shape[1] | |
| x = x.reshape(B * L, *x.shape[2:]) | |
| x = layer(x) | |
| # reshape back | |
| x = x.reshape(B, L, *x.shape[1:]) | |
| return x, new_cache | |
| def clear_cache(cache: Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]], batch_ids: Optional[List[int]] = None) -> Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]: | |
| """Clear the cached hidden states for the given batch indices. If batch_ids is None, clear all hidden states by simply returning None.""" | |
| if cache is None or batch_ids is None: | |
| return None | |
| for hidden_state in cache: | |
| if hidden_state is None: | |
| continue | |
| if isinstance(hidden_state, tuple) and hidden_state[0].dim() == 3: # LSTM | |
| h, c = hidden_state | |
| h[:, batch_ids, :] = 0 # zero out the hidden states for the given batch indices | |
| c[:, batch_ids, :] = 0 # zero out the cell states for the given batch indices | |
| elif isinstance(hidden_state, torch.Tensor) and hidden_state.dim() == 3: # GRU or RNN | |
| hidden_state[:, batch_ids, :] = 0 # zero out the hidden states for the given batch indices | |
| else: | |
| raise ValueError("Unknown hidden state format in cache.") | |
| return cache | |
| def cache_slice_for_batch_ids(cache: Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]], batch_ids: List[int]) -> Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]: | |
| """Extract a slice of the cache for the given batch indices.""" | |
| if cache is None: | |
| return None | |
| new_cache = [] | |
| for hidden_state in cache: | |
| if hidden_state is None: | |
| new_cache.append(None) | |
| continue | |
| if isinstance(hidden_state, tuple) and hidden_state[0].dim() == 3: # LSTM | |
| h, c = hidden_state | |
| h_slice = h[:, batch_ids, :] | |
| c_slice = c[:, batch_ids, :] | |
| new_cache.append((h_slice, c_slice)) | |
| elif isinstance(hidden_state, torch.Tensor) and hidden_state.dim() == 3: # GRU or RNN | |
| h_slice = hidden_state[:, batch_ids, :] | |
| new_cache.append(h_slice) | |
| else: | |
| raise ValueError("Unknown hidden state format in cache.") | |
| return new_cache | |
| def cache_detach(cache: Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]) -> Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]: | |
| """Detach the cached hidden states from the computation graph.""" | |
| if cache is None: | |
| return None | |
| new_cache = [] | |
| for hidden_state in cache: | |
| if hidden_state is None: | |
| new_cache.append(None) | |
| continue | |
| if isinstance(hidden_state, tuple): # LSTM | |
| h, c = hidden_state | |
| new_cache.append((h.detach(), c.detach())) | |
| else: # GRU or RNN | |
| new_cache.append(hidden_state.detach()) | |
| return new_cache | |
| def cache_move_to_device(cache: Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]], device: Union[torch.device, str]) -> Optional[List[Union[Tensor, Tuple[Tensor, Tensor]]]]: | |
| """Move the cached hidden states to the specified device.""" | |
| if cache is None: | |
| return None | |
| new_cache = [] | |
| for hidden_state in cache: | |
| if hidden_state is None: | |
| new_cache.append(None) | |
| continue | |
| if isinstance(hidden_state, tuple): # LSTM | |
| h, c = hidden_state | |
| new_cache.append((h.to(device), c.to(device))) | |
| else: # GRU or RNN | |
| new_cache.append(hidden_state.to(device)) | |
| return new_cache | |
| class ActorBase(nn.Module, abc.ABC): | |
| """ | |
| Base class for the policy (actor) network in PPO. | |
| This class wraps a neural network model for policy representation and provides | |
| utilities for action sampling, log-probability computation, and entropy estimation. | |
| It supports optional observation normalization using running statistics and can | |
| operate in deterministic mode for evaluation. | |
| Attributes: | |
| model: The underlying neural network (nn.Module) for the policy. | |
| observation_space: The Gymnasium observation space (Box). | |
| action_space: The Gymnasium action space (Box or Discrete). | |
| deterministic: If True, actions are deterministic (mean/argmax). | |
| norm_obs: If True, normalize observations using running mean/var. ! Critical note: In POMDPs, this may leak information across timesteps. | |
| obs_mean: Running mean for observation normalization (non-trainable). | |
| obs_var: Running variance for observation normalization (non-trainable). | |
| obs_count: Running count for observation normalization (non-trainable). | |
| Note: Subclasses must implement `sample_action`, `get_entropy`, and `get_logprob`. | |
| """ | |
| def __init__(self, model, observation_space: Box, action_space: Union[Box, Discrete], deterministic: bool = False, device='cpu', norm_obs: bool = False): | |
| super(ActorBase, self).__init__() | |
| self.model = model.to(device) | |
| self.observation_space = observation_space | |
| self.action_space = action_space | |
| self.deterministic = deterministic | |
| self.norm_obs = norm_obs | |
| self.obs_mean = nn.Parameter(torch.zeros(observation_space.shape, dtype=torch.float32, device=device), requires_grad=False) | |
| self.obs_var = nn.Parameter(torch.ones(observation_space.shape, dtype=torch.float32, device=device), requires_grad=False) | |
| self.obs_count = nn.Parameter(torch.tensor(0., dtype=torch.float32, device=device), requires_grad=False) | |
| self.cache = None # for RNN hidden states if needed | |
| def update_obs_stats(self, x: Tensor) -> None: | |
| """Update running normalization statistics for observations using the input batch.""" | |
| if self.norm_obs: | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| _ = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, True) | |
| def forward_model(self, x: Tensor) -> Tensor: | |
| """Forward pass through the policy network, applying optional observation normalization.""" | |
| # if dtype is UInt8, convert to float32 and scale to [0, 1] | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| if isinstance(self.model, SequenceModel): | |
| # assert shape is 3D | |
| assert x.dim() >= 3, "Input tensor must be 3D: (batch_size, seq_len, feature_dims...) for SequenceModel." | |
| # for normalization, flatten batch and seq dims | |
| B, L = x.shape[0], x.shape[1] | |
| x = x.reshape(B * L, *x.shape[2:]) | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| # reshape back | |
| x = x.reshape(B, L, *x.shape[1:]) | |
| x, self.cache = self.model(x, self.cache) | |
| else: | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| x = self.model(x) | |
| return x | |
| def get_subfinal_layer_output(self, x: Tensor) -> Tensor: | |
| """Get the output from the second last layer of the model (useful for auxiliary tasks).""" | |
| if isinstance(self.model, SequenceModel): | |
| raise NotImplementedError("get_subfinal_layer_output is not implemented for SequenceModel.") | |
| if not isinstance(self.model, nn.Sequential): | |
| raise NotImplementedError("get_subfinal_layer_output is only implemented for nn.Sequential models.") | |
| if len(list(self.model.children())) < 2: | |
| raise ValueError("Model must have at least two layers to get subfinal layer output.") | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| layers = list(self.model.children()) | |
| for layer in layers[:-1]: | |
| x = layer(x) | |
| return x | |
| @abc.abstractmethod | |
| def sample_action(self, x: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| """Sample actions from the policy given observations.""" | |
| raise NotImplementedError | |
| @abc.abstractmethod | |
| def get_entropy(self, x) -> Tensor: | |
| """Compute the entropy of the policy distribution for given observations.""" | |
| raise NotImplementedError | |
| @abc.abstractmethod | |
| def get_logprob(self, x: Tensor, action: Tensor) -> Tensor: | |
| """Compute the log-probability of given actions under the policy for observations.""" | |
| raise NotImplementedError | |
| def get_kl_div(self, obs, actions, old_logprobs): | |
| """Approximate KL divergence between old and new policies using log-ratio.""" | |
| log_probs = self.get_logprob(obs, actions) | |
| log_ratio = log_probs - old_logprobs | |
| kl = torch.mean(torch.exp(log_ratio) - 1 - log_ratio) | |
| # kl = torch.mean(old_logprobs - log_probs) | |
| return kl | |
| def get_policy_loss_and_entropy(self, obs, actions, advantages, old_logprobs, clip_ratio): | |
| """Compute the clipped PPO policy loss and mean entropy.""" | |
| log_probs = self.get_logprob(obs, actions) | |
| ratio = torch.exp(log_probs - old_logprobs) | |
| surrogate1 = ratio * advantages | |
| surrogate2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages | |
| policy_loss = -torch.mean(torch.min(surrogate1, surrogate2)) | |
| entropy = self.get_entropy(obs).mean() | |
| return policy_loss, entropy | |
| def forward(self, x) -> Any: | |
| """Convenience forward pass: map observations to actions (handles numpy/torch, batching).""" | |
| # if input is numpy array, convert to torch tensor | |
| input_is_numpy = isinstance(x, np.ndarray) | |
| added_time_dim = False | |
| if input_is_numpy: | |
| if (x.dtype == np.uint8): | |
| x = x.astype(np.float32) / 255.0 | |
| x = torch.as_tensor(x, dtype=torch.float32, device=next(self.model.parameters()).device) | |
| if isinstance(self.model, SequenceModel) and len(x.shape) == len(self.observation_space.shape) + 1: | |
| x = x.unsqueeze(1) # Add seq_len=1 for batched single-step | |
| added_time_dim = True | |
| # if input is not batched, add a batch dimension | |
| input_is_unbatched = len(x.shape) == len(self.observation_space.shape) | |
| if input_is_unbatched: | |
| x = x.unsqueeze(0) | |
| if isinstance(self.model, SequenceModel): | |
| x = x.unsqueeze(1) # add seq_len dimension if RNN | |
| added_time_dim = True | |
| # get action | |
| action, _ = self.sample_action(x) | |
| if added_time_dim: | |
| action = action.squeeze(1) # remove seq_len dim if RNN | |
| # if input was unbatched, remove the batch dimension | |
| if input_is_unbatched: | |
| action = action.squeeze(0) | |
| # if discrete action space, convert to int | |
| if isinstance(self.action_space, Discrete): | |
| action = action.item() | |
| # if input was numpy, convert output to numpy | |
| if input_is_numpy and isinstance(action, torch.Tensor): | |
| action = action.detach().cpu().numpy() | |
| return action | |
| def clear_cache(self, batch_ids: Optional[List[int]] = None) -> None: | |
| """Clear the cached hidden states in the underlying model if it is a SequenceModel. If batch_ids is provided, only clear the hidden states for those batch indices.""" | |
| if isinstance(self.model, SequenceModel): | |
| self.cache = clear_cache(self.cache, batch_ids) | |
| def evaluate_policy(self, env: Env, num_episodes=100, deterministic=True, seed=0) -> Tuple[List[float], List[float], List[int], List[bool]]: | |
| """Evaluate the policy by rolling out episodes in a single environment. Default seed is 0. Returns lists of episode rewards, episode costs, episode lengths, and episode successes.""" | |
| self.clear_cache() | |
| self.eval() | |
| deterministic_before = self.deterministic | |
| self.deterministic = deterministic | |
| episode_rewards = [] | |
| episode_costs = [] | |
| episode_lengths = [] | |
| episode_successes = [] | |
| for episode in range(num_episodes): | |
| print(f"Evaluating episode {episode + 1}/{num_episodes}", end='\r') | |
| obs, _ = env.reset(seed=seed) | |
| done = False | |
| total_reward = 0.0 | |
| total_cost = 0.0 | |
| length = 0 | |
| while not done: | |
| action = self(obs) | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| total_reward += reward # type: ignore | |
| cost = info.get("cost", 0.0) | |
| total_cost += cost # type: ignore | |
| length += 1 | |
| if done: | |
| episode_successes.append(info.get("is_success", False)) | |
| # no need to clear cache here; the RNN will automatically learn to reset based on an assumed flag in state-space that marks beginning of episodes. | |
| episode_rewards.append(total_reward) | |
| episode_costs.append(total_cost) | |
| episode_lengths.append(length) | |
| print() | |
| self.deterministic = deterministic_before # undo | |
| return episode_rewards, episode_costs, episode_lengths, episode_successes | |
| def evaluate_policy_parallel(self, envs: VectorEnv, num_episodes=100, deterministic=True, base_seed=0) -> Tuple[List[float], List[float], List[int], List[bool]]: | |
| """Evaluate the policy by rolling out episodes in a vectorized environment. Default base_seed is 0, the envs will be seeded: base_seed, base_seed+1, base_seed+2, ... etc. Returns lists of episode rewards, costs, lengths, and success flags.""" | |
| self.clear_cache() | |
| self.eval() | |
| deterministic_before = self.deterministic | |
| self.deterministic = deterministic | |
| episode_rewards = [] | |
| episode_costs = [] | |
| episode_lengths = [] | |
| episode_successes = [] | |
| episode_rew_vec = np.zeros(envs.num_envs) | |
| episode_cost_vec = np.zeros(envs.num_envs) | |
| episode_len_vec = np.zeros(envs.num_envs, dtype=int) | |
| num_envs = envs.num_envs | |
| obs = envs.reset(seed=base_seed)[0] # Automatically seeds each env with base_seed + env_index | |
| while len(episode_rewards) < num_episodes: | |
| action = self(obs) | |
| action = action if isinstance(envs.single_action_space, Box) else action.squeeze(-1) | |
| obs, rewards, terminateds, truncateds, infos = envs.step(action) | |
| episode_rew_vec += rewards | |
| episode_cost_vec += infos.get('cost', np.zeros(envs.num_envs)) | |
| episode_len_vec += 1 | |
| for i in range(num_envs): | |
| if terminateds[i] or truncateds[i]: | |
| episode_rewards.append(episode_rew_vec[i]) | |
| episode_costs.append(episode_cost_vec[i]) | |
| episode_lengths.append(episode_len_vec[i]) | |
| # extract is_success from final_info if available | |
| if 'final_info' in infos and 'is_success' in infos['final_info']: | |
| episode_successes.append(infos['final_info']['is_success'][i]) # type: ignore | |
| elif 'final_info' in infos and isinstance(infos['final_info'], list) and 'is_success' in infos['final_info'][i]: # old gymnasium version | |
| episode_successes.append(infos['final_info'][i]['is_success']) # type: ignore | |
| else: | |
| episode_successes.append(False) | |
| # extract last step cost from final_info if available | |
| if 'final_info' in infos and 'cost' in infos['final_info']: | |
| episode_cost_vec[i] += infos['final_info']['cost'][i] # type: ignore | |
| elif 'final_info' in infos and isinstance(infos['final_info'], list) and 'cost' in infos['final_info'][i]: | |
| episode_cost_vec[i] += infos['final_info'][i]['cost'] # type: ignore | |
| episode_rew_vec[i] = 0.0 | |
| episode_cost_vec[i] = 0.0 | |
| # no need to clear cache here; the RNN will automatically learn to reset based on an assumed flag in state-space that marks beginning of episodes. | |
| print(f"Evaluating episodes {len(episode_rewards)}/{num_episodes}", end='\r') | |
| if len(episode_rewards) >= num_episodes: | |
| break | |
| print() | |
| self.deterministic = deterministic_before # undo | |
| return episode_rewards, episode_costs, episode_lengths, episode_successes | |
| class ActorContinuous(ActorBase): | |
| """ | |
| Continuous action policy using a squashed Gaussian distribution. | |
| Outputs a mean from the network; samples from N(mean, exp(logstd)), applies tanh | |
| squashing to [-1, 1], and scales/shifts to match the action space bounds. | |
| Log-probabilities are adjusted for the tanh transformation and scaling. | |
| """ | |
| def __init__(self, model, observation_space: Box, action_space: Box, deterministic: bool = False, device='cpu', norm_obs: bool = False): | |
| super(ActorContinuous, self).__init__(model, observation_space, action_space, deterministic, device, norm_obs) | |
| assert np.all(np.isfinite(action_space.low)) and np.all(np.isfinite(action_space.high)) | |
| self.logstd = nn.Parameter(torch.zeros(action_space.shape[0], dtype=torch.float32, device=device), requires_grad=True) | |
| # shift/scale to affine map from [-1, 1] to [low, high] => a = tanh(raw) * scale + shift: | |
| self.shift = nn.Parameter(torch.tensor(action_space.low + action_space.high, dtype=torch.float32, device=device) / 2.0, requires_grad=False) | |
| self.scale = nn.Parameter(torch.tensor((action_space.high - action_space.low) / 2.0, dtype=torch.float32, device=device), requires_grad=False) | |
| def sample_action(self, x: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| """Sample continuous actions: reparameterize, tanh-squash, and scale to bounds.""" | |
| mean = self._get_mean(x) | |
| logstd = self._get_logstd(x) | |
| std = torch.exp(logstd) | |
| # Reparameterized sample with noise, then tanh-squash and scale to bounds. | |
| noise = torch.randn_like(mean) | |
| action = mean + std * noise | |
| log_prob = log_normal_prob(action, mean, std) | |
| action = torch.tanh(action) | |
| # Subtract log|det(Jacobian)| for tanh so log_prob matches final action space. | |
| log_prob -= torch.sum(torch.log(1 - action.pow(2) + 1e-6), dim=-1, keepdim=True) | |
| action = action * self.scale + self.shift | |
| info = {'logstd': logstd, 'log_prob': log_prob} | |
| return action, info | |
| def get_entropy(self, x): | |
| """Compute entropy of the squashed Gaussian policy.""" | |
| logstd = self._get_logstd(x) | |
| return gaussian_entropy(logstd) | |
| def get_logprob(self, x: Tensor, action: Tensor) -> Tensor: | |
| """Compute log-prob of actions under the squashed Gaussian policy.""" | |
| mean = self._get_mean(x) | |
| logstd = self._get_logstd(x) | |
| std = torch.exp(logstd) | |
| # Rescale action to [-1, 1] | |
| action_unshifted_unscaled = (action - self.shift) / (self.scale + 1e-6) | |
| action_unshifted_unscaled = torch.clamp(action_unshifted_unscaled, -0.999, 0.999) # otherwise atanh(1) is inf | |
| action_untanhed = torch.atanh(action_unshifted_unscaled) | |
| log_prob = log_normal_prob(action_untanhed, mean, std) | |
| log_prob -= torch.sum(torch.log(1 - action_unshifted_unscaled.pow(2) + 1e-6), dim=-1, keepdim=True) | |
| return log_prob | |
| def _get_mean(self, x: Tensor) -> Tensor: | |
| """Get the policy mean from the network.""" | |
| mean = self.forward_model(x) | |
| return mean | |
| def _get_logstd(self, x: Tensor) -> Tensor: | |
| """Get the log-standard deviation (-inf in deterministic mode; clamped for stability).""" | |
| logstd = self.logstd | |
| if self.deterministic: | |
| logstd = logstd - torch.inf | |
| logstd = torch.clamp(logstd, -20, 2) # to avoid numerical issues | |
| return logstd | |
| class ActorDiscrete(ActorBase): | |
| """ | |
| Discrete action policy using a categorical distribution. | |
| Outputs logits from the network, applies softmax for probabilities. | |
| In deterministic mode, logits are inflated to approximate argmax. | |
| """ | |
| def __init__(self, model, observation_space: Box, action_space: Discrete, deterministic: bool = False, device='cpu', norm_obs: bool = False): | |
| super(ActorDiscrete, self).__init__(model, observation_space, action_space, deterministic, device, norm_obs) | |
| assert action_space.n >= 2, "Action space must have at least 2 discrete actions." | |
| def sample_action(self, x: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| """Sample discrete actions from the categorical distribution.""" | |
| probs_all, _ = self._get_probs_logprobs_all(x) | |
| # sample now | |
| action_dist = torch.distributions.Categorical(probs=probs_all) | |
| action = action_dist.sample() | |
| log_prob = action_dist.log_prob(action) | |
| action = action.unsqueeze(-1) # make it (batch_size, 1) or (batch_size, seq_len, 1) | |
| log_prob = log_prob.unsqueeze(-1) # make it (batch_size, 1) or (batch_size, seq_len, 1) | |
| return action, {'log_prob': log_prob} | |
| def get_entropy(self, x): | |
| """Compute entropy of the categorical policy.""" | |
| probs_all, logprobs_all = self._get_probs_logprobs_all(x) | |
| entropy = -torch.sum(probs_all * logprobs_all, dim=-1) | |
| return entropy | |
| def get_logprob(self, x: Tensor, action: Tensor) -> Tensor: | |
| """Compute log-prob of discrete actions under the categorical policy.""" | |
| _, logprobs_all = self._get_probs_logprobs_all(x) | |
| return logprobs_all.gather(-1, action) | |
| def _get_probs_logprobs_all(self, x: Tensor) -> Tuple[Tensor, Tensor]: | |
| """Get softmax probabilities and log-probs over all actions.""" | |
| logits = self.forward_model(x) | |
| logits = logits - logits.max(dim=-1, keepdim=True).values # Stabilize logits by subtracting max (per batch row) before softmax/log_softmax. | |
| if self.deterministic: | |
| logits = logits * 1e6 # make it very large to approximate argmax | |
| probs_all = F.softmax(logits, dim=-1) # (batch_size, action_dim) or (batch_size, seq_len, action_dim) | |
| logprobs_all = F.log_softmax(logits, dim=-1) # (batch_size, action_dim) or (batch_size, seq_len, action_dim) | |
| return probs_all, logprobs_all | |
| def Actor(model, observation_space: Box, action_space: Union[Box, Discrete], deterministic: bool = False, device='cpu', norm_obs: bool = False) -> ActorBase: | |
| """Factory function to create an appropriate Actor subclass based on the action space.""" | |
| # raise a warning if norm_obs is True and the model is a SequenceModel | |
| if norm_obs and isinstance(model, SequenceModel): | |
| print("Warning: Observation normalization is enabled for a SequenceModel. In POMDPs, this may leak information across timesteps.", file=sys.stderr) | |
| if isinstance(action_space, Box): | |
| return ActorContinuous(model, observation_space, action_space, deterministic, device, norm_obs) | |
| elif isinstance(action_space, Discrete): | |
| return ActorDiscrete(model, observation_space, action_space, deterministic, device, norm_obs) | |
| else: | |
| raise NotImplementedError("Only Box and Discrete action spaces are supported.") | |
| class Critic(nn.Module): | |
| """ | |
| Value function (critic) network for state-value estimation V(s). | |
| Estimates the expected discounted return from a given state under the current policy. | |
| Supports optional observation normalization using running statistics. | |
| Attributes: | |
| model: The underlying neural network (nn.Module) for value prediction. | |
| norm_obs: If True, normalize observations using running mean/var. ! Critical note: In POMDPs, this may leak information across timesteps. | |
| obs_mean: Running mean for observation normalization (non-trainable). | |
| obs_var: Running variance for observation normalization (non-trainable). | |
| obs_count: Running count for observation normalization (non-trainable). | |
| """ | |
| def __init__(self, model, observation_space: Box, device='cpu', norm_obs: bool = False): | |
| super(Critic, self).__init__() | |
| self.model = model | |
| self.norm_obs = norm_obs | |
| self.obs_mean = nn.Parameter(torch.zeros(observation_space.shape, dtype=torch.float32, device=device), requires_grad=False) | |
| self.obs_var = nn.Parameter(torch.ones(observation_space.shape, dtype=torch.float32, device=device), requires_grad=False) | |
| self.obs_count = nn.Parameter(torch.tensor(0., dtype=torch.float32, device=device), requires_grad=False) | |
| self.cache = None # for RNN hidden states if needed | |
| if norm_obs and isinstance(model, SequenceModel): | |
| print("Warning: Observation normalization is enabled for a SequenceModel. In POMDPs, this may leak information across timesteps.", file=sys.stderr) | |
| def update_obs_stats(self, x: Tensor) -> None: | |
| """Update running normalization statistics for observations using the input batch.""" | |
| if self.norm_obs: | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| _ = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, True) | |
| def get_loss(self, states, old_values, advantages): | |
| """Compute the MSE loss for value regression on target returns (old_values + advantages).""" | |
| values = (old_values + advantages).detach() | |
| values_pred = self.get_value(states) | |
| return F.mse_loss(values_pred, values) | |
| def get_value(self, x: Tensor) -> Tensor: | |
| """Predict state values, applying optional observation normalization.""" | |
| # if dtype is UInt8, convert to float32 and scale to [0, 1] | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| if isinstance(self.model, SequenceModel): | |
| # assert shape is 3D | |
| assert x.dim() >= 3, "Input tensor must be 3D: (batch_size, seq_len, feature_dims...) for SequenceModel." | |
| # for normalization, flatten batch and seq dims | |
| B, L = x.shape[0], x.shape[1] | |
| x = x.reshape(B * L, *x.shape[2:]) | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| # reshape back | |
| x = x.reshape(B, L, *x.shape[1:]) | |
| x, self.cache = self.model(x, self.cache) | |
| else: | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| x = self.model(x) | |
| return x | |
| def get_subfinal_layer_output(self, x: Tensor) -> Tensor: | |
| """Get the output of the layer before the final output layer (for use in e.g., cost critics).""" | |
| if isinstance(self.model, SequenceModel): | |
| raise NotImplementedError("get_subfinal_layer_output is not implemented for SequenceModel.") | |
| else: | |
| # assuming model is nn.Sequential | |
| if not isinstance(self.model, nn.Sequential): | |
| raise NotImplementedError("get_subfinal_layer_output is only implemented for nn.Sequential models.") | |
| if len(self.model) < 2: | |
| raise ValueError("Model must have at least 2 layers to get subfinal layer output.") | |
| # if dtype is UInt8, convert to float32 and scale to [0, 1] | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| if self.norm_obs: | |
| x = normalize_and_update_stats(self.obs_mean, self.obs_var, self.obs_count, x, False) | |
| for layer in list(self.model.children())[:-1]: | |
| x = layer(x) | |
| return x | |
| def clear_cache(self, batch_ids: Optional[List[int]] = None) -> None: | |
| """Clear the cached hidden states in the underlying model if it is a SequenceModel. If batch_ids is provided, only clear the hidden states for those batch indices.""" | |
| if isinstance(self.model, SequenceModel): | |
| self.cache = clear_cache(self.cache, batch_ids) | |
| class PPO: | |
| """ | |
| Proximal Policy Optimization (PPO) trainer. | |
| This class implements the PPO algorithm, including trajectory collection, advantage | |
| estimation via GAE, and clipped policy/value updates. It supports vectorized | |
| environments in SAME_STEP autoreset mode for efficient data collection. | |
| Key Features: | |
| - Supports continuous and discrete actions via Actor subclasses. | |
| - Optional Constrained MDP (CMDP) mode with Lagrange multiplier for cost constraints. | |
| - Separate devices for inference (rollouts) and training (updates) for performance. | |
| - Comprehensive logging of metrics (returns, losses, KL, etc.) for analysis. | |
| - Optional annealing of hyperparameters (LR, entropy coef, clip ratio). | |
| - Normalization of observations, rewards, and advantages for stability. | |
| Usage: | |
| 1. Create a vectorized environment e.g., AsyncVectorEnv with SAME_STEP autoreset (skip AutoresetMode if using older gymnasium versions). | |
| 2. Define actor and critic networks (nn.Module). | |
| 3. Instantiate PPO with envs, actor, critic, and hyperparameters. | |
| 4. Call `train(num_iterations)` to run training loops. | |
| 5. Access `stats` dict for logged metrics (e.g., mean_returns, losses). | |
| CMDP Mode: | |
| - Requires a cost_critic and cost_threshold > 0. | |
| - Uses Lagrange multiplier to penalize cost advantages in policy updates. | |
| - Supports constraining discounted or undiscounted costs via `constrain_undiscounted_cost`. | |
| Example: | |
| envs = AsyncVectorEnv([...], autoreset_mode=AutoresetMode.SAME_STEP) # skip AutoresetMode if using older gymnasium versions | |
| actor = Actor(mlp_model, obs_space, act_space, norm_obs=True) | |
| critic = Critic(mlp_model, obs_space, norm_obs=True) | |
| ppo = PPO(envs, actor, critic, iters=1000, gamma=0.99, clip_ratio=0.2) | |
| ppo.train(1000) | |
| plt.plot(ppo.stats['mean_returns']) | |
| """ | |
| def __init__(self, | |
| envs: VectorEnv, | |
| actor: ActorBase, | |
| critic: Critic, | |
| iters: int, | |
| cmdp_mode: bool = False, | |
| cost_critic: Optional[Critic] = None, # required if cmdp_mode is True | |
| cost_threshold: float = 0.0, # must be > 0 if cmdp_mode is True | |
| constrain_undiscounted_cost: bool = False, # if True, use undiscounted cost to update lagrange multiplier, else use discounted cost | |
| lagrange_lr: float = 0.01, | |
| lagrange_max: float = 1000.0, | |
| moving_cost_estimate_step_size: float = 0.01, # step size for moving average estimation of cost per episode. Formula: new_estimate = (1-step_size)*old_estimate + step_size*new_observation. Whether the estimate is discounted or undiscounted depends on constrain_undiscounted_cost. | |
| actor_lr: float = 3e-4, | |
| critic_lr: float = 3e-4, | |
| decay_lr: bool = False, | |
| min_lr: float = 1.25e-5, | |
| gamma: float = 0.99, | |
| lam: float = 0.95, | |
| nsteps: int = 2048, | |
| nepochs: int = 10, | |
| batch_size: int = 64, | |
| context_length: Optional[int] = None, # if using RNNs, the context length for BPTT. It must be less than or equal to nsteps as well as less than or equal to batch_size. If None, set to minimum of nsteps and batch_size. | |
| clip_ratio: float = 0.2, | |
| decay_clip_ratio: bool = False, | |
| min_clip_ratio: float = 0.01, | |
| target_kl: float = np.inf, | |
| early_stop_critic: bool = False, | |
| value_loss_coef: float = 0.5, | |
| entropy_coef: float = 0.0, | |
| decay_entropy_coef: bool = False, | |
| normalize_advantages: bool = True, | |
| clipnorm: float = 0.5, | |
| norm_rewards: bool = False, | |
| adam_weight_decay: float = 0.0, | |
| adam_epsilon: float = 1e-7, | |
| base_seed: int = 42, # base seed for envs, the envs will be seeded: base_seed, base_seed+1, base_seed+2, ... etc. | |
| inference_device: str = 'cpu', | |
| training_device: str = 'cpu', | |
| custom_info_keys_to_log_at_episode_end: List[str] = [],): # keys in info dict to log at episode end | |
| """ | |
| Initialize the PPO trainer with environments, networks, and hyperparameters. | |
| Args: | |
| envs: Vectorized environment (must use SAME_STEP autoreset mode). | |
| actor: Policy network (ActorBase instance). | |
| critic: Value network (Critic instance). | |
| iters: Total number of training iterations. | |
| cmdp_mode: If True, enable Constrained MDP mode with cost constraints. | |
| cost_critic: Cost value network (required if cmdp_mode=True). | |
| cost_threshold: Target average cost per episode (required if cmdp_mode=True). | |
| constrain_undiscounted_cost: If True, constrain undiscounted costs for Lagrange update. | |
| lagrange_lr: Learning rate for Lagrange multiplier update. | |
| lagrange_max: Maximum value for Lagrange multiplier. | |
| moving_cost_estimate_step_size: Step size for moving average estimation of cost per episode. Formula: new_estimate = (1-step_size) x old_estimate + step_size x new_observation. Whether the estimate is discounted or undiscounted depends on constrain_undiscounted_cost. | |
| actor_lr: Initial learning rate for actor optimizer. | |
| critic_lr: Initial learning rate for critic optimizer. | |
| decay_lr: If True, anneal learning rates linearly over iterations. | |
| min_lr: Minimum learning rate after annealing. | |
| gamma: Discount factor for returns and advantages. | |
| lam: GAE lambda for advantage estimation. | |
| nsteps: Steps per trajectory rollout (per env). | |
| nepochs: Number of SGD epochs per iteration. | |
| batch_size: Minibatch size for SGD. | |
| context_length: For RNNs, the context length for BPTT. Must be <= nsteps and <= batch_size. If None, set to min(nsteps, batch_size). | |
| clip_ratio: PPO clipping parameter (epsilon). | |
| decay_clip_ratio: If True, anneal clip ratio linearly. | |
| min_clip_ratio: Minimum clip ratio after annealing. | |
| target_kl: Target KL divergence for early stopping. | |
| early_stop_critic: If True, stop critic updates when actor early-stops. | |
| value_loss_coef: Coefficient for value loss in total objective. | |
| entropy_coef: Initial coefficient for entropy regularization. | |
| decay_entropy_coef: If True, anneal entropy coef linearly. | |
| normalize_advantages: If True, normalize advantages before policy update. | |
| clipnorm: Gradient norm clipping value (0.5 by default). | |
| norm_rewards: If True, normalize rewards by running std of discounted returns. | |
| adam_weight_decay: Weight decay for Adam optimizers. | |
| adam_epsilon: Epsilon for Adam optimizers. | |
| inference_device: Device for rollout/inference (e.g., 'cpu'). | |
| base_seed: Base seed for envs; envs will be seeded sequentially at the first reset call e.g., base_seed, base_seed+1, ... | |
| training_device: Device for parameter updates (e.g., 'cuda'). | |
| custom_info_keys_to_log_at_episode_end: List of env info keys to log per episode. Their 100-window averages will be tracked in stats and printed during training. | |
| Raises: | |
| AssertionError: If envs not in SAME_STEP mode (in recent gymnasium versions, older versions are usually already similar to SAME_STEP), or CMDP params invalid. | |
| """ | |
| try: | |
| from gymnasium.vector import AutoresetMode | |
| assert envs.metadata["autoreset_mode"] == AutoresetMode.SAME_STEP, "VectorEnv must be in SAME_STEP autoreset mode" | |
| except (KeyError, ImportError): | |
| print("Warning: Could not verify that VectorEnv is in SAME_STEP autoreset mode. Ensure that envs are created with autoreset_mode=AutoresetMode.SAME_STEP in recent gymnasium versions. The reset logic in older versions is usually already similar to SAME_STEP.", file=sys.stderr) | |
| self.envs = envs | |
| self.actor = actor.to(training_device) | |
| self.critic = critic.to(training_device) | |
| self.cmdp_mode = cmdp_mode | |
| if self.cmdp_mode: | |
| assert cost_critic is not None, "cost_critic must be provided in CMDP mode" | |
| assert cost_threshold > 0.0, "cost_threshold must be > 0 in CMDP mode" | |
| if not normalize_advantages: | |
| print("Warning: It is recommended to normalize advantages when using CMDP mode so that the scale of reward and cost advantages are similar before combining them.", file=sys.stderr) | |
| self.cost_critic = cost_critic.to(training_device) | |
| self.iters = iters | |
| self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr, weight_decay=adam_weight_decay, eps=adam_epsilon) | |
| self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr, weight_decay=adam_weight_decay, eps=adam_epsilon) | |
| if self.cmdp_mode: | |
| self.cost_optimizer = torch.optim.Adam(self.cost_critic.parameters(), lr=critic_lr, weight_decay=adam_weight_decay, eps=adam_epsilon) | |
| self.actor_lr = actor_lr | |
| self.critic_lr = critic_lr | |
| self.decay_lr = decay_lr | |
| self.min_lr = min_lr | |
| self.gamma = gamma | |
| self.lam = lam | |
| self.nsteps = nsteps | |
| self.nepochs = nepochs | |
| self.batch_size = batch_size | |
| assert (nsteps * envs.num_envs) % batch_size == 0, "nsteps * num_envs must be divisible by batch_size" | |
| self.context_length = context_length if context_length is not None else min(nsteps, batch_size) | |
| self.is_rucur = isinstance(self.actor.model, SequenceModel) or isinstance(self.critic.model, SequenceModel) or (self.cmdp_mode and isinstance(self.cost_critic.model, SequenceModel)) | |
| if self.is_rucur: | |
| # ensure all models are SequenceModel | |
| assert isinstance(self.actor.model, SequenceModel) and isinstance(self.critic.model, SequenceModel), "All actor and critic models must be SequenceModel when using RNNs" | |
| if self.cmdp_mode: | |
| assert isinstance(self.cost_critic.model, SequenceModel), "All actor and critic models must be SequenceModel when using RNNs" | |
| # assert context_length is valid | |
| assert self.context_length <= self.nsteps, "context_length must be less than or equal to nsteps" | |
| # ensure context_length divides nsteps for simplicity | |
| assert self.nsteps % self.context_length == 0, "When using RNNs, nsteps must be divisible by context_length for simplicity" | |
| # ensure batch_size divides context_length times num_envs | |
| assert (self.context_length * envs.num_envs) % self.batch_size == 0, "When using RNNs, context_length * num_envs must be divisible by batch_size for simplicity" | |
| self.clip_ratio = clip_ratio | |
| self.decay_clip_ratio = decay_clip_ratio | |
| self.min_clip_ratio = min_clip_ratio | |
| self.target_kl = target_kl | |
| self.early_stop_critic = early_stop_critic | |
| self.value_loss_coef = value_loss_coef | |
| self.entropy_coef = entropy_coef | |
| self.decay_entropy_coef = decay_entropy_coef | |
| self.normalize_advantages = normalize_advantages | |
| self.clipnorm = clipnorm | |
| self.norm_rewards = norm_rewards | |
| self.base_seed = base_seed | |
| self.inference_device = inference_device | |
| self.training_device = training_device | |
| self.custom_info_keys_to_log_at_episode_end = custom_info_keys_to_log_at_episode_end | |
| self.cost_threshold = cost_threshold | |
| self.constrain_undiscounted_cost = constrain_undiscounted_cost | |
| self.lagrange_lr = lagrange_lr | |
| self.lagrange_max = lagrange_max | |
| self.moving_cost_estimate_step_size = moving_cost_estimate_step_size | |
| self.lagrange = 0.0 # Lagrange multiplier | |
| self.moving_average_cost = 0.0 # moving average estimate of cost per episode, discounted or undiscounted depending on constrain_undiscounted_cost | |
| # stats | |
| self.stats = { | |
| 'iterations': 0, | |
| 'total_timesteps': 0, | |
| 'total_episodes': 0, | |
| 'timestepss': [], # after each iteration | |
| 'episodess': [], # after each iteration | |
| 'losses': [], # after each iteration | |
| 'actor_losses': [], # after each iteration | |
| 'critic_losses': [], # after each iteration | |
| 'cost_losses': [], # after each iteration | |
| 'entropies': [], # after each iteration | |
| 'kl_divs': [], # after each iteration | |
| 'values': [], # mean value of states in the batch after each iteration | |
| 'cost_values': [], # mean cost value of states in the batch after each iteration | |
| 'logprobs': [], # mean logprob of actions in the batch after each iteration (its negative is the true mean entropy of the actions in the batch) | |
| 'logstds': [], # mean logstd of actions in the batch after each iteration | |
| 'actor_lrs': [], # adjusted per iteration | |
| 'critic_lrs': [], # adjusted per iteration | |
| 'entropy_coefs': [], # adjusted per iteration | |
| 'clip_ratios': [], # adjusted per iteration | |
| 'lagranges': [], # after each iteration | |
| 'returns': [], # return of each trajectory collected so far | |
| 'discounted_returns': [], # discounted return of each trajectory collected so far | |
| 'discounted_costs': [], # discounted cost of each trajectory collected so far | |
| 'lengths': [], # length of each trajectory collected so far | |
| 'total_costs': [], # total cost of each trajectory collected so far (if env provides this info) | |
| 'successes': [], # success of each trajectory collected so far (if env provides this info) | |
| 'mean_returns': [], # mean return of latest 100 trajectories after each iteration | |
| 'mean_discounted_returns': [], # mean discounted return of latest 100 trajectories after each iteration | |
| 'mean_discounted_costs': [], # mean discounted cost of latest 100 trajectories after each iteration | |
| 'mean_lengths': [], # mean length of latest 100 trajectories after each iteration | |
| 'mean_total_costs': [], # mean total cost of latest 100 trajectories after each iteration (if env provides this info) | |
| 'mean_successes': [], # mean success of latest 100 trajectories after each iteration (if env provides this info) | |
| } | |
| self.stats['info_keys'] = {key: [] for key in (self.custom_info_keys_to_log_at_episode_end)} | |
| self.stats['info_key_means'] = {key: [] for key in (self.custom_info_keys_to_log_at_episode_end)} # mean of latest 100 episodes after each iteration | |
| state_shape: Tuple[int, ...] = envs.single_observation_space.shape # type: ignore | |
| if envs.single_observation_space.dtype == np.uint8: | |
| obs_dtype = torch.uint8 | |
| else: | |
| obs_dtype = torch.float32 | |
| if isinstance(envs.single_action_space, Discrete): | |
| action_dim = 1 | |
| action_dtype = torch.int64 | |
| elif isinstance(envs.single_action_space, Box): | |
| action_dim = envs.single_action_space.shape[0] # type: ignore | |
| action_dtype = torch.float32 | |
| else: | |
| raise NotImplementedError("Only Box and Discrete action spaces are supported.") | |
| M, N = self.nsteps, self.envs.num_envs | |
| # Experience buffers (shape: [N_envs, nsteps, ...]) for obs, actions, logprobs, | |
| # rewards/costs, done flags, and advantages/values. Flattened later for SGD. | |
| self.obs_buf = torch.zeros((N, M, *state_shape), dtype=obs_dtype, device=self.training_device) | |
| self.actions_buf = torch.zeros((N, M, action_dim), dtype=action_dtype, device=self.training_device) | |
| self.log_probs_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.rewards_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.costs_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.terminateds_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.truncateds_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.values_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.cost_values_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.advantages_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| self.cost_advantages_buf = torch.zeros((N, M, 1), dtype=torch.float32, device=self.training_device) | |
| # data buffers for collecting trajectories, if using a different device for inference | |
| if self.inference_device != self.training_device: | |
| self.obs_buf_inference = self.obs_buf.clone().to(self.inference_device) | |
| self.actions_buf_inference = self.actions_buf.clone().to(self.inference_device) | |
| self.log_probs_buf_inference = self.log_probs_buf.clone().to(self.inference_device) | |
| self.rewards_buf_inference = self.rewards_buf.clone().to(self.inference_device) | |
| self.costs_buf_inference = self.costs_buf.clone().to(self.inference_device) | |
| self.terminateds_buf_inference = self.terminateds_buf.clone().to(self.inference_device) | |
| self.truncateds_buf_inference = self.truncateds_buf.clone().to(self.inference_device) | |
| else: | |
| self.obs_buf_inference = self.obs_buf | |
| self.actions_buf_inference = self.actions_buf | |
| self.log_probs_buf_inference = self.log_probs_buf | |
| self.rewards_buf_inference = self.rewards_buf | |
| self.costs_buf_inference = self.costs_buf | |
| self.terminateds_buf_inference = self.terminateds_buf | |
| self.truncateds_buf_inference = self.truncateds_buf | |
| self.returns_buffer = np.zeros((N,), dtype=np.float32) | |
| self.discounted_returns_buffer = np.zeros((N,), dtype=np.float32) | |
| self.discounted_costs_buffer = np.zeros((N,), dtype=np.float32) | |
| self.lengths_buffer = np.zeros((N,), dtype=np.int32) | |
| self.costs_total_buffer = np.zeros((N,), dtype=np.float32) | |
| # clear caches in actor and critics | |
| self.actor.clear_cache() | |
| self.critic.clear_cache() | |
| if self.cmdp_mode: | |
| self.cost_critic.clear_cache() | |
| # context before this iteration for RNNs | |
| self.actor_inference_context_initial = None | |
| self.critic_context_initial = None | |
| self.cost_critic_context_initial = None | |
| # Optional copy of the actor on the inference device (kept in sync after updates). | |
| if self.inference_device != self.training_device: | |
| self.actor_inference = copy.deepcopy(self.actor).to(self.inference_device) | |
| self.critic_inference = copy.deepcopy(self.critic).to(self.inference_device) | |
| else: | |
| self.actor_inference = self.actor | |
| self.critic_inference = self.critic | |
| # reset all envs and store initial observation | |
| self.obs: torch.Tensor = torch.tensor(envs.reset(seed=self.base_seed)[0], dtype=obs_dtype, device=self.inference_device) # Automatically seeds each env with base_seed + env_index | |
| def collect_trajectories(self) -> None: | |
| """Collect rollout trajectories from the vector environment using the inference actor.""" | |
| self.actor_inference.eval() | |
| self.actor_inference.cache = copy.deepcopy(self.actor_inference_context_initial) # set context for RNNs | |
| for t in range(self.nsteps): | |
| with torch.no_grad(): | |
| # act and step | |
| obs = self.obs.unsqueeze(1) if self.is_rucur else self.obs # add seq_len dimension if RNN | |
| a_t, a_t_info = self.actor_inference.sample_action(obs) | |
| logp_t = a_t_info['log_prob'] | |
| if self.is_rucur: | |
| a_t = a_t.squeeze(1) # remove seq_len dimension | |
| logp_t = logp_t.squeeze(1) # remove seq_len dimension | |
| _a_t = a_t if isinstance(self.envs.single_action_space, Box) else a_t.squeeze(-1) | |
| obs_next_potentially_resetted, r_t, terminated_t, truncated_t, infos = self.envs.step(_a_t.cpu().numpy()) | |
| # store data | |
| self.obs_buf_inference[:, t, ...] = self.obs | |
| self.actions_buf_inference[:, t, :] = a_t | |
| self.log_probs_buf_inference[:, t, :] = logp_t | |
| self.rewards_buf_inference[:, t, 0] = torch.tensor(r_t, dtype=torch.float32, device=self.inference_device) | |
| # If CMDP: env must provide "cost" (either per-step or via "final_info"). | |
| # We aggregate both discounted and undiscounted views for different constraints. | |
| c_t = infos.get('cost', np.zeros(self.envs.num_envs)) | |
| if 'final_info' in infos and 'cost' in infos['final_info']: # some episode ended and vecenv provided final cost info for that env and zeros for others in a vectorized way (newer gymnasium versions) | |
| c_t += infos['final_info']['cost'] | |
| elif 'final_info' in infos and isinstance(infos['final_info'], list): # vecenv provided final info as a list of dicts (older gymnasium versions) | |
| for i in range(self.envs.num_envs): | |
| if (truncated_t[i] or terminated_t[i]) and 'cost' in infos['final_info'][i]: | |
| c_t[i] += infos['final_info'][i]['cost'] | |
| self.costs_buf_inference[:, t, 0] = torch.tensor(c_t, dtype=torch.float32, device=self.inference_device) | |
| self.terminateds_buf_inference[:, t, 0] = torch.tensor(terminated_t, dtype=torch.float32, device=self.inference_device) | |
| self.truncateds_buf_inference[:, t, 0] = torch.tensor(truncated_t, dtype=torch.float32, device=self.inference_device) | |
| # update current obs | |
| self.obs.data.copy_(torch.tensor(obs_next_potentially_resetted, device=self.inference_device), non_blocking=True) | |
| # update returns, lengths, costs | |
| self.discounted_returns_buffer += r_t * np.power(self.gamma, self.lengths_buffer) | |
| self.discounted_costs_buffer += c_t * np.power(self.gamma, self.lengths_buffer) | |
| self.returns_buffer += r_t | |
| self.lengths_buffer += 1 | |
| self.costs_total_buffer += c_t | |
| for i in range(self.envs.num_envs): | |
| if terminated_t[i] or truncated_t[i]: | |
| self.stats['returns'].append(self.returns_buffer[i]) | |
| self.stats['discounted_returns'].append(self.discounted_returns_buffer[i]) | |
| self.stats['discounted_costs'].append(self.discounted_costs_buffer[i]) | |
| self.stats['lengths'].append(self.lengths_buffer[i]) | |
| self.stats['total_costs'].append(self.costs_total_buffer[i]) | |
| if self.cmdp_mode: | |
| # update moving average estimate of cost per episode | |
| if self.constrain_undiscounted_cost: | |
| if (len(self.stats['total_costs']) == 1): | |
| self.moving_average_cost = self.costs_total_buffer[i] | |
| else: | |
| self.moving_average_cost = (1 - self.moving_cost_estimate_step_size) * self.moving_average_cost + self.moving_cost_estimate_step_size * self.costs_total_buffer[i] | |
| else: | |
| if (len(self.stats['discounted_costs']) == 1): | |
| self.moving_average_cost = self.discounted_costs_buffer[i] | |
| else: | |
| self.moving_average_cost = (1 - self.moving_cost_estimate_step_size) * self.moving_average_cost + self.moving_cost_estimate_step_size * self.discounted_costs_buffer[i] | |
| if 'final_info' in infos and 'is_success' in infos['final_info']: | |
| self.stats['successes'].append(infos['final_info']['is_success'][i]) # type: ignore | |
| self.returns_buffer[i] = 0.0 | |
| self.discounted_returns_buffer[i] = 0.0 | |
| self.discounted_costs_buffer[i] = 0.0 | |
| self.lengths_buffer[i] = 0 | |
| self.costs_total_buffer[i] = 0.0 | |
| self.stats['total_episodes'] += 1 | |
| # Don't reset the cache here; the RNN will automatically learn to reset based on an assumed flag in state-space that marks beginning of episodes. | |
| for key in self.custom_info_keys_to_log_at_episode_end: | |
| key_found = False | |
| if 'final_info' in infos and key in infos['final_info']: | |
| self.stats['info_keys'][key].append(infos['final_info'][key][i]) # type: ignore | |
| key_found = True | |
| elif 'final_info' in infos and isinstance(infos['final_info'], list): # older gymnasium versions | |
| if key in infos['final_info'][i]: | |
| self.stats['info_keys'][key].append(infos['final_info'][i][key]) | |
| key_found = True | |
| if not key_found: | |
| # raise warning only first time | |
| if len(self.stats['info_keys'][key]) == 0: | |
| print(f"Warning: info key '{key}' not found in env info dictionary. Ignoring it. This warning will not be repeated.", file=sys.stderr) | |
| self.stats['info_keys'][key].append(np.nan) | |
| print(f"Collected steps {(t + 1) * self.envs.num_envs}/{self.nsteps * self.envs.num_envs}", end='\r') | |
| terminal_width = shutil.get_terminal_size((80, 20)).columns | |
| print(' ' * terminal_width, end='\r') # erase the progress log | |
| # if using different device for inference, copy data to training device | |
| if self.inference_device != self.training_device: | |
| self.obs_buf.data.copy_(self.obs_buf_inference.data, non_blocking=True) | |
| self.actions_buf.data.copy_(self.actions_buf_inference.data, non_blocking=True) | |
| self.log_probs_buf.data.copy_(self.log_probs_buf_inference.data, non_blocking=True) | |
| self.rewards_buf.data.copy_(self.rewards_buf_inference.data, non_blocking=True) | |
| self.costs_buf.data.copy_(self.costs_buf_inference.data, non_blocking=True) | |
| self.terminateds_buf.data.copy_(self.terminateds_buf_inference.data, non_blocking=True) | |
| self.truncateds_buf.data.copy_(self.truncateds_buf_inference.data, non_blocking=True) | |
| # update stats | |
| self.stats['total_timesteps'] += self.nsteps * self.envs.num_envs | |
| self.stats['timestepss'].append(self.stats['total_timesteps']) | |
| self.stats['episodess'].append(self.stats['total_episodes']) | |
| # Rolling window stats (last 100 episodes): mean return/length/cost/success. | |
| # Also track custom info keys from env for handy metrics. | |
| mean_return = np.mean(self.stats['returns'][-100:]) if len(self.stats['returns']) > 0 else 0.0 | |
| mean_discounted_return = np.mean(self.stats['discounted_returns'][-100:]) if len(self.stats['discounted_returns']) > 0 else 0.0 | |
| mean_discounted_cost = np.mean(self.stats['discounted_costs'][-100:]) if len(self.stats['discounted_costs']) > 0 else 0.0 | |
| mean_length = np.mean(self.stats['lengths'][-100:]) if len(self.stats['lengths']) > 0 else 0.0 | |
| mean_total_cost = np.mean(self.stats['total_costs'][-100:]) if len(self.stats['total_costs']) > 0 else 0.0 | |
| mean_success = np.mean(self.stats['successes'][-100:]) if len(self.stats['successes']) > 0 else 0.0 | |
| self.stats['mean_returns'].append(mean_return) | |
| self.stats['mean_discounted_returns'].append(mean_discounted_return) | |
| self.stats['mean_discounted_costs'].append(mean_discounted_cost) | |
| self.stats['mean_lengths'].append(mean_length) | |
| self.stats['mean_total_costs'].append(mean_total_cost) | |
| self.stats['mean_successes'].append(mean_success) | |
| for key in self.custom_info_keys_to_log_at_episode_end: | |
| key_mean = np.nanmean(self.stats['info_keys'][key][-100:]) if len(self.stats['info_keys'][key]) > 0 else 0.0 | |
| self.stats['info_key_means'][key].append(key_mean) | |
| def compute_values_advantages(self) -> None: | |
| """Compute value estimates and GAE advantages for rewards (and costs in CMDP mode).""" | |
| self.critic.eval() | |
| self.critic.cache = copy.deepcopy(self.critic_context_initial) # set context for RNNs | |
| self.actor.eval() | |
| self.actor.cache = copy.deepcopy(self.actor_inference_context_initial) # set context for RNNs | |
| if self.cmdp_mode: | |
| self.cost_critic.eval() | |
| self.cost_critic.cache = copy.deepcopy(self.cost_critic_context_initial) # set context for RNNs | |
| with torch.no_grad(): | |
| obs_buf_cat_cur_obs = torch.cat([self.obs_buf, self.obs.to(self.obs_buf.device).unsqueeze(1)], dim=1) # (N, M+1, obs_dim) | |
| if not self.is_rucur: | |
| obs_buf_cat_cur_obs = obs_buf_cat_cur_obs.reshape(-1, *obs_buf_cat_cur_obs.shape[2:]) # (N*(M+1), obs_dim) | |
| v_cat_v_cur = self.critic.get_value(obs_buf_cat_cur_obs) # (N*(M+1), 1) or (N, M+1, 1) depending on whether RNN | |
| v_cat_v_cur = v_cat_v_cur.reshape(self.envs.num_envs, self.nsteps + 1, 1) | |
| v = v_cat_v_cur[:, :-1, :] # (N, M, 1) | |
| v_next = v_cat_v_cur[:, 1:, :] # (N, M, 1) | |
| advantages = self.rewards_buf + self.gamma * (1.0 - self.terminateds_buf) * v_next - v # (N, M, 1) | |
| # ! Important: Since we did not handle episode boundaries properly, v_next is not exactly value of obs_prime when dones are 1. This is not an issue for terminated episodes because v_next does not contribute to advantages, but for truncated episodes, v_next is being used for bootstrapping. We should set advantages to zero where episodes are truncated so that losses are not affected. We do lose some data here (insignificant if timelimits are large), but it's better than having wrong advantages. We are not going to use a more complex method to use every single data point since it adds significantly more computation complexity and moreover, makes it really hard to deal with RNNs. For RNNs, we assume that the state space has a flag that marks the beginning of episodes, so that the RNN automatically learns to reset its hidden states at episode boundaries. | |
| advantages *= (1.0 - self.truncateds_buf) | |
| adv_next = torch.zeros((self.envs.num_envs, 1), dtype=torch.float32, device=self.training_device) | |
| dones = (self.terminateds_buf + self.truncateds_buf).clamp(max=1.0) # (N, M, 1) | |
| for t in reversed(range(self.nsteps)): | |
| # Standard backward-time GAE(λ): A_t = δ_t + γλ (1 - done_t) A_{t+1} | |
| advantages[:, t, :] += (1.0 - dones[:, t, :]) * self.gamma * self.lam * adv_next | |
| adv_next = advantages[:, t, :] | |
| self.values_buf.data.copy_(v.data, non_blocking=True) | |
| self.advantages_buf.data.copy_(advantages.data, non_blocking=True) | |
| if self.cmdp_mode: | |
| c_cat_c_cur = self.cost_critic.get_value(obs_buf_cat_cur_obs) # (N*(M+1), 1) or (N, M+1, 1) depending on whether RNN | |
| c_cat_c_cur = c_cat_c_cur.reshape(self.envs.num_envs, self.nsteps + 1, 1) | |
| c = c_cat_c_cur[:, :-1, :] # (N, M, 1) | |
| c_next = c_cat_c_cur[:, 1:, :] # (N, M, 1) | |
| cost_advantages = self.costs_buf + self.gamma * (1.0 - self.terminateds_buf) * c_next - c # (N, M, 1) | |
| cost_advantages *= (1.0 - self.truncateds_buf) # zero out advantages where episodes are truncated (see comment above) | |
| cost_adv_next = torch.zeros((self.envs.num_envs, 1), dtype=torch.float32, device=self.training_device) | |
| for t in reversed(range(self.nsteps)): | |
| cost_advantages[:, t, :] += (1.0 - dones[:, t, :]) * self.gamma * self.lam * cost_adv_next | |
| cost_adv_next = cost_advantages[:, t, :] | |
| self.cost_values_buf.data.copy_(c.data, non_blocking=True) | |
| self.cost_advantages_buf.data.copy_(cost_advantages.data, non_blocking=True) | |
| def loss(self, obs: Tensor, actions: Tensor, old_values: Tensor, advantages: Tensor, old_logprobs: Tensor, old_cost_values: Tensor, cost_advantages: Tensor, clip_ratio, actor_coeff=1.0, value_coeff=0.5, cost_coeff=0.5, entropy_coeff=0.0) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Tensor]: | |
| """Compute the combined PPO loss (policy + value + entropy + optional cost).""" | |
| value_loss = self.critic.get_loss(obs, old_values, advantages) | |
| cost_loss = self.cost_critic.get_loss(obs, old_cost_values, cost_advantages) if self.cmdp_mode else None | |
| if self.normalize_advantages: | |
| advantages = (advantages - advantages.mean().detach()) / (advantages.std().detach() + 1e-3) | |
| if self.cmdp_mode: | |
| cost_advantages = (cost_advantages - cost_advantages.mean().detach()) / (cost_advantages.std().detach() + 1e-3) | |
| if self.cmdp_mode: | |
| advantages -= self.lagrange * cost_advantages | |
| if self.normalize_advantages: # normalize again after combining | |
| advantages = (advantages - advantages.mean().detach()) / (advantages.std().detach() + 1e-3) # again | |
| actor_loss, entropy = self.actor.get_policy_loss_and_entropy(obs, actions, advantages, old_logprobs, clip_ratio) | |
| if cost_loss is not None: | |
| total_loss = actor_coeff * actor_loss + value_coeff * value_loss + cost_coeff * cost_loss - entropy_coeff * entropy | |
| else: | |
| total_loss = actor_coeff * actor_loss + value_coeff * value_loss - entropy_coeff * entropy | |
| return total_loss, actor_loss, value_loss, cost_loss, entropy | |
| def train_one_iteration(self): | |
| """Perform one full PPO iteration: collect data, optionally normalize rewards/costs, update obs-normalization stats, compute advantages, anneal hyperparameters, update networks, log stats.""" | |
| # collect data | |
| self.collect_trajectories() | |
| # update normalization stats | |
| if self.norm_rewards and len(self.stats['discounted_returns']) > 1: | |
| reward_std = np.std(self.stats['discounted_returns']) | |
| if reward_std >= 1e-3: | |
| self.rewards_buf /= reward_std # type: ignore | |
| if self.cmdp_mode: | |
| costs_std = np.std(self.stats['discounted_costs']) | |
| if costs_std > 1e-3: | |
| self.costs_buf /= costs_std # type: ignore | |
| if self.actor.norm_obs: | |
| self.actor.update_obs_stats(self.obs_buf.reshape(-1, *self.obs_buf.shape[2:])) | |
| if self.critic.norm_obs: | |
| self.critic.update_obs_stats(self.obs_buf.reshape(-1, *self.obs_buf.shape[2:])) | |
| if self.cmdp_mode and self.cost_critic.norm_obs: | |
| self.cost_critic.update_obs_stats(self.obs_buf.reshape(-1, *self.obs_buf.shape[2:])) | |
| # compute values and advantages | |
| self.compute_values_advantages() | |
| # do annealing updates of lr, entropy coef, clip ratio | |
| entropy_coef = self.entropy_coef | |
| if self.decay_entropy_coef: | |
| entropy_coef = self.entropy_coef * (1.0 - float(self.stats['iterations']) / float(self.iters)) | |
| actor_lr, critic_lr = self.actor_lr, self.critic_lr | |
| if self.decay_lr: | |
| actor_lr = max(self.actor_lr * (1.0 - float(self.stats['iterations']) / float(self.iters)), self.min_lr) | |
| critic_lr = max(self.critic_lr * (1.0 - float(self.stats['iterations']) / float(self.iters)), self.min_lr) | |
| for param_group in self.actor_optimizer.param_groups: | |
| param_group['lr'] = actor_lr | |
| for param_group in self.critic_optimizer.param_groups: | |
| param_group['lr'] = critic_lr | |
| if self.cmdp_mode: | |
| for param_group in self.cost_optimizer.param_groups: | |
| param_group['lr'] = critic_lr | |
| clip_ratio = self.clip_ratio | |
| if self.decay_clip_ratio: | |
| clip_ratio = max(self.clip_ratio * (1.0 - float(self.stats['iterations']) / float(self.iters)), self.min_clip_ratio) | |
| # update policy and value networks | |
| epochs = 0 | |
| kl = 0.0 | |
| stop_actor_training = False | |
| if self.is_rucur: | |
| obs_buf, actions_buf, values_buf, advantages_buf, log_probs_buf, cost_values_buf, cost_advantages_buf = (self.obs_buf, self.actions_buf, self.values_buf, self.advantages_buf, self.log_probs_buf, self.cost_values_buf, self.cost_advantages_buf) | |
| pass | |
| else: | |
| # flatten the (N_envs, nsteps, ...) data to (N_envs * nsteps, ...) | |
| obs_buf, actions_buf, values_buf, advantages_buf, log_probs_buf, cost_values_buf, cost_advantages_buf = (self.obs_buf.reshape(-1, *self.obs_buf.shape[2:]), self.actions_buf.reshape(-1, *self.actions_buf.shape[2:]), self.values_buf.reshape(-1, 1), self.advantages_buf.reshape(-1, 1), self.log_probs_buf.reshape(-1, 1), self.cost_values_buf.reshape(-1, 1), self.cost_advantages_buf.reshape(-1, 1)) | |
| for epoch in range(self.nepochs): | |
| print(f"Training epoch {epoch + 1}/{self.nepochs}", end='\r') | |
| self.actor.train() | |
| self.critic.train() | |
| if self.cmdp_mode: | |
| self.cost_critic.train() | |
| total = self.nsteps * self.envs.num_envs | |
| idxs = torch.randperm(total, device=self.training_device) # for non-RNNs | |
| if self.is_rucur: | |
| n_segments_along_time = self.nsteps // self.context_length | |
| n_trajs_per_batch = self.batch_size // self.context_length | |
| n_batches = total // self.batch_size | |
| for mb_num in range(n_batches): | |
| if not self.is_rucur: | |
| start = mb_num * self.batch_size | |
| end = min(start + self.batch_size, total) | |
| mb_idxs = idxs[start:end] | |
| mb_loss, _, _, _, _ = self.loss(obs_buf[mb_idxs], actions_buf[mb_idxs], values_buf[mb_idxs], advantages_buf[mb_idxs], log_probs_buf[mb_idxs], cost_values_buf[mb_idxs], cost_advantages_buf[mb_idxs], clip_ratio, float(not stop_actor_training), self.value_loss_coef, self.value_loss_coef * float(self.cmdp_mode), entropy_coef) | |
| else: | |
| n_segments_along_time = self.nsteps // self.context_length | |
| n_trajs_per_batch = self.batch_size // self.context_length | |
| n_segments_along_batch = obs_buf.shape[0] // n_trajs_per_batch | |
| assert n_batches == n_segments_along_batch * n_segments_along_time, "Batch size not compatible with context length and number of environments." | |
| # we are moving along a grid of n_segments_along_batch x n_segments_along_time. Scanning along time axis. Whehever time axis index becomes 0, we reset the cache to initial context for the corresponding trajectories along the batch axis. | |
| grid_row_index = (mb_num // n_segments_along_time) | |
| grid_col_index = (mb_num % n_segments_along_time) | |
| dim0_start = grid_row_index * n_trajs_per_batch | |
| dim0_end = min(dim0_start + n_trajs_per_batch, obs_buf.shape[0]) | |
| dim1_start = grid_col_index * self.context_length | |
| dim1_end = min(dim1_start + self.context_length, obs_buf.shape[1]) | |
| dim0_range = list(range(dim0_start, dim0_end)) | |
| if grid_col_index == 0: | |
| self.actor.cache = copy.deepcopy(cache_move_to_device(cache_slice_for_batch_ids(self.actor_inference_context_initial, dim0_range), self.training_device)) | |
| self.critic.cache = copy.deepcopy(cache_slice_for_batch_ids(self.critic_context_initial, dim0_range)) | |
| if self.cmdp_mode: | |
| self.cost_critic.cache = copy.deepcopy(cache_slice_for_batch_ids(self.cost_critic_context_initial, dim0_range)) | |
| else: | |
| # detach the cache from previous segment to avoid backprop through time | |
| self.actor.cache = cache_detach(self.actor.cache) | |
| self.critic.cache = cache_detach(self.critic.cache) | |
| if self.cmdp_mode: | |
| self.cost_critic.cache = cache_detach(self.cost_critic.cache) | |
| mb_loss, _, _, _, _ = self.loss(obs_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], actions_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], values_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], advantages_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], log_probs_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], cost_values_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], cost_advantages_buf[dim0_start:dim0_end, dim1_start:dim1_end, ...], clip_ratio, float(not stop_actor_training), self.value_loss_coef, self.value_loss_coef * float(self.cmdp_mode), entropy_coef) | |
| self.actor_optimizer.zero_grad() | |
| self.critic_optimizer.zero_grad() | |
| if self.cmdp_mode: | |
| self.cost_optimizer.zero_grad() | |
| mb_loss.backward() | |
| # update actor | |
| if not stop_actor_training: | |
| # clipnorm (only) if actor is being updated, as it is an expensive operation | |
| if self.clipnorm < np.inf: | |
| torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.clipnorm) | |
| self.actor_optimizer.step() | |
| # update critic | |
| if self.clipnorm < np.inf: | |
| torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.clipnorm) | |
| self.critic_optimizer.step() | |
| # update cost critic | |
| if self.cmdp_mode: | |
| if self.clipnorm < np.inf: | |
| torch.nn.utils.clip_grad_norm_(self.cost_critic.parameters(), self.clipnorm) | |
| self.cost_optimizer.step() | |
| epochs += 1 | |
| if not stop_actor_training: | |
| self.actor.eval() | |
| self.actor.cache = copy.deepcopy(cache_move_to_device(self.actor_inference_context_initial, self.training_device)) # set context for RNNs | |
| # Early stop if measured KL exceeds 1.5 * target_kl to avoid over-updating. | |
| kl = self.actor.get_kl_div(obs_buf, actions_buf, log_probs_buf).item() | |
| if kl > 1.5 * self.target_kl: | |
| stop_actor_training = True | |
| if self.early_stop_critic: | |
| break | |
| terminal_width = shutil.get_terminal_size((80, 20)).columns | |
| print(' ' * terminal_width, end='\r') # erase the progress log | |
| # compute loss on all data for logging. This will also forward the caches for RNNs to the end of the data and set it up for next iteration. | |
| with torch.no_grad(): | |
| self.actor.eval() | |
| self.critic.eval() | |
| if self.cmdp_mode: | |
| self.cost_critic.eval() | |
| self.actor.cache = copy.deepcopy(cache_move_to_device(self.actor_inference_context_initial, self.training_device)) # set context for RNNs | |
| self.critic.cache = copy.deepcopy(self.critic_context_initial) # set context for RNNs | |
| if self.cmdp_mode: | |
| self.cost_critic.cache = copy.deepcopy(self.cost_critic_context_initial) # set context for RNNs | |
| total_loss, actor_loss, critic_loss, cost_loss, entropy = self.loss(obs_buf, actions_buf, values_buf, advantages_buf, log_probs_buf, cost_values_buf, cost_advantages_buf, clip_ratio, 1, self.value_loss_coef, self.value_loss_coef * float(self.cmdp_mode), entropy_coef) | |
| value = values_buf.mean().item() | |
| mean_cost = cost_values_buf.mean().item() if self.cmdp_mode else 0.0 | |
| self.actor.cache = self.actor_inference_context_initial # reset context before calling get_logprob | |
| logprobs = self.actor.get_logprob(obs_buf, actions_buf) | |
| mean_logprob = logprobs.mean().item() | |
| if isinstance(self.actor, ActorContinuous): | |
| self.actor.cache = copy.deepcopy(cache_move_to_device(self.actor_inference_context_initial, self.training_device)) | |
| logstd = self.actor._get_logstd(obs_buf).mean(dim=0) | |
| mean_logstd = logstd.mean().item() | |
| else: | |
| mean_logstd = 0.0 | |
| # record cached contexts for RNNs for next iteration | |
| self.actor_inference_context_initial = copy.deepcopy(cache_move_to_device(cache_detach(self.actor.cache), self.inference_device)) | |
| self.critic_context_initial = copy.deepcopy(cache_detach(self.critic.cache)) | |
| if self.cmdp_mode: | |
| self.cost_critic_context_initial = copy.deepcopy(cache_detach(self.cost_critic.cache)) | |
| # update actor for inference, if using a different device for inference | |
| if self.inference_device != self.training_device: | |
| self.actor_inference.load_state_dict(self.actor.state_dict()) | |
| # Lagrange update: | |
| if self.cmdp_mode: | |
| # λ <- clip(λ + lr * (mean_cost - cost_threshold), [0, lagrange_max]) | |
| # Switch between discounted vs undiscounted cost via `constrain_undiscounted_cost`. | |
| self.lagrange = self.lagrange + self.lagrange_lr * (self.moving_average_cost - self.cost_threshold) | |
| self.lagrange = np.clip(self.lagrange, 0.0, self.lagrange_max) # avoid too large lagrange multiplier | |
| # update stats | |
| self.stats['iterations'] += 1 | |
| self.stats['losses'].append(total_loss.item()) | |
| self.stats['actor_losses'].append(actor_loss.item()) | |
| self.stats['critic_losses'].append(critic_loss.item()) | |
| self.stats['cost_losses'].append(cost_loss.item() if cost_loss is not None else 0.0) | |
| self.stats['entropies'].append(entropy.item()) | |
| self.stats['kl_divs'].append(kl) | |
| self.stats['values'].append(value) | |
| self.stats['cost_values'].append(mean_cost) | |
| self.stats['logstds'].append(mean_logstd) | |
| self.stats['logprobs'].append(mean_logprob) | |
| self.stats['actor_lrs'].append(actor_lr) | |
| self.stats['critic_lrs'].append(critic_lr) | |
| self.stats['entropy_coefs'].append(entropy_coef) | |
| self.stats['clip_ratios'].append(clip_ratio) | |
| self.stats['lagranges'].append(self.lagrange) | |
| def train(self, for_iterations): | |
| """ | |
| Train the PPO agent for the specified number of iterations. | |
| Performs `for_iterations` calls to `train_one_iteration()`, printing progress | |
| and key metrics after each. Stops early if total iterations reach `self.iters`. | |
| Args: | |
| for_iterations: Number of iterations to train (or fewer if already complete). | |
| Returns: | |
| None (updates `self.stats` with training history). | |
| """ | |
| if self.stats['iterations'] >= self.iters: | |
| print("Training already complete.") | |
| return | |
| for it in range(for_iterations): | |
| self.train_one_iteration() | |
| print( | |
| f"Iteration {self.stats['iterations']}/{self.iters} complete.\n" | |
| f" Total timesteps: {self.stats['total_timesteps']}\n" | |
| f" Total episodes: {self.stats['total_episodes']}\n" | |
| f" Mean (100-window) return: {self.stats['mean_returns'][-1]:.6f}\n" | |
| f" Mean (100-window) total cost: {self.stats['mean_total_costs'][-1]:.6f}\n" | |
| f" Mean (100-window) discounted cost: {self.stats['mean_discounted_costs'][-1]:.6f}\n" | |
| f" Mean (100-window) length: {self.stats['mean_lengths'][-1]:.6f}\n" | |
| f" Mean (100-window) success: {self.stats['mean_successes'][-1]:.6f}\n" | |
| f" Exp. Moving average cost ({'undiscounted' if self.constrain_undiscounted_cost else 'discounted'}) (alpha = {self.moving_cost_estimate_step_size}): {self.moving_average_cost:.6f}\n" | |
| f" Lagrange multiplier: {self.stats['lagranges'][-1]:.6f}\n" | |
| f" Total loss: {self.stats['losses'][-1]:.6f}\n" | |
| f" Actor loss: {self.stats['actor_losses'][-1]:.6f}\n" | |
| f" Critic loss: {self.stats['critic_losses'][-1]:.6f}\n" | |
| f" Cost critic loss: {self.stats['cost_losses'][-1]:.6f}\n" | |
| f" Entropy: {self.stats['entropies'][-1]:.6f}\n" | |
| f" KL: {self.stats['kl_divs'][-1]:.6f}\n" | |
| f" Value: {self.stats['values'][-1]:.6f}\n" | |
| f" Cost value: {self.stats['cost_values'][-1]:.6f}\n" | |
| f" Learning rates - actor: {self.stats['actor_lrs'][-1]:.6f}, critic: {self.stats['critic_lrs'][-1]:.6f}\n" | |
| f" Entropy coef: {self.stats['entropy_coefs'][-1]:.6f}\n" | |
| f" Clip ratio: {self.stats['clip_ratios'][-1]:.6f}\n" | |
| + ''.join([f" Mean (100-window) {key}: {self.stats['info_key_means'][key][-1]:.6f}\n" for key in self.custom_info_keys_to_log_at_episode_end]), flush=True) | |
| if self.stats['iterations'] >= self.iters: | |
| print("Training complete.") | |
| break | |
| class POMDPEnvWrapper(gymnasium.Wrapper): | |
| """A wrapper to augment a flag indicating start of an episode in the observation space. Useful for RNN-based agents in POMDPs. Usually not required for Atari since it is clear which states are episode starts. TODO: This wrapper will concatenate the previous state, previous action, and previous reward to the current observation. | |
| """ | |
| def __init__(self, env: gymnasium.Env): | |
| super().__init__(env) | |
| original_obs_space = env.observation_space | |
| if isinstance(original_obs_space, Box): | |
| low = np.concatenate([original_obs_space.low, np.array([0.0], dtype=original_obs_space.dtype)], axis=0) | |
| high = np.concatenate([original_obs_space.high, np.array([1.0], dtype=original_obs_space.dtype)], axis=0) | |
| self.observation_space = Box(low=low, high=high, dtype=original_obs_space.dtype) # type: ignore | |
| else: | |
| raise NotImplementedError("POMDPEnvWrapper only supports Box observation spaces.") | |
| def reset(self, **kwargs): | |
| obs, info = self.env.reset(**kwargs) | |
| obs_with_flag = np.concatenate([obs, np.array([1.0], dtype=obs.dtype)], axis=0) # episode start flag set to 1.0 | |
| return obs_with_flag, info | |
| def step(self, action): | |
| obs, reward, terminated, truncated, info = self.env.step(action) | |
| obs_with_flag = np.concatenate([obs, np.array([0.0], dtype=obs.dtype)], axis=0) # episode start flag set to 0.0 | |
| return obs_with_flag, reward, terminated, truncated, info | |
| suffix = "-default" | |
| # Example usage: | |
| # python ppo_single_file.py <EnvName> <iters> [model_path] | |
| # If model_path is provided, load actor and evaluate; else train from scratch. For atari, specify env name starting with "ALE/". | |
| if __name__ == "__main__": | |
| # Command-line argument parsing for environment, iterations, and optional model loading. | |
| envname = sys.argv[1] if len(sys.argv) > 1 else "ALE/Pong-v5" | |
| iters = int(sys.argv[2]) if len(sys.argv) > 2 else 100 | |
| print(f"Using environment: {envname}") | |
| def make_env(env_name: str, **kwargs) -> gymnasium.Env: | |
| if env_name.startswith("ALE/"): # handle Atari specially with preprocessing and frame stacking. | |
| # if it is not -v5, then remove the "ALE/" prefix. | |
| if not env_name.endswith("-v5"): | |
| env_name = env_name[4:] | |
| import ale_py | |
| gymnasium.register_envs(ale_py) | |
| env = wrappers.FrameStackObservation(wrappers.AtariPreprocessing(make(env_name, frameskip=1, **kwargs), frame_skip=4), stack_size=4, padding_type="zero") | |
| else: | |
| env = make(env_name, **kwargs) | |
| env = wrappers.TimeLimit(env, max_episode_steps=100000) | |
| return env | |
| # one single env for recording videos | |
| env = make_env(envname, render_mode="rgb_array") | |
| env = wrappers.RecordVideo(env, f"videos/{envname.replace('/', '-')}{suffix}", name_prefix='train', episode_trigger=lambda x: True) | |
| action_dim: int = env.action_space.n if isinstance(env.action_space, Discrete) else env.action_space.shape[0] # type: ignore | |
| print(f"obs_shape: {env.observation_space.shape}, action_dim: {action_dim}") | |
| # vectorized env for training and common PPO kwargs (irrespective of atari vs others) | |
| envs = AsyncVectorEnv([lambda i=i: make_env(envname) for i in range(8)], **autoreset_kwarg_samestep_newgymapi) | |
| ppo_kwargs = dict(entropy_coef=0.01, nsteps=256, norm_rewards=True, training_device='cuda' if torch.cuda.is_available() else 'cpu') | |
| # Model && PPO setup: branch for Atari (CNN) vs others (classic control / Mujoco / MLP). | |
| if envname.startswith("ALE/"): | |
| actor_model = nn.Sequential( | |
| nn.Conv2d(4, 32, kernel_size=8, stride=4), # 32 x 20 x 20 | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), # 64 x 9 x 9 | |
| nn.ReLU(), | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), # 64 x 7 x 7 | |
| nn.ReLU(), | |
| nn.Flatten(), # 3136 | |
| nn.Linear(3136, 512), # 512 | |
| nn.ReLU(), | |
| nn.Linear(512, action_dim) # Output layer | |
| ) | |
| critic_model = nn.Sequential( | |
| nn.Conv2d(4, 32, kernel_size=8, stride=4), # 32 x 20 x 20 | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), # 64 x 9 x 9 | |
| nn.ReLU(), | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), # 64 x 7 x 7 | |
| nn.ReLU(), | |
| nn.Flatten(), # 3136 | |
| nn.Linear(3136, 512), # 512 | |
| nn.ReLU(), | |
| nn.Linear(512, 1) # Output layer | |
| ) | |
| ppo_kwargs.update(dict(actor_lr=0.0001, critic_lr=0.0001, decay_lr=True, decay_entropy_coef=True, batch_size=64, nepochs=4, inference_device=ppo_kwargs['training_device'])) | |
| else: | |
| obs_dim = env.observation_space.shape[0] # type: ignore | |
| actor_model = nn.Sequential( | |
| nn.Linear(obs_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, action_dim) | |
| ) | |
| critic_model = nn.Sequential( | |
| nn.Linear(obs_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 1) | |
| ) | |
| actor = Actor(actor_model, env.observation_space, env.action_space, deterministic=False, norm_obs=True) # type: ignore | |
| critic = Critic(critic_model, env.observation_space, norm_obs=True) # type: ignore | |
| ppo = PPO(envs, actor, critic, iters, **ppo_kwargs) # type: ignore | |
| # Optional: Load a pre-trained model and evaluate. | |
| test_model = sys.argv[3] if len(sys.argv) > 3 else "" | |
| if test_model != "": | |
| assert os.path.exists(test_model), f"Model path {test_model} does not exist" | |
| print(f"Testing model from {test_model}") | |
| actor.load_state_dict(torch.load(test_model, map_location=ppo.inference_device)) | |
| rs, _, _, _ = actor.evaluate_policy_parallel(envs, ppo.iters, deterministic=True) # run parallel tests over iters episodes | |
| print(f"Mean return over {ppo.iters} episodes: {np.mean(rs):.6f} +/- {np.std(rs):.6f}") | |
| env.name_prefix = "test" | |
| actor.evaluate_policy(env, 5, deterministic=True) # run few episodes and record video | |
| else: | |
| print("Training model from scratch") | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("plots", exist_ok=True) | |
| while ppo.stats['iterations'] < ppo.iters: | |
| ppo.train(100) # train in chunks of 100 iterations and save model, record a video, and plot training curve | |
| torch.save(ppo.actor.state_dict(), f"models/{envname.replace('/', '-')}{suffix}.pth") | |
| actor.evaluate_policy(env, 1, deterministic=False) # for video recording. Keeping deterministic=False to see the exact policy as in training | |
| plt.clf() | |
| plt.plot(np.arange(ppo.stats['iterations']) + 1, ppo.stats['mean_returns']) | |
| plt.gca().set(xlabel="Iteration", ylabel="Mean Return (100-episode window)", title=f"PPO on {envname}") | |
| plt.savefig(f"plots/{envname.replace('/', '-')}{suffix}.png") | |
| # Cleanup environments. | |
| envs.close() | |
| env.close() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To train, run
python ppo_single_file.py CartPole-v1 100. <--- trains for 100 PPO iterations. Saves videos every few iters.To test the trained model
python ppo_single_file.py CartPole-v1 100 model-CartPole-v1.pth. <--- tests for 100 episode. Saves 5 videos.In general, read the code under
if __name__ == "__main__":block to see an example of how to use the library. Trust me, it's simple.