Skip to content

Instantly share code, notes, and snippets.

@datokrat
Last active March 6, 2026 12:00
Show Gist options
  • Select an option

  • Save datokrat/e0cfb7d56eedb1bc4f37cd71427371b1 to your computer and use it in GitHub Desktop.

Select an option

Save datokrat/e0cfb7d56eedb1bc4f37cd71427371b1 to your computer and use it in GitHub Desktop.
simp set minimization tactic (LLM-generated)
module
public meta import Lean
open Lean Elab Tactic Meta Simp Tactic.TryThis
namespace MinimizeSimps
/-- Check if a syntax node is a `simp only` call (simp, simp_all, or dsimp with `only`). -/
private meta def isSimpOnlyNode (stx : Syntax) : Bool :=
let k := stx.getKind
(k == ``Parser.Tactic.simp || k == ``Parser.Tactic.simpAll || k == ``Parser.Tactic.dsimp) &&
!stx[simpOnlyPos].isNone
/-- Count `simp only` occurrences in a syntax tree (DFS order). -/
private meta partial def countSimpOnlyNodes (stx : Syntax) : Nat :=
go stx 0
where
go (stx : Syntax) (count : Nat) : Nat :=
let count := if isSimpOnlyNode stx then count + 1 else count
match stx with
| .node _ _ args => args.foldl (fun c arg => go arg c) count
| _ => count
/-- Get the nth `simp only` node in DFS order. -/
private meta partial def getNthSimpOnly (stx : Syntax) (n : Nat) : Option Syntax :=
(go stx n).2
where
go (stx : Syntax) (remaining : Nat) : Nat × Option Syntax :=
if isSimpOnlyNode stx then
if remaining == 0 then (0, some stx)
else
match stx with
| .node _ _ args => goArgs args (remaining - 1)
| _ => (remaining - 1, none)
else
match stx with
| .node _ _ args => goArgs args remaining
| _ => (remaining, none)
goArgs (args : Array Syntax) (remaining : Nat) : Nat × Option Syntax :=
args.foldl (fun (remaining, found?) arg =>
if found?.isSome then (remaining, found?)
else go arg remaining
) (remaining, none)
/-- Replace the nth `simp only` node in DFS order with `replacement`.
Uses `Option Nat`: `some n` = still looking (n to skip), `none` = already replaced. -/
private meta partial def replaceNthSimpOnly (stx : Syntax) (n : Nat) (replacement : Syntax) : Syntax :=
(go stx (some n)).2
where
go (stx : Syntax) (remaining? : Option Nat) : Option Nat × Syntax :=
match remaining? with
| none => (none, stx) -- Already replaced, done
| some remaining =>
if isSimpOnlyNode stx then
if remaining == 0 then (none, replacement)
else (some (remaining - 1), stx)
else
match stx with
| .node info kind args =>
let (remaining', args') := args.foldl (fun (rem, acc) arg =>
let (rem', arg') := go arg rem
(rem', acc.push arg')
) (some remaining, #[])
(remaining', .node info kind args')
| other => (some remaining, other)
/-- Try evaluating a tactic sequence. Returns true if it succeeds and closes all goals.
Uses `withoutRecover` to prevent `·` blocks from silently admitting failed goals. -/
private meta def tryTacticBlock (savedState : Tactic.SavedState) (tacSeq : Syntax) : TacticM Bool := do
let currState ← Tactic.saveState
savedState.restore (restoreInfo := true)
try
withoutRecover <| evalTactic tacSeq
let goals ← getGoals
currState.restore (restoreInfo := true)
return goals.isEmpty
catch _ =>
currState.restore (restoreInfo := true)
return false
/--
For the `idx`th `simp only` call, try removing each lemma and check if the full
tactic block still succeeds. Returns the minimal set of params.
-/
private meta def minimizeSingleSimp (savedState : Tactic.SavedState) (fullTacSeq : Syntax)
(simpNode : Syntax) (idx : Nat) : TacticM (Array Syntax) := do
let params := getSimpParams ⟨simpNode⟩
if params.isEmpty then return params
-- First, try removing all lemmas at once
let emptySimp := (setSimpParams ⟨simpNode⟩ #[]).raw
let modifiedSeq := replaceNthSimpOnly fullTacSeq idx emptySimp
if (← tryTacticBlock savedState modifiedSeq) then
return #[] -- None needed!
-- Greedy removal: try removing each lemma from back to front, accumulating successful removals
let mut currentParams := params
let mut i := currentParams.size
while 0 < i do
i := i - 1
let candidateParams := currentParams.eraseIdxIfInBounds i
let candidateSimp := (setSimpParams ⟨simpNode⟩ candidateParams).raw
let candidateSeq := replaceNthSimpOnly fullTacSeq idx candidateSimp
if (← tryTacticBlock savedState candidateSeq) then
currentParams := candidateParams
return currentParams
/-- The core implementation of `minimize_simps`. -/
meta def minimizeSimpsCore (tk : Syntax) (tacSeq : Syntax) : TacticM Unit := do
-- First, run the full block to make sure it works
let savedState ← Tactic.saveState
withoutRecover <| evalTactic tacSeq
let resultGoals ← getGoals
unless resultGoals.isEmpty do
throwError "minimize_simps: the tactic block must close all goals"
-- Now minimize
savedState.restore
let numSimps := countSimpOnlyNodes tacSeq
if numSimps == 0 then
logWarningAt tk "minimize_simps: no `simp only` calls found in the tactic block"
evalTactic tacSeq
return
let mut currentTacSeq := tacSeq
let mut anyChanged := false
for idx in [:numSimps] do
let some simpNode := getNthSimpOnly currentTacSeq idx | continue
let origParams := getSimpParams ⟨simpNode⟩
let minParams ← minimizeSingleSimp savedState currentTacSeq simpNode idx
if minParams.size < origParams.size then
anyChanged := true
let newSimp := (setSimpParams ⟨simpNode⟩ minParams).raw
currentTacSeq := replaceNthSimpOnly currentTacSeq idx newSimp
-- Actually execute the minimized block
savedState.restore
evalTactic currentTacSeq
-- Emit suggestion
if anyChanged then
addSuggestion tk (TSyntax.mk currentTacSeq : TSyntax `tactic) (origSpan? := ← getRef)
syntax (name := minimizeSimps) "minimize_simps " " => " tacticSeq : tactic
end MinimizeSimps
open MinimizeSimps in
elab_rules : tactic
| `(tactic| minimize_simps => $tacSeq) => do
minimizeSimpsCore (← getRef) tacSeq
-- ============ Tests ============
set_option linter.unusedSimpArgs false
-- Test 1: Remove unused lemma from a terminal simp
example (n : Nat) (h : n = 0) : n = 0 := by
minimize_simps =>
simp only [h, Nat.add_zero]
-- Test 2: Keep only needed lemmas
example (n m : Nat) (h : n = 0) : n + m = m := by
minimize_simps =>
simp only [h, Nat.add_zero, Nat.zero_add]
-- Test 3: Multiple simp calls, each minimized independently
example (n m : Nat) (h1 : n = 0) (h2 : m = 0) : n = 0 ∧ m = 0 := by
minimize_simps =>
constructor
· simp only [h1, h2, Nat.add_zero]
· simp only [h1, h2, Nat.add_zero]
-- Test 4: dsimp only - both lemmas removable since dsimp handles it definitionally
example : (fun x : Nat => x) 0 = 0 + 0 := by
minimize_simps =>
dsimp only [Nat.add_zero, Nat.zero_add]
-- Test 5: Nonterminal simp where subsequent tactic doesn't need all simplifications
-- Nat.zero_add is removable because ring can close the goal regardless
example (n : Nat) (h : n = 0) : n * n = 0 := by
minimize_simps =>
simp only [h, Nat.zero_add, Nat.mul_zero]
-- Test 6: All lemmas needed (no suggestion expected)
example (a b : Nat) (h1 : a = 1) (h2 : b = 2) : a + b = 3 := by
minimize_simps =>
simp only [h1, h2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment