Last active
September 21, 2024 17:44
-
-
Save yeshwantd/85f3a9bc9060ae9f5b66177530a5ec63 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gymnasium as gym | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical | |
| import numpy as np | |
| class Policy(nn.Module): | |
| def __init__(self): | |
| super(Policy, self).__init__() | |
| self.l1 = nn.Linear(8, 128) | |
| self.l2 = nn.Linear(128, 128) | |
| self.l3 = nn.Linear(128, 4) | |
| def forward(self, observation): | |
| out = F.relu(self.l1(observation)) | |
| out = F.relu(self.l2(out)) | |
| logits = self.l3(out) | |
| return logits | |
| def action(self, observation): | |
| logits = self(observation) | |
| probs = F.softmax(logits, dim=1) | |
| action = torch.argmax(probs) | |
| return action.item() | |
| def train(env, policy): | |
| optimizer = torch.optim.AdamW(policy.parameters()) | |
| for i in range(5000): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| rewards, log_probs = [], [] | |
| while not (terminated or truncated): | |
| logits = policy(torch.tensor(observation).view(1, -1)) | |
| probs = F.softmax(logits, dim=1) | |
| action_distribution = Categorical(probs) | |
| action = action_distribution.sample().item() | |
| next_observation, reward, terminated, truncated, _ = env.step(action) | |
| log_prob = action_distribution.log_prob(torch.tensor(action)) | |
| observation = next_observation | |
| # Save the rewards and log probabilities | |
| rewards.append(reward) | |
| log_probs.append(log_prob) | |
| # Compute the discouted sum of rewards | |
| gamma = 0.99 # discount factor | |
| n = len(rewards) | |
| rewards = torch.tensor(rewards, dtype=torch.float32) | |
| discount = gamma ** torch.arange(n) | |
| discounted_sum_rewards = [torch.sum(rewards[i:] * discount[:n - i]) for i in range(n)] | |
| discounted_sum_rewards = torch.tensor(discounted_sum_rewards) | |
| # Normalize the sum of discounted rewards | |
| normalized_discounted_sum_rewards = (discounted_sum_rewards - discounted_sum_rewards.mean()) / discounted_sum_rewards.std() | |
| # Compute the loss | |
| avg_loss = -(torch.hstack(log_probs) * normalized_discounted_sum_rewards).mean() | |
| # Backpropagate the loss | |
| optimizer.zero_grad() | |
| avg_loss.backward() | |
| # Update the policy | |
| optimizer.step() | |
| # At every 1000th episode, check how the policy is doing | |
| if (i+1) % 1000 == 0: | |
| episode_rewards = [] | |
| for _ in range(100): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| episode_reward = 0 | |
| while not (terminated or truncated): | |
| with torch.no_grad(): | |
| action = policy.action(torch.tensor(observation).view(1, -1)) | |
| next_observation, reward, terminated, truncated, _ = env.step(action) | |
| episode_reward += reward | |
| observation = next_observation | |
| episode_rewards.append(episode_reward) | |
| print(f"Episode {i+1}: Mean Reward = {np.mean(episode_rewards)} and Standard Deviation = {np.std(episode_rewards)}") | |
| torch.save(policy.state_dict(), f"./policy_{i+1}.pt") | |
| def view_episode(env, policy): | |
| observation, _ = env.reset() | |
| terminated, truncated = False, False | |
| episode_reward = 0 | |
| while not (terminated or truncated): | |
| with torch.no_grad(): | |
| action = policy.action(torch.tensor(observation).view(1, -1)) | |
| observation, reward, terminated, truncated, _ = env.step(action) | |
| episode_reward += reward | |
| env.render() | |
| return episode_reward | |
| if __name__ == "__main__": | |
| train_mode = True | |
| if train_mode: | |
| env = gym.envs.make("LunarLander-v2") | |
| policy = Policy() | |
| train(env, policy) | |
| else: | |
| env = gym.envs.make("LunarLander-v2", render_mode="human") | |
| policy = Policy() | |
| policy.load_state_dict(torch.load("./policy_5000.pt")) | |
| episode_reward = view_episode(env, policy) | |
| print(f"Episode Reward: {episode_reward}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment