Created
August 17, 2022 07:21
-
-
Save hiszpanski/b3ec7f20e7d7326342ad69e86e973b6d to your computer and use it in GitHub Desktop.
Sutton & Barto Gambler Problem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from decimal import * | |
| import matplotlib.pyplot as plt | |
| """ | |
| Curiously, the policy does not match the textbook. This may be due to | |
| numerical precision, or ties. | |
| """ | |
| p = Decimal('0.4') | |
| θ = Decimal('1e-50') | |
| γ = Decimal(1) | |
| goal = 100 | |
| def argmax(x): | |
| return max(range(len(x)), key=lambda i: x[i]) | |
| fig, axs = plt.subplots(2) | |
| # Value Iteration | |
| V = [0] * goal + [0] | |
| r = [0] * goal + [1] # Reward is 1 if reach goal, 0 otherwise | |
| while True: | |
| Δ = 0 | |
| for s in range(1, goal): | |
| v = V[s] | |
| V[s] = max([ p * (r[s+a] + γ*V[s+a]) + (1-p) * (r[s-a] + γ*V[s-a]) | |
| for a in range(1,min(s, goal-s)+1) ]) | |
| Δ = max(Δ, abs(v - V[s])) | |
| axs[0].plot(V[:goal]) | |
| if Δ < θ: | |
| break | |
| # Derive policy from value function | |
| π = [ | |
| argmax([ | |
| p * (r[s+a] + γ*V[s+a]) + (1-p) * (r[s-a] + γ*V[s-a]) | |
| for a in range(1, min(s, goal-s)+1) | |
| ]) for s in range(1, goal) | |
| ] | |
| axs[0].set(ylabel='Value') | |
| axs[0].set_title('Value Function') | |
| axs[1].bar(range(len(π)), π) | |
| axs[1].set(xlabel='State', ylabel='Action') | |
| axs[1].set_title('Policy') | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment