Created
March 2, 2025 15:04
-
-
Save arya2004/0b7280b7cbbdfe98c4c2a59f9e5881bc to your computer and use it in GitHub Desktop.
Forward Mode Automatic Differentiation in Go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "fmt" | |
| "math" | |
| ) | |
| // Dual represents a dual number: a + bε. | |
| type Dual struct { | |
| Real float64 | |
| Dual float64 | |
| } | |
| // Add adds two dual numbers. | |
| func (d Dual) Add(other Dual) Dual { | |
| return Dual{Real: d.Real + other.Real, Dual: d.Dual + other.Dual} | |
| } | |
| // Mul multiplies two dual numbers. | |
| func (d Dual) Mul(other Dual) Dual { | |
| return Dual{ | |
| Real: d.Real * other.Real, | |
| Dual: d.Real*other.Dual + d.Dual*other.Real, | |
| } | |
| } | |
| // Sin computes the sine of a dual number. | |
| func (d Dual) Sin() Dual { | |
| return Dual{ | |
| Real: math.Sin(d.Real), | |
| Dual: d.Dual * math.Cos(d.Real), | |
| } | |
| } | |
| // Exp computes the exponential of a dual number. | |
| func (d Dual) Exp() Dual { | |
| val := math.Exp(d.Real) | |
| return Dual{ | |
| Real: val, | |
| Dual: d.Dual * val, | |
| } | |
| } | |
| // g defines the function: g(x) = sin(x) * exp(x). | |
| func g(x Dual) Dual { | |
| return x.Sin().Mul(x.Exp()) | |
| } | |
| // differentiate computes the derivative of function f at x0. | |
| func differentiate(f func(Dual) Dual, x0 float64) float64 { | |
| x := Dual{Real: x0, Dual: 1} | |
| return f(x).Dual | |
| } | |
| func main() { | |
| x0 := 1.0 | |
| value := g(Dual{Real: x0, Dual: 0}).Real | |
| deriv := differentiate(g, x0) | |
| fmt.Printf("Advanced Example: At x = %v, g(x) = %v and g'(x) = %v\n", x0, value, deriv) | |
| // Expected derivative: e^x (sin(x) + cos(x)) at x = 1 | |
| expected := math.Exp(1) * (math.Sin(1) + math.Cos(1)) | |
| fmt.Printf("Expected derivative: %v\n", expected) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "fmt" | |
| ) | |
| // Dual represents a dual number: a + bε, where ε² = 0. | |
| type Dual struct { | |
| Real float64 | |
| Dual float64 | |
| } | |
| // Add returns the sum of two dual numbers. | |
| func (d Dual) Add(other Dual) Dual { | |
| return Dual{Real: d.Real + other.Real, Dual: d.Dual + other.Dual} | |
| } | |
| // Mul returns the product of two dual numbers. | |
| func (d Dual) Mul(other Dual) Dual { | |
| return Dual{ | |
| Real: d.Real * other.Real, | |
| Dual: d.Real*other.Dual + d.Dual*other.Real, | |
| } | |
| } | |
| // Pow returns the dual number raised to an integer power. | |
| func (d Dual) Pow(n int) Dual { | |
| result := Dual{Real: 1, Dual: 0} | |
| for i := 0; i < n; i++ { | |
| result = result.Mul(d) | |
| } | |
| return result | |
| } | |
| // f defines the function: f(x) = x² + 3x + 5. | |
| func f(x Dual) Dual { | |
| return x.Pow(2).Add(Dual{Real: 3, Dual: 0}.Mul(x)).Add(Dual{Real: 5, Dual: 0}) | |
| } | |
| // differentiate computes the derivative of f at x0. | |
| func differentiate(f func(Dual) Dual, x0 float64) float64 { | |
| x := Dual{Real: x0, Dual: 1} | |
| return f(x).Dual | |
| } | |
| func main() { | |
| x0 := 2.0 | |
| value := f(Dual{Real: x0, Dual: 0}).Real | |
| deriv := differentiate(f, x0) | |
| fmt.Printf("Simple Example: At x = %v, f(x) = %v and f'(x) = %v\n", x0, value, deriv) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment