Created
September 15, 2024 20:58
-
-
Save yeshwantd/52c6e505965e8d81377078a19065e8c3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gymnasium as gym | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical | |
| import numpy as np | |
| class Policy(nn.Module): | |
| def __init__(self): | |
| super(Policy, self).__init__() | |
| self.l1 = nn.Linear(8, 128) # Input observation is an 8 dimensional vector | |
| self.l2 = nn.Linear(128, 128) | |
| self.l3 = nn.Linear(128, 4) # There are 4 actions, hence the output is a 4 dim vector | |
| def forward(self, observation): | |
| # observation should be a torch tensor of size (1,8) | |
| out = F.relu(self.l1(observation)) | |
| out = F.relu(self.l2(out)) | |
| logits = self.l3(out) | |
| return logits | |
| def action(self, observation): | |
| logits = self(observation) | |
| # Convert the logits to a discrete probability distribution over the 4 actions | |
| probs = F.softmax(logits, dim=1) | |
| # Choose the action with the highest probability | |
| action = torch.argmax(probs) | |
| return action.item() | |
| def train(env, policy): | |
| # Initialize the optimizer | |
| optimizer = torch.optim.AdamW(policy.parameters()) | |
| # Run 5000 episodes of training | |
| for i in range(5000): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| losses = [] | |
| while not (terminated or truncated): | |
| # Get the logits from the policy | |
| logits = policy(torch.tensor(observation).view(1, -1)) | |
| # Convert the logits to a discrete probability distribution over the 4 actions | |
| probs = F.softmax(logits, dim=1) | |
| # Create a Categorical distribution over the 4 actions | |
| action_distribution = Categorical(probs) | |
| # Sample an action from the distribution | |
| action = action_distribution.sample().item() | |
| # Take a step in the environment | |
| next_observation, reward, terminated, truncated, _ = env.step(action) | |
| # Compute the log probability of the action | |
| log_prob = action_distribution.log_prob(torch.tensor(action)) | |
| # Calculate the loss for taking this action from this observation | |
| loss = -log_prob * reward | |
| # Store the loss | |
| losses.append(loss) | |
| # Set the observation to be the next observation | |
| observation = next_observation | |
| # Compute the average loss | |
| avg_loss = torch.stack(losses).mean() | |
| # Backpropagate the loss | |
| optimizer.zero_grad() | |
| avg_loss.backward() | |
| # Update the policy | |
| optimizer.step() | |
| # At every 1000th episode, check how the policy is doing by running the policy on 100 episodes | |
| # and printing the mean reward, the standard deviation of the reward, and saving the policy | |
| if (i+1) % 1000 == 0: | |
| episode_rewards = [] | |
| for _ in range(100): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| episode_reward = 0 | |
| while not (terminated or truncated): | |
| # Get the action from the policy | |
| with torch.no_grad(): | |
| action = policy.action(torch.tensor(observation).view(1, -1)) | |
| # Take a step in the environment | |
| next_observation, reward, terminated, truncated, _ = env.step(action) | |
| # Add the reward to the episode reward | |
| episode_reward += reward | |
| # Set the observation to be the next observation | |
| observation = next_observation | |
| episode_rewards.append(episode_reward) | |
| # Print the mean reward and standard deviation | |
| print(f"Episode {i+1}: Mean Reward = {np.mean(episode_rewards)} and Standard Deviation = {np.std(episode_rewards)}") | |
| # Save the policy | |
| torch.save(policy.state_dict(), f"./policy_{i+1}.pt") | |
| def view_episode(env, policy): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| episode_reward = 0 | |
| while not (terminated or truncated): | |
| with torch.no_grad(): | |
| action = policy.action(torch.tensor(observation).view(1, -1)) | |
| observation, reward, terminated, truncated, _ = env.step(action) | |
| episode_reward += reward | |
| env.render() | |
| return episode_reward | |
| if __name__ == "__main__": | |
| train_mode = True | |
| if train_mode: | |
| env = gym.envs.make("LunarLander-v2") | |
| policy = Policy() | |
| train(env, policy) | |
| else: | |
| env = gym.envs.make("LunarLander-v2", render_mode="human") | |
| policy = Policy() | |
| policy.load_state_dict(torch.load("./policy_5000.pt")) | |
| episode_reward = view_episode(env, policy) | |
| print(f"Episode Reward: {episode_reward}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment