Skip to content

Instantly share code, notes, and snippets.

@hiszpanski
Created August 17, 2022 07:21
Show Gist options
  • Select an option

  • Save hiszpanski/b3ec7f20e7d7326342ad69e86e973b6d to your computer and use it in GitHub Desktop.

Select an option

Save hiszpanski/b3ec7f20e7d7326342ad69e86e973b6d to your computer and use it in GitHub Desktop.
Sutton & Barto Gambler Problem
from decimal import *
import matplotlib.pyplot as plt
"""
Curiously, the policy does not match the textbook. This may be due to
numerical precision, or ties.
"""
p = Decimal('0.4')
θ = Decimal('1e-50')
γ = Decimal(1)
goal = 100
def argmax(x):
return max(range(len(x)), key=lambda i: x[i])
fig, axs = plt.subplots(2)
# Value Iteration
V = [0] * goal + [0]
r = [0] * goal + [1] # Reward is 1 if reach goal, 0 otherwise
while True:
Δ = 0
for s in range(1, goal):
v = V[s]
V[s] = max([ p * (r[s+a] + γ*V[s+a]) + (1-p) * (r[s-a] + γ*V[s-a])
for a in range(1,min(s, goal-s)+1) ])
Δ = max(Δ, abs(v - V[s]))
axs[0].plot(V[:goal])
if Δ < θ:
break
# Derive policy from value function
π = [
argmax([
p * (r[s+a] + γ*V[s+a]) + (1-p) * (r[s-a] + γ*V[s-a])
for a in range(1, min(s, goal-s)+1)
]) for s in range(1, goal)
]
axs[0].set(ylabel='Value')
axs[0].set_title('Value Function')
axs[1].bar(range(len(π)), π)
axs[1].set(xlabel='State', ylabel='Action')
axs[1].set_title('Policy')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment