Created
February 21, 2026 16:52
-
-
Save vacmar01/c677bc96c96224dce77b926184fef9f7 to your computer and use it in GitHub Desktop.
Optimized Sudoku solver from GEPA blog post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Fast Sudoku solver: bitmask constraints + MRV + forward-checking | |
| # Drop-in replacement: solve(puzzle_str)->81-char string | |
| # --- Precompute structures (indices 0..80) --- | |
| ROWS = [list(range(r * 9, r * 9 + 9)) for r in range(9)] | |
| COLS = [list(range(c, 81, 9)) for c in range(9)] | |
| BOXS = [] | |
| for br in range(0, 9, 3): | |
| for bc in range(0, 9, 3): | |
| box = [] | |
| for r in range(br, br + 3): | |
| box.extend(range(r * 9 + bc, r * 9 + bc + 3)) | |
| BOXS.append(box) | |
| UNITS = ROWS + COLS + BOXS | |
| PEERS = [None] * 81 | |
| CELL_UNITS = [[] for _ in range(81)] | |
| for u in UNITS: | |
| for i in u: | |
| CELL_UNITS[i].append(u) | |
| for i in range(81): | |
| ps = set() | |
| for u in CELL_UNITS[i]: | |
| ps.update(u) | |
| ps.discard(i) | |
| PEERS[i] = tuple(ps) | |
| # Precompute row/col/box index for each cell | |
| R_IDX = [i // 9 for i in range(81)] | |
| C_IDX = [i % 9 for i in range(81)] | |
| B_IDX = [((i // 9) // 3) * 3 + ((i % 9) // 3) for i in range(81)] | |
| # Bit helpers | |
| ALL = 0x1FF # 9 bits | |
| BIT = [0] * 10 | |
| for d in range(1, 10): | |
| BIT[d] = 1 << (d - 1) | |
| # popcount for 0..511 | |
| POPC = [0] * 512 | |
| for m in range(512): | |
| POPC[m] = m.bit_count() | |
| # lowbit digit for singletons | |
| LOWDIG = [0] * 512 | |
| for d in range(1, 10): | |
| LOWDIG[BIT[d]] = d | |
| # Digits list for each mask | |
| DIGITS = [[] for _ in range(512)] | |
| for m in range(512): | |
| mm = m | |
| ds = [] | |
| while mm: | |
| lb = mm & -mm | |
| ds.append(LOWDIG[lb]) | |
| mm ^= lb | |
| DIGITS[m] = ds | |
| # Unit-digit to cell index map: UNIT_POS[u][d] = position (0..8) of digit d in that unit | |
| UNIT_POS = [[[0] * 10 for _ in range(9)] for __ in range(27)] | |
| for ui, u in enumerate(UNITS): | |
| pos = UNIT_POS[ui] | |
| for p, cell in enumerate(u): | |
| r = pos[p] | |
| for p, cell in enumerate(u): | |
| pos[p][0] = cell # not used; keep shape | |
| for p, cell in enumerate(u): | |
| # not used directly | |
| pass | |
| # Instead store a fast map: POS_IN_UNIT[ui][cell] = index 0..8 | |
| POS_IN_UNIT = [[0] * 81 for _ in range(27)] | |
| for ui, u in enumerate(UNITS): | |
| pm = POS_IN_UNIT[ui] | |
| for p, cell in enumerate(u): | |
| pm[cell] = p | |
| def solve(puzzle_str: str) -> str: | |
| # Localize globals for speed | |
| PEERS_l = PEERS | |
| UNITS_l = UNITS | |
| CELL_UNITS_l = CELL_UNITS | |
| R_l = R_IDX | |
| C_l = C_IDX | |
| B_l = B_IDX | |
| BIT_l = BIT | |
| ALL_l = ALL | |
| POPC_l = POPC | |
| LOWDIG_l = LOWDIG | |
| DIGITS_l = DIGITS | |
| # State | |
| grid = [0] * 81 | |
| row_mask = [0] * 9 | |
| col_mask = [0] * 9 | |
| box_mask = [0] * 9 | |
| empties = [] | |
| for i, ch in enumerate(puzzle_str): | |
| if ch == '.': | |
| empties.append(i) | |
| continue | |
| d = ord(ch) - 48 | |
| if d < 1 or d > 9: | |
| empties.append(i) | |
| continue | |
| r = R_l[i] | |
| c = C_l[i] | |
| b = B_l[i] | |
| bm = BIT_l[d] | |
| if (row_mask[r] | col_mask[c] | box_mask[b]) & bm: | |
| return puzzle_str | |
| grid[i] = d | |
| row_mask[r] |= bm | |
| col_mask[c] |= bm | |
| box_mask[b] |= bm | |
| cand = [0] * 81 | |
| for i in empties: | |
| r = R_l[i] | |
| c = C_l[i] | |
| b = B_l[i] | |
| m = ALL_l & ~(row_mask[r] | col_mask[c] | box_mask[b]) | |
| if m == 0: | |
| return puzzle_str | |
| cand[i] = m | |
| # Trail encoding: (typ, a, v) where typ is small int | |
| T_CELL = 0 | |
| T_CAND = 1 | |
| T_RM = 2 | |
| T_CM = 3 | |
| T_BM = 4 | |
| def assign(i: int, d: int, trail) -> bool: | |
| r = R_l[i] | |
| c = C_l[i] | |
| b = B_l[i] | |
| bm = BIT_l[d] | |
| if (row_mask[r] | col_mask[c] | box_mask[b]) & bm: | |
| return False | |
| trail.append((T_CELL, i, grid[i])) | |
| grid[i] = d | |
| trail.append((T_RM, r, row_mask[r])) | |
| trail.append((T_CM, c, col_mask[c])) | |
| trail.append((T_BM, b, box_mask[b])) | |
| row_mask[r] |= bm | |
| col_mask[c] |= bm | |
| box_mask[b] |= bm | |
| if cand[i]: | |
| trail.append((T_CAND, i, cand[i])) | |
| cand[i] = 0 | |
| for p in PEERS_l[i]: | |
| if grid[p]: | |
| continue | |
| pm = cand[p] | |
| if pm & bm: | |
| newm = pm & ~bm | |
| if newm == 0: | |
| trail.append((T_CAND, p, pm)) | |
| cand[p] = 0 | |
| return False | |
| trail.append((T_CAND, p, pm)) | |
| cand[p] = newm | |
| return True | |
| # Fast hidden singles with bit accumulation (no per-cell digit iteration) | |
| def hidden_singles(trail) -> bool: | |
| for unit in UNITS_l: | |
| union = 0 | |
| for i in unit: | |
| if not grid[i]: | |
| union |= cand[i] | |
| if not union: | |
| continue | |
| for d in DIGITS_l[union]: | |
| bm = BIT_l[d] | |
| found = -1 | |
| for i in unit: | |
| if grid[i] == 0 and (cand[i] & bm): | |
| if found != -1: | |
| found = -2 | |
| break | |
| found = i | |
| if found >= 0: | |
| if not assign(found, d, trail): | |
| return False | |
| return True | |
| # Propagate using an explicit queue for singles; avoid rescanning empties repeatedly | |
| def propagate(trail) -> bool: | |
| q = [] | |
| # seed singles | |
| for i in empties: | |
| if grid[i] == 0: | |
| m = cand[i] | |
| if m == 0: | |
| return False | |
| if (m & (m - 1)) == 0: | |
| q.append(i) | |
| while True: | |
| while q: | |
| i = q.pop() | |
| if grid[i]: | |
| continue | |
| m = cand[i] | |
| if m == 0: | |
| return False | |
| if (m & (m - 1)) != 0: | |
| continue | |
| d = LOWDIG_l[m] | |
| if d == 0: | |
| # should not happen | |
| d = (m.bit_length() - 1) + 1 | |
| # Assign; track which peers may become singles | |
| r = R_l[i] | |
| c = C_l[i] | |
| b = B_l[i] | |
| bm = BIT_l[d] | |
| if (row_mask[r] | col_mask[c] | box_mask[b]) & bm: | |
| return False | |
| trail.append((T_CELL, i, grid[i])) | |
| grid[i] = d | |
| trail.append((T_RM, r, row_mask[r])) | |
| trail.append((T_CM, c, col_mask[c])) | |
| trail.append((T_BM, b, box_mask[b])) | |
| row_mask[r] |= bm | |
| col_mask[c] |= bm | |
| box_mask[b] |= bm | |
| if cand[i]: | |
| trail.append((T_CAND, i, cand[i])) | |
| cand[i] = 0 | |
| for p in PEERS_l[i]: | |
| if grid[p]: | |
| continue | |
| pm = cand[p] | |
| if pm & bm: | |
| newm = pm & ~bm | |
| trail.append((T_CAND, p, pm)) | |
| cand[p] = newm | |
| if newm == 0: | |
| return False | |
| if (newm & (newm - 1)) == 0: | |
| q.append(p) | |
| # Hidden singles pass; if it adds assignments, it will create new naked singles through elimination | |
| before = len(trail) | |
| if not hidden_singles(trail): | |
| return False | |
| if len(trail) == before: | |
| return True | |
| # seed any new singles created by hidden singles assignments | |
| for i in empties: | |
| if grid[i] == 0: | |
| m = cand[i] | |
| if m == 0: | |
| return False | |
| if (m & (m - 1)) == 0: | |
| q.append(i) | |
| def undo(trail): | |
| for typ, a, v in reversed(trail): | |
| if typ == T_CELL: | |
| grid[a] = v | |
| elif typ == T_CAND: | |
| cand[a] = v | |
| elif typ == T_RM: | |
| row_mask[a] = v | |
| elif typ == T_CM: | |
| col_mask[a] = v | |
| else: # T_BM | |
| box_mask[a] = v | |
| init_trail = [] | |
| if not propagate(init_trail): | |
| undo(init_trail) | |
| def choose_cell(): | |
| best_i = -1 | |
| best_n = 10 | |
| best_m = 0 | |
| for i in empties: | |
| if grid[i]: | |
| continue | |
| m = cand[i] | |
| n = POPC_l[m] | |
| if n < best_n: | |
| best_n = n | |
| best_i = i | |
| best_m = m | |
| if n <= 2: | |
| break | |
| return best_i, best_m, best_n | |
| def dfs(): | |
| i, m, n = choose_cell() | |
| if i == -1: | |
| return True | |
| # Least-constraining-ish: try digits with minimal immediate peer hits (cheap) | |
| ds = DIGITS_l[m] | |
| if n <= 2: | |
| order = ds | |
| else: | |
| scores = [] | |
| peers = PEERS_l[i] | |
| for d in ds: | |
| bm = BIT_l[d] | |
| hit = 0 | |
| for p in peers: | |
| if grid[p] == 0 and (cand[p] & bm): | |
| hit += 1 | |
| scores.append((hit, d)) | |
| scores.sort() | |
| order = [d for _, d in scores] | |
| for d in order: | |
| trail = [] | |
| if assign(i, d, trail) and propagate(trail) and dfs(): | |
| return True | |
| undo(trail) | |
| return False | |
| if not dfs(): | |
| return puzzle_str | |
| out = ['.'] * 81 | |
| for i, v in enumerate(grid): | |
| out[i] = chr(48 + v) if v else '.' | |
| return "".join(out) | |
| # Test the solver | |
| if __name__ == "__main__": | |
| # Standard benchmark puzzle | |
| puzzle = "53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79" | |
| expected = "534678912672195348198342567859761423426853791713924856961537284287419635345286179" | |
| result = solve(puzzle) | |
| print(f"Input: {puzzle}") | |
| print(f"Output: {result}") | |
| print(f"Expected: {expected}") | |
| print(f"Correct: {result == expected}") | |
| # Harder puzzle (0 and . both treated as empty) | |
| puzzle2 = ".00000010400000000020000000000050407008000300001090000300400200050100000000806000" | |
| result2 = solve(puzzle2) | |
| print(f"\nHard puzzle output: {result2}") | |
| print(f"Hard puzzle solved: {len(result2) == 81 and '.' not in result2 and '0' not in result2}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment