Skip to content

Instantly share code, notes, and snippets.

@KMJ-007
Last active March 8, 2026 13:39
Show Gist options
  • Select an option

  • Save KMJ-007/f55f0e2aad1df5d3227714998527ed56 to your computer and use it in GitHub Desktop.

Select an option

Save KMJ-007/f55f0e2aad1df5d3227714998527ed56 to your computer and use it in GitHub Desktop.
Problems goes like this: - There are 100 holes in a line and there is bunny🐇 hinding in one of them. - You can only look in one hole at a time, every time you look it jumps to adjacent hole.
import random
import numpy as np
import statistics as stats
from tqdm import tqdm
N = 100
def guess(P, j):
s = sum(P) - P[j]
return [0 if i == j else P[i] / s for i in range(N)]
def jumps(P):
Pn = P.copy()
Pn[0] = P[1] * 0.5
Pn[N - 1] = P[N - 2] * 0.5
for j in range(1, N - 1):
Pn[j] = 0.5 * P[j - 1] + 0.5 * P[j + 1]
return Pn
def argmax(l):
return max(range(len(l)), key=lambda i: l[i])
def move_hole(hole):
if hole == 0:
return hole + 1
elif hole == N - 1:
return hole - 1
else:
return hole + random.choice([-1, 1])
def play(strategy_fn, use_bandit=False, use_visits=False, use_recency=False):
P = [1 / N] * N
alpha = [1] * N
beta = [1] * N
visits = [0] * N
hole = random.randint(0, N - 1)
t = 0
while True:
t += 1
if use_bandit:
samples = [np.random.beta(a, b) for a, b in zip(alpha, beta)]
j = strategy_fn(P, samples, visits, t)
elif use_visits:
j = strategy_fn(P, visits, t)
else:
j = strategy_fn(P)
if j == hole:
if use_bandit:
alpha[j] += 1
break
else:
if use_bandit:
beta[j] += 1
if use_recency:
visits[j] = t
elif use_visits:
visits[j] += 1
P = guess(P, j)
P = jumps(P)
hole = move_hole(hole)
return t
def ucb_choice(P, visits, t, c=1.0):
total_t = t
scores = []
for i in range(N):
bonus = c * np.sqrt(np.log(total_t) / (visits[i] + 1))
scores.append(P[i] + bonus)
return argmax(scores)
def epsilon_greedy(P, visits, t, eps=0.12):
if random.random() < eps:
return random.choices(range(N), weights=P)[0]
return argmax(P)
def parity_sweeping(P, visits, t):
parity = (t // 10) % 2
candidates = [i for i in range(N) if i % 2 == parity]
return max(candidates, key=lambda i: P[i])
def double_coverage(P, visits, t):
pair_idx = (t // 2) % (N - 1)
candidates = [pair_idx, pair_idx + 1]
return max(candidates, key=lambda i: P[i])
def sweep_position(t):
cycle = t // N
if cycle % 2 == 0:
return t % N
else:
return N - 1 - (t % N)
def smart_sweep(P, visits, t):
max_v = max(visits) + 1
staleness = [(max_v - v) / max_v for v in visits]
scores = [P[i] * staleness[i] for i in range(N)]
return argmax(scores)
def recency_sweep(P, visits, t, window=50):
candidates = [i for i in range(N) if t - visits[i] > window]
if not candidates:
candidates = list(range(N))
return max(candidates, key=lambda i: P[i])
strategies = {
"random": (lambda P: random.randint(0, N - 1), False, False, False),
"argmax": (lambda P, v, t: argmax(P), False, True, False),
"weighted_random": (
lambda P: random.choices(range(N), weights=P)[0],
False,
False,
False,
),
"thompson": (lambda P, samples, v, t: argmax(samples), True, False, False),
"ucb_c1.5": (lambda P, v, t: ucb_choice(P, v, t, c=1.5), False, True, False),
"epsilon_greedy": (
lambda P, v, t: epsilon_greedy(P, v, t, eps=0.12),
False,
True,
False,
),
"parity_sweeping": (lambda P, v, t: parity_sweeping(P, v, t), False, True, False),
"double_coverage": (lambda P, v, t: double_coverage(P, v, t), False, True, False),
"sweep": (lambda P, v, t: sweep_position(t), False, True, False),
"smart_sweep": (lambda P, v, t: smart_sweep(P, v, t), False, True, False),
"recency_sweep": (
lambda P, v, t: recency_sweep(P, v, t, window=50),
False,
True,
True,
),
}
from concurrent.futures import ThreadPoolExecutor
def run_strategy(args):
name, fn, use_bandit, use_visits, use_recency, n_runs = args
r = [
play(fn, use_bandit, use_visits, use_recency)
for _ in tqdm(range(n_runs), desc=name)
]
return name, {
"mean": stats.mean(r),
"median": stats.median(r),
"stdev": stats.stdev(r),
"min": min(r),
"max": max(r),
"score": stats.mean(r) + stats.stdev(r),
}
n_runs = 1000
strategy_args = [
(name, fn, use_bandit, use_visits, use_recency, n_runs)
for name, (fn, use_bandit, use_visits, use_recency) in strategies.items()
]
with ThreadPoolExecutor() as executor:
results = dict(executor.map(run_strategy, strategy_args))
print(
f"{'Strategy':<20} {'Mean':>8} {'Median':>8} {'StdDev':>8} {'Min':>6} {'Max':>6} {'Score':>8}"
)
print("-" * 66)
for name, stats_dict in results.items():
print(
f"{name:<20} {stats_dict['mean']:>8.1f} {stats_dict['median']:>8.1f} "
f"{stats_dict['stdev']:>8.1f} {stats_dict['min']:>6} {stats_dict['max']:>6} {stats_dict['score']:>8.1f}"
)
@KMJ-007
Copy link
Author

KMJ-007 commented Mar 7, 2026

@KMJ-007
Copy link
Author

KMJ-007 commented Mar 7, 2026

Strategy Mean Median StdDev Min Max Score
random 100.1 72.0 96.2 1 914 196.2
argmax 89.5 52.0 106.1 1 761 195.6
weighted_random 99.0 64.0 103.6 1 779 202.7
thompson 96.9 64.5 93.9 1 508 190.8
ucb_c1.5 79.5 49.5 80.6 1 613 160.2
epsilon_greedy 101.3 55.0 128.4 1 1498 229.7
parity_sweeping 96.5 52.0 130.5 1 1551 227.0
double_coverage 113.8 105.5 85.5 1 665 199.4
sweep 100.4 98.0 57.1 1 198 157.6
smart_sweep 83.2 51.0 91.7 1 659 174.9
recency_sweep 88.5 48.0 113.7 1 1208 202.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment