Last active
March 6, 2026 12:00
-
-
Save datokrat/e0cfb7d56eedb1bc4f37cd71427371b1 to your computer and use it in GitHub Desktop.
simp set minimization tactic (LLM-generated)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module | |
| public meta import Lean | |
| open Lean Elab Tactic Meta Simp Tactic.TryThis | |
| namespace MinimizeSimps | |
| /-- Check if a syntax node is a `simp only` call (simp, simp_all, or dsimp with `only`). -/ | |
| private meta def isSimpOnlyNode (stx : Syntax) : Bool := | |
| let k := stx.getKind | |
| (k == ``Parser.Tactic.simp || k == ``Parser.Tactic.simpAll || k == ``Parser.Tactic.dsimp) && | |
| !stx[simpOnlyPos].isNone | |
| /-- Count `simp only` occurrences in a syntax tree (DFS order). -/ | |
| private meta partial def countSimpOnlyNodes (stx : Syntax) : Nat := | |
| go stx 0 | |
| where | |
| go (stx : Syntax) (count : Nat) : Nat := | |
| let count := if isSimpOnlyNode stx then count + 1 else count | |
| match stx with | |
| | .node _ _ args => args.foldl (fun c arg => go arg c) count | |
| | _ => count | |
| /-- Get the nth `simp only` node in DFS order. -/ | |
| private meta partial def getNthSimpOnly (stx : Syntax) (n : Nat) : Option Syntax := | |
| (go stx n).2 | |
| where | |
| go (stx : Syntax) (remaining : Nat) : Nat × Option Syntax := | |
| if isSimpOnlyNode stx then | |
| if remaining == 0 then (0, some stx) | |
| else | |
| match stx with | |
| | .node _ _ args => goArgs args (remaining - 1) | |
| | _ => (remaining - 1, none) | |
| else | |
| match stx with | |
| | .node _ _ args => goArgs args remaining | |
| | _ => (remaining, none) | |
| goArgs (args : Array Syntax) (remaining : Nat) : Nat × Option Syntax := | |
| args.foldl (fun (remaining, found?) arg => | |
| if found?.isSome then (remaining, found?) | |
| else go arg remaining | |
| ) (remaining, none) | |
| /-- Replace the nth `simp only` node in DFS order with `replacement`. | |
| Uses `Option Nat`: `some n` = still looking (n to skip), `none` = already replaced. -/ | |
| private meta partial def replaceNthSimpOnly (stx : Syntax) (n : Nat) (replacement : Syntax) : Syntax := | |
| (go stx (some n)).2 | |
| where | |
| go (stx : Syntax) (remaining? : Option Nat) : Option Nat × Syntax := | |
| match remaining? with | |
| | none => (none, stx) -- Already replaced, done | |
| | some remaining => | |
| if isSimpOnlyNode stx then | |
| if remaining == 0 then (none, replacement) | |
| else (some (remaining - 1), stx) | |
| else | |
| match stx with | |
| | .node info kind args => | |
| let (remaining', args') := args.foldl (fun (rem, acc) arg => | |
| let (rem', arg') := go arg rem | |
| (rem', acc.push arg') | |
| ) (some remaining, #[]) | |
| (remaining', .node info kind args') | |
| | other => (some remaining, other) | |
| /-- Try evaluating a tactic sequence. Returns true if it succeeds and closes all goals. | |
| Uses `withoutRecover` to prevent `·` blocks from silently admitting failed goals. -/ | |
| private meta def tryTacticBlock (savedState : Tactic.SavedState) (tacSeq : Syntax) : TacticM Bool := do | |
| let currState ← Tactic.saveState | |
| savedState.restore (restoreInfo := true) | |
| try | |
| withoutRecover <| evalTactic tacSeq | |
| let goals ← getGoals | |
| currState.restore (restoreInfo := true) | |
| return goals.isEmpty | |
| catch _ => | |
| currState.restore (restoreInfo := true) | |
| return false | |
| /-- | |
| For the `idx`th `simp only` call, try removing each lemma and check if the full | |
| tactic block still succeeds. Returns the minimal set of params. | |
| -/ | |
| private meta def minimizeSingleSimp (savedState : Tactic.SavedState) (fullTacSeq : Syntax) | |
| (simpNode : Syntax) (idx : Nat) : TacticM (Array Syntax) := do | |
| let params := getSimpParams ⟨simpNode⟩ | |
| if params.isEmpty then return params | |
| -- First, try removing all lemmas at once | |
| let emptySimp := (setSimpParams ⟨simpNode⟩ #[]).raw | |
| let modifiedSeq := replaceNthSimpOnly fullTacSeq idx emptySimp | |
| if (← tryTacticBlock savedState modifiedSeq) then | |
| return #[] -- None needed! | |
| -- Greedy removal: try removing each lemma from back to front, accumulating successful removals | |
| let mut currentParams := params | |
| let mut i := currentParams.size | |
| while 0 < i do | |
| i := i - 1 | |
| let candidateParams := currentParams.eraseIdxIfInBounds i | |
| let candidateSimp := (setSimpParams ⟨simpNode⟩ candidateParams).raw | |
| let candidateSeq := replaceNthSimpOnly fullTacSeq idx candidateSimp | |
| if (← tryTacticBlock savedState candidateSeq) then | |
| currentParams := candidateParams | |
| return currentParams | |
| /-- The core implementation of `minimize_simps`. -/ | |
| meta def minimizeSimpsCore (tk : Syntax) (tacSeq : Syntax) : TacticM Unit := do | |
| -- First, run the full block to make sure it works | |
| let savedState ← Tactic.saveState | |
| withoutRecover <| evalTactic tacSeq | |
| let resultGoals ← getGoals | |
| unless resultGoals.isEmpty do | |
| throwError "minimize_simps: the tactic block must close all goals" | |
| -- Now minimize | |
| savedState.restore | |
| let numSimps := countSimpOnlyNodes tacSeq | |
| if numSimps == 0 then | |
| logWarningAt tk "minimize_simps: no `simp only` calls found in the tactic block" | |
| evalTactic tacSeq | |
| return | |
| let mut currentTacSeq := tacSeq | |
| let mut anyChanged := false | |
| for idx in [:numSimps] do | |
| let some simpNode := getNthSimpOnly currentTacSeq idx | continue | |
| let origParams := getSimpParams ⟨simpNode⟩ | |
| let minParams ← minimizeSingleSimp savedState currentTacSeq simpNode idx | |
| if minParams.size < origParams.size then | |
| anyChanged := true | |
| let newSimp := (setSimpParams ⟨simpNode⟩ minParams).raw | |
| currentTacSeq := replaceNthSimpOnly currentTacSeq idx newSimp | |
| -- Actually execute the minimized block | |
| savedState.restore | |
| evalTactic currentTacSeq | |
| -- Emit suggestion | |
| if anyChanged then | |
| addSuggestion tk (TSyntax.mk currentTacSeq : TSyntax `tactic) (origSpan? := ← getRef) | |
| syntax (name := minimizeSimps) "minimize_simps " " => " tacticSeq : tactic | |
| end MinimizeSimps | |
| open MinimizeSimps in | |
| elab_rules : tactic | |
| | `(tactic| minimize_simps => $tacSeq) => do | |
| minimizeSimpsCore (← getRef) tacSeq | |
| -- ============ Tests ============ | |
| set_option linter.unusedSimpArgs false | |
| -- Test 1: Remove unused lemma from a terminal simp | |
| example (n : Nat) (h : n = 0) : n = 0 := by | |
| minimize_simps => | |
| simp only [h, Nat.add_zero] | |
| -- Test 2: Keep only needed lemmas | |
| example (n m : Nat) (h : n = 0) : n + m = m := by | |
| minimize_simps => | |
| simp only [h, Nat.add_zero, Nat.zero_add] | |
| -- Test 3: Multiple simp calls, each minimized independently | |
| example (n m : Nat) (h1 : n = 0) (h2 : m = 0) : n = 0 ∧ m = 0 := by | |
| minimize_simps => | |
| constructor | |
| · simp only [h1, h2, Nat.add_zero] | |
| · simp only [h1, h2, Nat.add_zero] | |
| -- Test 4: dsimp only - both lemmas removable since dsimp handles it definitionally | |
| example : (fun x : Nat => x) 0 = 0 + 0 := by | |
| minimize_simps => | |
| dsimp only [Nat.add_zero, Nat.zero_add] | |
| -- Test 5: Nonterminal simp where subsequent tactic doesn't need all simplifications | |
| -- Nat.zero_add is removable because ring can close the goal regardless | |
| example (n : Nat) (h : n = 0) : n * n = 0 := by | |
| minimize_simps => | |
| simp only [h, Nat.zero_add, Nat.mul_zero] | |
| -- Test 6: All lemmas needed (no suggestion expected) | |
| example (a b : Nat) (h1 : a = 1) (h2 : b = 2) : a + b = 3 := by | |
| minimize_simps => | |
| simp only [h1, h2] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment