Skip to content

Instantly share code, notes, and snippets.

@yeshwantd
Created September 15, 2024 20:58
Show Gist options
  • Select an option

  • Save yeshwantd/52c6e505965e8d81377078a19065e8c3 to your computer and use it in GitHub Desktop.

Select an option

Save yeshwantd/52c6e505965e8d81377078a19065e8c3 to your computer and use it in GitHub Desktop.
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.l1 = nn.Linear(8, 128) # Input observation is an 8 dimensional vector
self.l2 = nn.Linear(128, 128)
self.l3 = nn.Linear(128, 4) # There are 4 actions, hence the output is a 4 dim vector
def forward(self, observation):
# observation should be a torch tensor of size (1,8)
out = F.relu(self.l1(observation))
out = F.relu(self.l2(out))
logits = self.l3(out)
return logits
def action(self, observation):
logits = self(observation)
# Convert the logits to a discrete probability distribution over the 4 actions
probs = F.softmax(logits, dim=1)
# Choose the action with the highest probability
action = torch.argmax(probs)
return action.item()
def train(env, policy):
# Initialize the optimizer
optimizer = torch.optim.AdamW(policy.parameters())
# Run 5000 episodes of training
for i in range(5000):
observation, _ = env.reset()
terminated, truncated = False, False
losses = []
while not (terminated or truncated):
# Get the logits from the policy
logits = policy(torch.tensor(observation).view(1, -1))
# Convert the logits to a discrete probability distribution over the 4 actions
probs = F.softmax(logits, dim=1)
# Create a Categorical distribution over the 4 actions
action_distribution = Categorical(probs)
# Sample an action from the distribution
action = action_distribution.sample().item()
# Take a step in the environment
next_observation, reward, terminated, truncated, _ = env.step(action)
# Compute the log probability of the action
log_prob = action_distribution.log_prob(torch.tensor(action))
# Calculate the loss for taking this action from this observation
loss = -log_prob * reward
# Store the loss
losses.append(loss)
# Set the observation to be the next observation
observation = next_observation
# Compute the average loss
avg_loss = torch.stack(losses).mean()
# Backpropagate the loss
optimizer.zero_grad()
avg_loss.backward()
# Update the policy
optimizer.step()
# At every 1000th episode, check how the policy is doing by running the policy on 100 episodes
# and printing the mean reward, the standard deviation of the reward, and saving the policy
if (i+1) % 1000 == 0:
episode_rewards = []
for _ in range(100):
observation, _ = env.reset()
terminated, truncated = False, False
episode_reward = 0
while not (terminated or truncated):
# Get the action from the policy
with torch.no_grad():
action = policy.action(torch.tensor(observation).view(1, -1))
# Take a step in the environment
next_observation, reward, terminated, truncated, _ = env.step(action)
# Add the reward to the episode reward
episode_reward += reward
# Set the observation to be the next observation
observation = next_observation
episode_rewards.append(episode_reward)
# Print the mean reward and standard deviation
print(f"Episode {i+1}: Mean Reward = {np.mean(episode_rewards)} and Standard Deviation = {np.std(episode_rewards)}")
# Save the policy
torch.save(policy.state_dict(), f"./policy_{i+1}.pt")
def view_episode(env, policy):
observation, _ = env.reset()
terminated, truncated = False, False
episode_reward = 0
while not (terminated or truncated):
with torch.no_grad():
action = policy.action(torch.tensor(observation).view(1, -1))
observation, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
env.render()
return episode_reward
if __name__ == "__main__":
train_mode = True
if train_mode:
env = gym.envs.make("LunarLander-v2")
policy = Policy()
train(env, policy)
else:
env = gym.envs.make("LunarLander-v2", render_mode="human")
policy = Policy()
policy.load_state_dict(torch.load("./policy_5000.pt"))
episode_reward = view_episode(env, policy)
print(f"Episode Reward: {episode_reward}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment