Created
November 14, 2025 22:38
-
-
Save girving/db7a31fe4df064f4a9aba76927cab14a to your computer and use it in GitHub Desktop.
Dyadic tree data structure for Karatsuba multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Series.Misc.Polynomial | |
| /-! | |
| # Dyadic trees to fit Karatsuba multiplication | |
| -/ | |
| open Polynomial (C X) | |
| open scoped Polynomial | |
| namespace Series | |
| variable {α β 𝕜 : Type} {n m : ℕ} | |
| /-! | |
| ### Definitions | |
| -/ | |
| /-- A dyadic tree, with size at most `2 ^ n` -/ | |
| inductive Tree (α : Type) : ℕ → Type where | |
| | leaf : α → Tree α 0 | |
| | left : {n : ℕ} → Tree α n → Tree α (n + 1) | |
| | both : {n : ℕ} → Tree α n → Tree α n → Tree α (n + 1) | |
| namespace Tree | |
| /-! | |
| ### Size (number of possibly nonzero terms) | |
| -/ | |
| /-- Number of possibly nonzero terms -/ | |
| def size {n : ℕ} (x : Tree α n) : ℕ := match n, x with | |
| | 0, leaf _ => 1 | |
| | _ + 1, left x => x.size | |
| | n + 1, both _ y => 2 ^ n + y.size | |
| @[simp] lemma size_leaf (x : α) : (leaf x).size = 1 := rfl | |
| @[simp] lemma size_left (x : Tree α n) : (left x).size = x.size := rfl | |
| @[simp] lemma size_both (x y : Tree α n) : (both x y).size = 2 ^ n + y.size := rfl | |
| /-- Trees have size at most `2 ^ n` -/ | |
| lemma size_le_pow {n : ℕ} (x : Tree α n) : x.size ≤ 2 ^ n := by | |
| induction' x | |
| all_goals simp_all [pow_add] <;> omega | |
| /-! | |
| ### Polynomial expansions | |
| -/ | |
| noncomputable def poly [Semiring α] {n : ℕ} (x : Tree α n) : α[X] := match n, x with | |
| | 0, leaf x => C x | |
| | _ + 1, left x => x.poly | |
| | n + 1, both x y => x.poly + y.poly * X ^ (2 ^ n) | |
| @[simp] lemma poly_leaf [Semiring α] (x : α) : (leaf x).poly = C x := rfl | |
| @[simp] lemma poly_left [Semiring α] (x : Tree α n) : (left x).poly = x.poly := rfl | |
| @[simp] lemma poly_both [Semiring α] {n : ℕ} (x y : Tree α n) : | |
| (both x y).poly = x.poly + y.poly * X ^ (2 ^ n) := rfl | |
| @[simp] lemma poly_cast [Semiring α] (e : n = m) (x : Tree α n) : (e ▸ x).poly = x.poly := by | |
| subst e; rfl | |
| @[simp] lemma poly_cast_fun [Semiring α] {f : ℕ → ℕ} (e : n = m) (x : Tree α (f n)) : | |
| (e ▸ x).poly = x.poly := by | |
| subst e; rfl | |
| /-! | |
| ### Addition | |
| -/ | |
| def adds [Add α] {n : ℕ} (x y : Tree α n) : Tree α n := match x, y with | |
| | leaf x, leaf y => leaf (x + y) | |
| | left x, left y => left (adds x y) | |
| | left x, both y z => both (adds x y) z | |
| | both x y, left z => both (adds x z) y | |
| | both x y, both z w => both (adds x z) (adds y w) | |
| def add_le [Add α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : Tree α (n + m) := | |
| match m, y with | |
| | 0, y => x.adds y | |
| | m + 1, left y => left (add_le x y) | |
| | m + 1, both y z => both (add_le x y) z | |
| def add [Add α] (x : Tree α n) (y : Tree α m) : Tree α (max n m) := | |
| if le : n ≤ m then way x y le else max_comm n m ▸ way y x (not_le.mp le).le where | |
| way {n m : ℕ} (x : Tree α n) (y : Tree α m) (le : n ≤ m) : Tree α (max n m) := | |
| let s := Nat.add_sub_cancel' le | |
| let e := s.trans (max_eq_right le).symm | |
| e ▸ add_le x (s.symm.ndrec y) | |
| @[simp] lemma poly_adds [CommRing α] {n : ℕ} (x y : Tree α n) : | |
| (x.adds y).poly = x.poly + y.poly := by | |
| induction' x | |
| all_goals cases y | |
| all_goals simp_all [adds] <;> ring | |
| @[simp] lemma poly_add_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : | |
| (x.add_le y).poly = x.poly + y.poly := by | |
| induction' m | |
| · simp [add_le] | |
| · cases y | |
| all_goals simp_all [add_le] <;> ring | |
| @[simp] lemma poly_add [CommRing α] (x : Tree α n) (y : Tree α m) : | |
| (x.add y).poly = x.poly + y.poly := by | |
| simp only [add, add.way] | |
| split_ifs with h | |
| all_goals simp only [poly_cast, poly_add_le, add_comm] | |
| /-! | |
| ### Negation | |
| -/ | |
| def neg [Neg α] {n : ℕ} (x : Tree α n) : Tree α n := match x with | |
| | leaf x => leaf (-x) | |
| | left x => left (neg x) | |
| | both x y => both (neg x) (neg y) | |
| @[simp] lemma poly_neg [CommRing α] (x : Tree α n) : x.neg.poly = -x.poly := by | |
| induction' x | |
| all_goals simp_all [neg] <;> ring | |
| /-! | |
| ### Subtraction | |
| -/ | |
| def subs [Sub α] [Neg α] {n : ℕ} (x y : Tree α n) : Tree α n := match x, y with | |
| | leaf x, leaf y => leaf (x - y) | |
| | left x, left y => left (subs x y) | |
| | left x, both y z => both (subs x y) (neg z) | |
| | both x y, left z => both (subs x z) y | |
| | both x y, both z w => both (subs x z) (subs y w) | |
| def sub_le [Sub α] [Neg α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : Tree α (n + m) := | |
| match m, y with | |
| | 0, y => x.subs y | |
| | m + 1, left y => left (sub_le x y) | |
| | m + 1, both y z => both (sub_le x y) (neg z) | |
| def sub_ge [Sub α] [Neg α] {k : ℕ} (x : Tree α (m + k)) (y : Tree α m) : Tree α (m + k) := | |
| match k, x with | |
| | 0, x => x.subs y | |
| | k + 1, left x => left (sub_ge x y) | |
| | k + 1, both x z => both (sub_ge x y) z | |
| def sub [Sub α] [Neg α] (x : Tree α n) (y : Tree α m) : Tree α (max n m) := | |
| if le : n ≤ m then | |
| let s := Nat.add_sub_cancel' le | |
| let e := s.trans (max_eq_right le).symm | |
| e ▸ sub_le x (s.symm.ndrec y) | |
| else | |
| let ge := (not_le.mp le).le | |
| let s := Nat.add_sub_cancel' ge | |
| let e := s.trans (max_eq_left ge).symm | |
| e ▸ sub_ge (s.symm.ndrec x) y | |
| @[simp] lemma poly_subs [CommRing α] {n : ℕ} (x y : Tree α n) : | |
| (x.subs y).poly = x.poly - y.poly := by | |
| induction' x | |
| all_goals cases y | |
| all_goals simp_all [subs] <;> ring | |
| @[simp] lemma poly_sub_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : | |
| (x.sub_le y).poly = x.poly - y.poly := by | |
| induction' m | |
| · simp [sub_le] | |
| · cases y | |
| all_goals simp_all [sub_le] <;> ring | |
| @[simp] lemma poly_sub_ge [CommRing α] {k : ℕ} (x : Tree α (m + k)) (y : Tree α m) : | |
| (x.sub_ge y).poly = x.poly - y.poly := by | |
| induction' k | |
| · simp [sub_ge] | |
| · cases x | |
| all_goals simp_all [sub_ge] <;> ring | |
| @[simp] lemma poly_sub [CommRing α] (x : Tree α n) (y : Tree α m) : | |
| (x.sub y).poly = x.poly - y.poly := by | |
| unfold sub | |
| split_ifs with h | |
| · simp only [poly_cast, poly_sub_le] | |
| · simp only [poly_cast, poly_sub_ge] | |
| /-! | |
| ### Scalar multiplication | |
| -/ | |
| def smul [SMul α β] (s : α) {n : ℕ} (x : Tree β n) : Tree β n := match x with | |
| | leaf x => leaf (s • x) | |
| | left x => left (smul s x) | |
| | both x y => both (smul s x) (smul s y) | |
| @[simp] lemma poly_smul [Semiring α] [CommRing β] [Module α β] [IsScalarTower α β β] | |
| (s : α) (x : Tree β n) : (smul s x).poly = s • x.poly := by | |
| induction' x | |
| all_goals simp_all [smul, Polynomial.smul_eq_C_smul] <;> ring_nf | |
| rw [← smul_one_mul, Polynomial.C_mul] | |
| /-! | |
| ### Karatsuba multiplication | |
| -/ | |
| /-- Add two trees where the second is shifted: p0 + p1*X^(2^n) -/ | |
| def add_shift [Add α] (p0 p1 : Tree α (n + 1)) : Tree α (n + 2) := | |
| match p0, p1 with | |
| | left p0, left p1 => left (both p0 p1) | |
| | left p0, both p10 p11 => both (both p0 p10) (left p11) | |
| | both p00 p01, left p1 => left (both p00 (p01.adds p1)) | |
| | both p00 p01, both p10 p11 => both (both p00 (p01.adds p10)) (left p11) | |
| /-- Combine three products for Karatsuba: p0 + p1*X^(2^n) + p2*X^(2^(n+1)) -/ | |
| def add_karatsuba [Add α] (p0 p1 p2 : Tree α (n + 1)) : Tree α (n + 2) := | |
| match p0, p1, p2 with | |
| | left p0, left p1, left p2 => both (both p0 p1) (left p2) | |
| | left p0, left p1, both p20 p21 => both (both p0 p1) (both p20 p21) | |
| | left p0, both p10 p11, left p2 => both (both p0 p10) (left (p11.adds p2)) | |
| | left p0, both p10 p11, both p20 p21 => both (both p0 p10) (both (p11.adds p20) p21) | |
| | both p00 p01, left p1, left p2 => both (both p00 (p01.adds p1)) (left p2) | |
| | both p00 p01, left p1, both p20 p21 => both (both p00 (p01.adds p1)) (both p20 p21) | |
| | both p00 p01, both p10 p11, left p2 => both (both p00 (p01.adds p10)) (left (p11.adds p2)) | |
| | both p00 p01, both p10 p11, both p20 p21 => | |
| both (both p00 (p01.adds p10)) (both (p11.adds p20) p21) | |
| /-- Multiply two trees of the same depth using Karatsuba algorithm -/ | |
| def muls [Zero α] [Add α] [Sub α] [Neg α] [Mul α] {n : ℕ} (x y : Tree α n) : Tree α (n + 1) := | |
| match x, y with | |
| | leaf x, leaf y => left (leaf (x * y)) | |
| | left x, left y => left (muls x y) | |
| | left x, both y z => add_shift (muls x y) (muls x z) | |
| | both x y, left z => add_shift (muls x z) (muls y z) | |
| | both x0 x1, both y0 y1 => | |
| let p0 := muls x0 y0 | |
| let p2 := muls x1 y1 | |
| let pm := muls (x0.adds x1) (y0.adds y1) | |
| -- Karatsuba: p1 = pm - (p0 + p2) = x0*y1 + x1*y0 | |
| let p1 := pm.subs (p0.adds p2) | |
| add_karatsuba p0 p1 p2 | |
| /-- Multiply two trees of potentially different depths -/ | |
| def mul_le [Zero α] [Add α] [Sub α] [Neg α] [Mul α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : | |
| Tree α (n + m + 1) := | |
| match m, y with | |
| | 0, y => x.muls y | |
| | m + 1, left y => left (mul_le x y) | |
| | m + 1, both y z => add_shift (mul_le x y) (mul_le x z) | |
| lemma mul_nm (le : n ≤ m) : n + (m - n) + 1 = max n m + 1 := by omega | |
| /-- Karatsuba multiplication of two trees -/ | |
| def mul [Zero α] [Add α] [Sub α] [Neg α] [Mul α] (x : Tree α n) (y : Tree α m) : | |
| Tree α (max n m + 1) := | |
| if le : n ≤ m then way x y le else max_comm n m ▸ way y x (not_le.mp le).le where | |
| way {n m : ℕ} (x : Tree α n) (y : Tree α m) (le : n ≤ m) : Tree α (max n m + 1) := | |
| mul_nm le ▸ x.mul_le ((Nat.add_sub_cancel' le).symm ▸ y) | |
| @[simp] lemma poly_add_shift [CommRing α] (p0 p1 : Tree α (n + 1)) : | |
| (add_shift p0 p1).poly = p0.poly + p1.poly * X ^ (2 ^ n) := by | |
| cases p0 <;> cases p1 | |
| all_goals simp [add_shift] <;> ring | |
| @[simp] lemma poly_add_karatsuba [CommRing α] (p0 p1 p2 : Tree α (n + 1)) : | |
| (add_karatsuba p0 p1 p2).poly = | |
| p0.poly + p1.poly * X ^ (2 ^ n) + p2.poly * X ^ (2 ^ (n + 1)) := by | |
| cases p0 <;> cases p1 <;> cases p2 | |
| all_goals simp [add_karatsuba] <;> ring | |
| @[simp] lemma poly_muls [CommRing α] (x y : Tree α n) : (x.muls y).poly = x.poly * y.poly := by | |
| induction' n | |
| all_goals cases x <;> cases y | |
| all_goals simp_all [muls] <;> ring | |
| @[simp] lemma poly_mul_le [CommRing α] {m : ℕ} (x : Tree α n) (y : Tree α (n + m)) : | |
| (x.mul_le y).poly = x.poly * y.poly := by | |
| induction' m | |
| · simp [mul_le] | |
| · cases y | |
| all_goals simp_all [mul_le] <;> ring | |
| @[simp] lemma poly_mul [CommRing α] (x : Tree α n) (y : Tree α m) : | |
| (x.mul y).poly = x.poly * y.poly := by | |
| simp only [mul, mul.way] | |
| split_ifs with h | |
| all_goals simp [mul_comm] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment