Skip to content

Instantly share code, notes, and snippets.

View elumixor's full-sized avatar
🏠
Working from home

Vladyslav Yazykov elumixor

🏠
Working from home
  • Prague
  • 13:17 (UTC +01:00)
View GitHub Profile
@elumixor
elumixor / TRPO.py
Last active November 27, 2025 11:23
import gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.distributions import Categorical
from collections import namedtuple
env = gym.make('CartPole-v0')
def apply_update(grad_flattened):
n = 0
for p in actor.parameters():
numel = p.numel()
g = grad_flattened[n:n + numel].view(p.shape)
p.data += g
n += numel
def conjugate_gradient(A, b, delta=0., max_iterations=float('inf')):
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
i = 0
while i < max_iterations:
AVP = A(p)
dot_old = r @ r
def flat_grad(y, x, retain_graph=False, create_graph=False):
if create_graph:
retain_graph = True
g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
g = torch.cat([t.view(-1) for t in g])
return g
def kl_div(p, q):
p = p.detach()
return (p * (p.log() - q.log())).sum(-1).mean()
def surrogate_loss(new_probabilities, old_probabilities, advantages):
return (new_probabilities / old_probabilities * advantages).mean()
def estimate_advantages(states, last_state, rewards):
values = critic(states)
last_value = critic(last_state.unsqueeze(0))
next_values = torch.zeros_like(rewards)
for i in reversed(range(rewards.shape[0])):
last_value = next_values[i] = rewards[i] + 0.99 * last_value
advantages = next_values - values
return advantages
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = normalize(torch.cat(advantages, dim=0).flatten())
update_critic(advantages)
distribution = actor(states)
i = 0
while not criterion((0.9 ** i) * max_step) and i < 10:
i += 1
def criterion(step):
# Apply parameters' update
apply_update(step)
with torch.no_grad():
distribution_new = actor(states)
distribution_new = torch.distributions.utils.clamp_probs(distribution_new)
probabilities_new = distribution_new[range(distribution_new.shape[0]), actions]
L_new = surrogate_loss(probabilities_new, probabilities, advantages)