Skip to content

Instantly share code, notes, and snippets.

@vacmar01
Created February 21, 2026 16:52
Show Gist options
  • Select an option

  • Save vacmar01/c677bc96c96224dce77b926184fef9f7 to your computer and use it in GitHub Desktop.

Select an option

Save vacmar01/c677bc96c96224dce77b926184fef9f7 to your computer and use it in GitHub Desktop.
Optimized Sudoku solver from GEPA blog post
# Fast Sudoku solver: bitmask constraints + MRV + forward-checking
# Drop-in replacement: solve(puzzle_str)->81-char string
# --- Precompute structures (indices 0..80) ---
ROWS = [list(range(r * 9, r * 9 + 9)) for r in range(9)]
COLS = [list(range(c, 81, 9)) for c in range(9)]
BOXS = []
for br in range(0, 9, 3):
for bc in range(0, 9, 3):
box = []
for r in range(br, br + 3):
box.extend(range(r * 9 + bc, r * 9 + bc + 3))
BOXS.append(box)
UNITS = ROWS + COLS + BOXS
PEERS = [None] * 81
CELL_UNITS = [[] for _ in range(81)]
for u in UNITS:
for i in u:
CELL_UNITS[i].append(u)
for i in range(81):
ps = set()
for u in CELL_UNITS[i]:
ps.update(u)
ps.discard(i)
PEERS[i] = tuple(ps)
# Precompute row/col/box index for each cell
R_IDX = [i // 9 for i in range(81)]
C_IDX = [i % 9 for i in range(81)]
B_IDX = [((i // 9) // 3) * 3 + ((i % 9) // 3) for i in range(81)]
# Bit helpers
ALL = 0x1FF # 9 bits
BIT = [0] * 10
for d in range(1, 10):
BIT[d] = 1 << (d - 1)
# popcount for 0..511
POPC = [0] * 512
for m in range(512):
POPC[m] = m.bit_count()
# lowbit digit for singletons
LOWDIG = [0] * 512
for d in range(1, 10):
LOWDIG[BIT[d]] = d
# Digits list for each mask
DIGITS = [[] for _ in range(512)]
for m in range(512):
mm = m
ds = []
while mm:
lb = mm & -mm
ds.append(LOWDIG[lb])
mm ^= lb
DIGITS[m] = ds
# Unit-digit to cell index map: UNIT_POS[u][d] = position (0..8) of digit d in that unit
UNIT_POS = [[[0] * 10 for _ in range(9)] for __ in range(27)]
for ui, u in enumerate(UNITS):
pos = UNIT_POS[ui]
for p, cell in enumerate(u):
r = pos[p]
for p, cell in enumerate(u):
pos[p][0] = cell # not used; keep shape
for p, cell in enumerate(u):
# not used directly
pass
# Instead store a fast map: POS_IN_UNIT[ui][cell] = index 0..8
POS_IN_UNIT = [[0] * 81 for _ in range(27)]
for ui, u in enumerate(UNITS):
pm = POS_IN_UNIT[ui]
for p, cell in enumerate(u):
pm[cell] = p
def solve(puzzle_str: str) -> str:
# Localize globals for speed
PEERS_l = PEERS
UNITS_l = UNITS
CELL_UNITS_l = CELL_UNITS
R_l = R_IDX
C_l = C_IDX
B_l = B_IDX
BIT_l = BIT
ALL_l = ALL
POPC_l = POPC
LOWDIG_l = LOWDIG
DIGITS_l = DIGITS
# State
grid = [0] * 81
row_mask = [0] * 9
col_mask = [0] * 9
box_mask = [0] * 9
empties = []
for i, ch in enumerate(puzzle_str):
if ch == '.':
empties.append(i)
continue
d = ord(ch) - 48
if d < 1 or d > 9:
empties.append(i)
continue
r = R_l[i]
c = C_l[i]
b = B_l[i]
bm = BIT_l[d]
if (row_mask[r] | col_mask[c] | box_mask[b]) & bm:
return puzzle_str
grid[i] = d
row_mask[r] |= bm
col_mask[c] |= bm
box_mask[b] |= bm
cand = [0] * 81
for i in empties:
r = R_l[i]
c = C_l[i]
b = B_l[i]
m = ALL_l & ~(row_mask[r] | col_mask[c] | box_mask[b])
if m == 0:
return puzzle_str
cand[i] = m
# Trail encoding: (typ, a, v) where typ is small int
T_CELL = 0
T_CAND = 1
T_RM = 2
T_CM = 3
T_BM = 4
def assign(i: int, d: int, trail) -> bool:
r = R_l[i]
c = C_l[i]
b = B_l[i]
bm = BIT_l[d]
if (row_mask[r] | col_mask[c] | box_mask[b]) & bm:
return False
trail.append((T_CELL, i, grid[i]))
grid[i] = d
trail.append((T_RM, r, row_mask[r]))
trail.append((T_CM, c, col_mask[c]))
trail.append((T_BM, b, box_mask[b]))
row_mask[r] |= bm
col_mask[c] |= bm
box_mask[b] |= bm
if cand[i]:
trail.append((T_CAND, i, cand[i]))
cand[i] = 0
for p in PEERS_l[i]:
if grid[p]:
continue
pm = cand[p]
if pm & bm:
newm = pm & ~bm
if newm == 0:
trail.append((T_CAND, p, pm))
cand[p] = 0
return False
trail.append((T_CAND, p, pm))
cand[p] = newm
return True
# Fast hidden singles with bit accumulation (no per-cell digit iteration)
def hidden_singles(trail) -> bool:
for unit in UNITS_l:
union = 0
for i in unit:
if not grid[i]:
union |= cand[i]
if not union:
continue
for d in DIGITS_l[union]:
bm = BIT_l[d]
found = -1
for i in unit:
if grid[i] == 0 and (cand[i] & bm):
if found != -1:
found = -2
break
found = i
if found >= 0:
if not assign(found, d, trail):
return False
return True
# Propagate using an explicit queue for singles; avoid rescanning empties repeatedly
def propagate(trail) -> bool:
q = []
# seed singles
for i in empties:
if grid[i] == 0:
m = cand[i]
if m == 0:
return False
if (m & (m - 1)) == 0:
q.append(i)
while True:
while q:
i = q.pop()
if grid[i]:
continue
m = cand[i]
if m == 0:
return False
if (m & (m - 1)) != 0:
continue
d = LOWDIG_l[m]
if d == 0:
# should not happen
d = (m.bit_length() - 1) + 1
# Assign; track which peers may become singles
r = R_l[i]
c = C_l[i]
b = B_l[i]
bm = BIT_l[d]
if (row_mask[r] | col_mask[c] | box_mask[b]) & bm:
return False
trail.append((T_CELL, i, grid[i]))
grid[i] = d
trail.append((T_RM, r, row_mask[r]))
trail.append((T_CM, c, col_mask[c]))
trail.append((T_BM, b, box_mask[b]))
row_mask[r] |= bm
col_mask[c] |= bm
box_mask[b] |= bm
if cand[i]:
trail.append((T_CAND, i, cand[i]))
cand[i] = 0
for p in PEERS_l[i]:
if grid[p]:
continue
pm = cand[p]
if pm & bm:
newm = pm & ~bm
trail.append((T_CAND, p, pm))
cand[p] = newm
if newm == 0:
return False
if (newm & (newm - 1)) == 0:
q.append(p)
# Hidden singles pass; if it adds assignments, it will create new naked singles through elimination
before = len(trail)
if not hidden_singles(trail):
return False
if len(trail) == before:
return True
# seed any new singles created by hidden singles assignments
for i in empties:
if grid[i] == 0:
m = cand[i]
if m == 0:
return False
if (m & (m - 1)) == 0:
q.append(i)
def undo(trail):
for typ, a, v in reversed(trail):
if typ == T_CELL:
grid[a] = v
elif typ == T_CAND:
cand[a] = v
elif typ == T_RM:
row_mask[a] = v
elif typ == T_CM:
col_mask[a] = v
else: # T_BM
box_mask[a] = v
init_trail = []
if not propagate(init_trail):
undo(init_trail)
def choose_cell():
best_i = -1
best_n = 10
best_m = 0
for i in empties:
if grid[i]:
continue
m = cand[i]
n = POPC_l[m]
if n < best_n:
best_n = n
best_i = i
best_m = m
if n <= 2:
break
return best_i, best_m, best_n
def dfs():
i, m, n = choose_cell()
if i == -1:
return True
# Least-constraining-ish: try digits with minimal immediate peer hits (cheap)
ds = DIGITS_l[m]
if n <= 2:
order = ds
else:
scores = []
peers = PEERS_l[i]
for d in ds:
bm = BIT_l[d]
hit = 0
for p in peers:
if grid[p] == 0 and (cand[p] & bm):
hit += 1
scores.append((hit, d))
scores.sort()
order = [d for _, d in scores]
for d in order:
trail = []
if assign(i, d, trail) and propagate(trail) and dfs():
return True
undo(trail)
return False
if not dfs():
return puzzle_str
out = ['.'] * 81
for i, v in enumerate(grid):
out[i] = chr(48 + v) if v else '.'
return "".join(out)
# Test the solver
if __name__ == "__main__":
# Standard benchmark puzzle
puzzle = "53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79"
expected = "534678912672195348198342567859761423426853791713924856961537284287419635345286179"
result = solve(puzzle)
print(f"Input: {puzzle}")
print(f"Output: {result}")
print(f"Expected: {expected}")
print(f"Correct: {result == expected}")
# Harder puzzle (0 and . both treated as empty)
puzzle2 = ".00000010400000000020000000000050407008000300001090000300400200050100000000806000"
result2 = solve(puzzle2)
print(f"\nHard puzzle output: {result2}")
print(f"Hard puzzle solved: {len(result2) == 81 and '.' not in result2 and '0' not in result2}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment