Skip to content

Instantly share code, notes, and snippets.

@girving
Created November 14, 2025 22:38
Show Gist options
  • Select an option

  • Save girving/db7a31fe4df064f4a9aba76927cab14a to your computer and use it in GitHub Desktop.

Select an option

Save girving/db7a31fe4df064f4a9aba76927cab14a to your computer and use it in GitHub Desktop.
Dyadic tree data structure for Karatsuba multiplication
import Series.Misc.Polynomial
/-!
# Dyadic trees to fit Karatsuba multiplication
-/
open Polynomial (C X)
open scoped Polynomial
namespace Series
variable {α β 𝕜 : Type} {n m : ℕ}
/-!
### Definitions
-/
/-- A dyadic tree, with size at most `2 ^ n` -/
inductive Tree (α : Type) : ℕ → Type where
| leaf : α → Tree α 0
| left : {n : ℕ} → Tree α n → Tree α (n + 1)
| both : {n : ℕ} → Tree α n → Tree α n → Tree α (n + 1)
namespace Tree
/-!
### Size (number of possibly nonzero terms)
-/
/-- Number of possibly nonzero terms -/
def size {n : ℕ} (x : Tree α n) : ℕ := match n, x with
| 0, leaf _ => 1
| _ + 1, left x => x.size
| n + 1, both _ y => 2 ^ n + y.size
@[simp] lemma size_leaf (x : α) : (leaf x).size = 1 := rfl
@[simp] lemma size_left (x : Tree α n) : (left x).size = x.size := rfl
@[simp] lemma size_both (x y : Tree α n) : (both x y).size = 2 ^ n + y.size := rfl
/-- Trees have size at most `2 ^ n` -/
lemma size_le_pow {n : ℕ} (x : Tree α n) : x.size ≤ 2 ^ n := by
induction' x
all_goals simp_all [pow_add] <;> omega
/-!
### Polynomial expansions
-/
noncomputable def poly [Semiring α] {n : ℕ} (x : Tree α n) : α[X] := match n, x with
| 0, leaf x => C x
| _ + 1, left x => x.poly
| n + 1, both x y => x.poly + y.poly * X ^ (2 ^ n)
@[simp] lemma poly_leaf [Semiring α] (x : α) : (leaf x).poly = C x := rfl
@[simp] lemma poly_left [Semiring α] (x : Tree α n) : (left x).poly = x.poly := rfl
@[simp] lemma poly_both [Semiring α] {n : ℕ} (x y : Tree α n) :
(both x y).poly = x.poly + y.poly * X ^ (2 ^ n) := rfl
@[simp] lemma poly_cast [Semiring α] (e : n = m) (x : Tree α n) : (e ▸ x).poly = x.poly := by
subst e; rfl
@[simp] lemma poly_cast_fun [Semiring α] {f : ℕ → ℕ} (e : n = m) (x : Tree α (f n)) :
(e ▸ x).poly = x.poly := by
subst e; rfl
/-!
### Addition
-/
def adds [Add α] {n : ℕ} (x y : Tree α n) : Tree α n := match x, y with
| leaf x, leaf y => leaf (x + y)
| left x, left y => left (adds x y)
| left x, both y z => both (adds x y) z
| both x y, left z => both (adds x z) y
| both x y, both z w => both (adds x z) (adds y w)
def add_le [Add α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : Tree α (n + m) :=
match m, y with
| 0, y => x.adds y
| m + 1, left y => left (add_le x y)
| m + 1, both y z => both (add_le x y) z
def add [Add α] (x : Tree α n) (y : Tree α m) : Tree α (max n m) :=
if le : n ≤ m then way x y le else max_comm n m ▸ way y x (not_le.mp le).le where
way {n m : ℕ} (x : Tree α n) (y : Tree α m) (le : n ≤ m) : Tree α (max n m) :=
let s := Nat.add_sub_cancel' le
let e := s.trans (max_eq_right le).symm
e ▸ add_le x (s.symm.ndrec y)
@[simp] lemma poly_adds [CommRing α] {n : ℕ} (x y : Tree α n) :
(x.adds y).poly = x.poly + y.poly := by
induction' x
all_goals cases y
all_goals simp_all [adds] <;> ring
@[simp] lemma poly_add_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) :
(x.add_le y).poly = x.poly + y.poly := by
induction' m
· simp [add_le]
· cases y
all_goals simp_all [add_le] <;> ring
@[simp] lemma poly_add [CommRing α] (x : Tree α n) (y : Tree α m) :
(x.add y).poly = x.poly + y.poly := by
simp only [add, add.way]
split_ifs with h
all_goals simp only [poly_cast, poly_add_le, add_comm]
/-!
### Negation
-/
def neg [Neg α] {n : ℕ} (x : Tree α n) : Tree α n := match x with
| leaf x => leaf (-x)
| left x => left (neg x)
| both x y => both (neg x) (neg y)
@[simp] lemma poly_neg [CommRing α] (x : Tree α n) : x.neg.poly = -x.poly := by
induction' x
all_goals simp_all [neg] <;> ring
/-!
### Subtraction
-/
def subs [Sub α] [Neg α] {n : ℕ} (x y : Tree α n) : Tree α n := match x, y with
| leaf x, leaf y => leaf (x - y)
| left x, left y => left (subs x y)
| left x, both y z => both (subs x y) (neg z)
| both x y, left z => both (subs x z) y
| both x y, both z w => both (subs x z) (subs y w)
def sub_le [Sub α] [Neg α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : Tree α (n + m) :=
match m, y with
| 0, y => x.subs y
| m + 1, left y => left (sub_le x y)
| m + 1, both y z => both (sub_le x y) (neg z)
def sub_ge [Sub α] [Neg α] {k : ℕ} (x : Tree α (m + k)) (y : Tree α m) : Tree α (m + k) :=
match k, x with
| 0, x => x.subs y
| k + 1, left x => left (sub_ge x y)
| k + 1, both x z => both (sub_ge x y) z
def sub [Sub α] [Neg α] (x : Tree α n) (y : Tree α m) : Tree α (max n m) :=
if le : n ≤ m then
let s := Nat.add_sub_cancel' le
let e := s.trans (max_eq_right le).symm
e ▸ sub_le x (s.symm.ndrec y)
else
let ge := (not_le.mp le).le
let s := Nat.add_sub_cancel' ge
let e := s.trans (max_eq_left ge).symm
e ▸ sub_ge (s.symm.ndrec x) y
@[simp] lemma poly_subs [CommRing α] {n : ℕ} (x y : Tree α n) :
(x.subs y).poly = x.poly - y.poly := by
induction' x
all_goals cases y
all_goals simp_all [subs] <;> ring
@[simp] lemma poly_sub_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) :
(x.sub_le y).poly = x.poly - y.poly := by
induction' m
· simp [sub_le]
· cases y
all_goals simp_all [sub_le] <;> ring
@[simp] lemma poly_sub_ge [CommRing α] {k : ℕ} (x : Tree α (m + k)) (y : Tree α m) :
(x.sub_ge y).poly = x.poly - y.poly := by
induction' k
· simp [sub_ge]
· cases x
all_goals simp_all [sub_ge] <;> ring
@[simp] lemma poly_sub [CommRing α] (x : Tree α n) (y : Tree α m) :
(x.sub y).poly = x.poly - y.poly := by
unfold sub
split_ifs with h
· simp only [poly_cast, poly_sub_le]
· simp only [poly_cast, poly_sub_ge]
/-!
### Scalar multiplication
-/
def smul [SMul α β] (s : α) {n : ℕ} (x : Tree β n) : Tree β n := match x with
| leaf x => leaf (s • x)
| left x => left (smul s x)
| both x y => both (smul s x) (smul s y)
@[simp] lemma poly_smul [Semiring α] [CommRing β] [Module α β] [IsScalarTower α β β]
(s : α) (x : Tree β n) : (smul s x).poly = s • x.poly := by
induction' x
all_goals simp_all [smul, Polynomial.smul_eq_C_smul] <;> ring_nf
rw [← smul_one_mul, Polynomial.C_mul]
/-!
### Karatsuba multiplication
-/
/-- Add two trees where the second is shifted: p0 + p1*X^(2^n) -/
def add_shift [Add α] (p0 p1 : Tree α (n + 1)) : Tree α (n + 2) :=
match p0, p1 with
| left p0, left p1 => left (both p0 p1)
| left p0, both p10 p11 => both (both p0 p10) (left p11)
| both p00 p01, left p1 => left (both p00 (p01.adds p1))
| both p00 p01, both p10 p11 => both (both p00 (p01.adds p10)) (left p11)
/-- Combine three products for Karatsuba: p0 + p1*X^(2^n) + p2*X^(2^(n+1)) -/
def add_karatsuba [Add α] (p0 p1 p2 : Tree α (n + 1)) : Tree α (n + 2) :=
match p0, p1, p2 with
| left p0, left p1, left p2 => both (both p0 p1) (left p2)
| left p0, left p1, both p20 p21 => both (both p0 p1) (both p20 p21)
| left p0, both p10 p11, left p2 => both (both p0 p10) (left (p11.adds p2))
| left p0, both p10 p11, both p20 p21 => both (both p0 p10) (both (p11.adds p20) p21)
| both p00 p01, left p1, left p2 => both (both p00 (p01.adds p1)) (left p2)
| both p00 p01, left p1, both p20 p21 => both (both p00 (p01.adds p1)) (both p20 p21)
| both p00 p01, both p10 p11, left p2 => both (both p00 (p01.adds p10)) (left (p11.adds p2))
| both p00 p01, both p10 p11, both p20 p21 =>
both (both p00 (p01.adds p10)) (both (p11.adds p20) p21)
/-- Multiply two trees of the same depth using Karatsuba algorithm -/
def muls [Zero α] [Add α] [Sub α] [Neg α] [Mul α] {n : ℕ} (x y : Tree α n) : Tree α (n + 1) :=
match x, y with
| leaf x, leaf y => left (leaf (x * y))
| left x, left y => left (muls x y)
| left x, both y z => add_shift (muls x y) (muls x z)
| both x y, left z => add_shift (muls x z) (muls y z)
| both x0 x1, both y0 y1 =>
let p0 := muls x0 y0
let p2 := muls x1 y1
let pm := muls (x0.adds x1) (y0.adds y1)
-- Karatsuba: p1 = pm - (p0 + p2) = x0*y1 + x1*y0
let p1 := pm.subs (p0.adds p2)
add_karatsuba p0 p1 p2
/-- Multiply two trees of potentially different depths -/
def mul_le [Zero α] [Add α] [Sub α] [Neg α] [Mul α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) :
Tree α (n + m + 1) :=
match m, y with
| 0, y => x.muls y
| m + 1, left y => left (mul_le x y)
| m + 1, both y z => add_shift (mul_le x y) (mul_le x z)
lemma mul_nm (le : n ≤ m) : n + (m - n) + 1 = max n m + 1 := by omega
/-- Karatsuba multiplication of two trees -/
def mul [Zero α] [Add α] [Sub α] [Neg α] [Mul α] (x : Tree α n) (y : Tree α m) :
Tree α (max n m + 1) :=
if le : n ≤ m then way x y le else max_comm n m ▸ way y x (not_le.mp le).le where
way {n m : ℕ} (x : Tree α n) (y : Tree α m) (le : n ≤ m) : Tree α (max n m + 1) :=
mul_nm le ▸ x.mul_le ((Nat.add_sub_cancel' le).symm ▸ y)
@[simp] lemma poly_add_shift [CommRing α] (p0 p1 : Tree α (n + 1)) :
(add_shift p0 p1).poly = p0.poly + p1.poly * X ^ (2 ^ n) := by
cases p0 <;> cases p1
all_goals simp [add_shift] <;> ring
@[simp] lemma poly_add_karatsuba [CommRing α] (p0 p1 p2 : Tree α (n + 1)) :
(add_karatsuba p0 p1 p2).poly =
p0.poly + p1.poly * X ^ (2 ^ n) + p2.poly * X ^ (2 ^ (n + 1)) := by
cases p0 <;> cases p1 <;> cases p2
all_goals simp [add_karatsuba] <;> ring
@[simp] lemma poly_muls [CommRing α] (x y : Tree α n) : (x.muls y).poly = x.poly * y.poly := by
induction' n
all_goals cases x <;> cases y
all_goals simp_all [muls] <;> ring
@[simp] lemma poly_mul_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) :
(x.mul_le y).poly = x.poly * y.poly := by
induction' m
· simp [mul_le]
· cases y
all_goals simp_all [mul_le] <;> ring
@[simp] lemma poly_mul [CommRing α] (x : Tree α n) (y : Tree α m) :
(x.mul y).poly = x.poly * y.poly := by
simp only [mul, mul.way]
split_ifs with h
all_goals simp [mul_comm]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment