Last active
March 8, 2026 13:39
-
-
Save KMJ-007/f55f0e2aad1df5d3227714998527ed56 to your computer and use it in GitHub Desktop.
Problems goes like this: - There are 100 holes in a line and there is bunny🐇 hinding in one of them. - You can only look in one hole at a time, every time you look it jumps to adjacent hole.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import numpy as np | |
| import statistics as stats | |
| from tqdm import tqdm | |
| N = 100 | |
| def guess(P, j): | |
| s = sum(P) - P[j] | |
| return [0 if i == j else P[i] / s for i in range(N)] | |
| def jumps(P): | |
| Pn = P.copy() | |
| Pn[0] = P[1] * 0.5 | |
| Pn[N - 1] = P[N - 2] * 0.5 | |
| for j in range(1, N - 1): | |
| Pn[j] = 0.5 * P[j - 1] + 0.5 * P[j + 1] | |
| return Pn | |
| def argmax(l): | |
| return max(range(len(l)), key=lambda i: l[i]) | |
| def move_hole(hole): | |
| if hole == 0: | |
| return hole + 1 | |
| elif hole == N - 1: | |
| return hole - 1 | |
| else: | |
| return hole + random.choice([-1, 1]) | |
| def play(strategy_fn, use_bandit=False, use_visits=False, use_recency=False): | |
| P = [1 / N] * N | |
| alpha = [1] * N | |
| beta = [1] * N | |
| visits = [0] * N | |
| hole = random.randint(0, N - 1) | |
| t = 0 | |
| while True: | |
| t += 1 | |
| if use_bandit: | |
| samples = [np.random.beta(a, b) for a, b in zip(alpha, beta)] | |
| j = strategy_fn(P, samples, visits, t) | |
| elif use_visits: | |
| j = strategy_fn(P, visits, t) | |
| else: | |
| j = strategy_fn(P) | |
| if j == hole: | |
| if use_bandit: | |
| alpha[j] += 1 | |
| break | |
| else: | |
| if use_bandit: | |
| beta[j] += 1 | |
| if use_recency: | |
| visits[j] = t | |
| elif use_visits: | |
| visits[j] += 1 | |
| P = guess(P, j) | |
| P = jumps(P) | |
| hole = move_hole(hole) | |
| return t | |
| def ucb_choice(P, visits, t, c=1.0): | |
| total_t = t | |
| scores = [] | |
| for i in range(N): | |
| bonus = c * np.sqrt(np.log(total_t) / (visits[i] + 1)) | |
| scores.append(P[i] + bonus) | |
| return argmax(scores) | |
| def epsilon_greedy(P, visits, t, eps=0.12): | |
| if random.random() < eps: | |
| return random.choices(range(N), weights=P)[0] | |
| return argmax(P) | |
| def parity_sweeping(P, visits, t): | |
| parity = (t // 10) % 2 | |
| candidates = [i for i in range(N) if i % 2 == parity] | |
| return max(candidates, key=lambda i: P[i]) | |
| def double_coverage(P, visits, t): | |
| pair_idx = (t // 2) % (N - 1) | |
| candidates = [pair_idx, pair_idx + 1] | |
| return max(candidates, key=lambda i: P[i]) | |
| def sweep_position(t): | |
| cycle = t // N | |
| if cycle % 2 == 0: | |
| return t % N | |
| else: | |
| return N - 1 - (t % N) | |
| def smart_sweep(P, visits, t): | |
| max_v = max(visits) + 1 | |
| staleness = [(max_v - v) / max_v for v in visits] | |
| scores = [P[i] * staleness[i] for i in range(N)] | |
| return argmax(scores) | |
| def recency_sweep(P, visits, t, window=50): | |
| candidates = [i for i in range(N) if t - visits[i] > window] | |
| if not candidates: | |
| candidates = list(range(N)) | |
| return max(candidates, key=lambda i: P[i]) | |
| strategies = { | |
| "random": (lambda P: random.randint(0, N - 1), False, False, False), | |
| "argmax": (lambda P, v, t: argmax(P), False, True, False), | |
| "weighted_random": ( | |
| lambda P: random.choices(range(N), weights=P)[0], | |
| False, | |
| False, | |
| False, | |
| ), | |
| "thompson": (lambda P, samples, v, t: argmax(samples), True, False, False), | |
| "ucb_c1.5": (lambda P, v, t: ucb_choice(P, v, t, c=1.5), False, True, False), | |
| "epsilon_greedy": ( | |
| lambda P, v, t: epsilon_greedy(P, v, t, eps=0.12), | |
| False, | |
| True, | |
| False, | |
| ), | |
| "parity_sweeping": (lambda P, v, t: parity_sweeping(P, v, t), False, True, False), | |
| "double_coverage": (lambda P, v, t: double_coverage(P, v, t), False, True, False), | |
| "sweep": (lambda P, v, t: sweep_position(t), False, True, False), | |
| "smart_sweep": (lambda P, v, t: smart_sweep(P, v, t), False, True, False), | |
| "recency_sweep": ( | |
| lambda P, v, t: recency_sweep(P, v, t, window=50), | |
| False, | |
| True, | |
| True, | |
| ), | |
| } | |
| from concurrent.futures import ThreadPoolExecutor | |
| def run_strategy(args): | |
| name, fn, use_bandit, use_visits, use_recency, n_runs = args | |
| r = [ | |
| play(fn, use_bandit, use_visits, use_recency) | |
| for _ in tqdm(range(n_runs), desc=name) | |
| ] | |
| return name, { | |
| "mean": stats.mean(r), | |
| "median": stats.median(r), | |
| "stdev": stats.stdev(r), | |
| "min": min(r), | |
| "max": max(r), | |
| "score": stats.mean(r) + stats.stdev(r), | |
| } | |
| n_runs = 1000 | |
| strategy_args = [ | |
| (name, fn, use_bandit, use_visits, use_recency, n_runs) | |
| for name, (fn, use_bandit, use_visits, use_recency) in strategies.items() | |
| ] | |
| with ThreadPoolExecutor() as executor: | |
| results = dict(executor.map(run_strategy, strategy_args)) | |
| print( | |
| f"{'Strategy':<20} {'Mean':>8} {'Median':>8} {'StdDev':>8} {'Min':>6} {'Max':>6} {'Score':>8}" | |
| ) | |
| print("-" * 66) | |
| for name, stats_dict in results.items(): | |
| print( | |
| f"{name:<20} {stats_dict['mean']:>8.1f} {stats_dict['median']:>8.1f} " | |
| f"{stats_dict['stdev']:>8.1f} {stats_dict['min']:>6} {stats_dict['max']:>6} {stats_dict['score']:>8.1f}" | |
| ) |
Author
Author
| Strategy | Mean | Median | StdDev | Min | Max | Score |
|---|---|---|---|---|---|---|
| random | 100.1 | 72.0 | 96.2 | 1 | 914 | 196.2 |
| argmax | 89.5 | 52.0 | 106.1 | 1 | 761 | 195.6 |
| weighted_random | 99.0 | 64.0 | 103.6 | 1 | 779 | 202.7 |
| thompson | 96.9 | 64.5 | 93.9 | 1 | 508 | 190.8 |
| ucb_c1.5 | 79.5 | 49.5 | 80.6 | 1 | 613 | 160.2 |
| epsilon_greedy | 101.3 | 55.0 | 128.4 | 1 | 1498 | 229.7 |
| parity_sweeping | 96.5 | 52.0 | 130.5 | 1 | 1551 | 227.0 |
| double_coverage | 113.8 | 105.5 | 85.5 | 1 | 665 | 199.4 |
| sweep | 100.4 | 98.0 | 57.1 | 1 | 198 | 157.6 |
| smart_sweep | 83.2 | 51.0 | 91.7 | 1 | 659 | 174.9 |
| recency_sweep | 88.5 | 48.0 | 113.7 | 1 | 1208 | 202.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
more: https://www.karanjanthe.me/posts/bunny-hopping/