Skip to content

Instantly share code, notes, and snippets.

@arya2004
Created March 2, 2025 15:04
Show Gist options
  • Select an option

  • Save arya2004/0b7280b7cbbdfe98c4c2a59f9e5881bc to your computer and use it in GitHub Desktop.

Select an option

Save arya2004/0b7280b7cbbdfe98c4c2a59f9e5881bc to your computer and use it in GitHub Desktop.
Forward Mode Automatic Differentiation in Go
package main
import (
"fmt"
"math"
)
// Dual represents a dual number: a + bε.
type Dual struct {
Real float64
Dual float64
}
// Add adds two dual numbers.
func (d Dual) Add(other Dual) Dual {
return Dual{Real: d.Real + other.Real, Dual: d.Dual + other.Dual}
}
// Mul multiplies two dual numbers.
func (d Dual) Mul(other Dual) Dual {
return Dual{
Real: d.Real * other.Real,
Dual: d.Real*other.Dual + d.Dual*other.Real,
}
}
// Sin computes the sine of a dual number.
func (d Dual) Sin() Dual {
return Dual{
Real: math.Sin(d.Real),
Dual: d.Dual * math.Cos(d.Real),
}
}
// Exp computes the exponential of a dual number.
func (d Dual) Exp() Dual {
val := math.Exp(d.Real)
return Dual{
Real: val,
Dual: d.Dual * val,
}
}
// g defines the function: g(x) = sin(x) * exp(x).
func g(x Dual) Dual {
return x.Sin().Mul(x.Exp())
}
// differentiate computes the derivative of function f at x0.
func differentiate(f func(Dual) Dual, x0 float64) float64 {
x := Dual{Real: x0, Dual: 1}
return f(x).Dual
}
func main() {
x0 := 1.0
value := g(Dual{Real: x0, Dual: 0}).Real
deriv := differentiate(g, x0)
fmt.Printf("Advanced Example: At x = %v, g(x) = %v and g'(x) = %v\n", x0, value, deriv)
// Expected derivative: e^x (sin(x) + cos(x)) at x = 1
expected := math.Exp(1) * (math.Sin(1) + math.Cos(1))
fmt.Printf("Expected derivative: %v\n", expected)
}
package main
import (
"fmt"
)
// Dual represents a dual number: a + bε, where ε² = 0.
type Dual struct {
Real float64
Dual float64
}
// Add returns the sum of two dual numbers.
func (d Dual) Add(other Dual) Dual {
return Dual{Real: d.Real + other.Real, Dual: d.Dual + other.Dual}
}
// Mul returns the product of two dual numbers.
func (d Dual) Mul(other Dual) Dual {
return Dual{
Real: d.Real * other.Real,
Dual: d.Real*other.Dual + d.Dual*other.Real,
}
}
// Pow returns the dual number raised to an integer power.
func (d Dual) Pow(n int) Dual {
result := Dual{Real: 1, Dual: 0}
for i := 0; i < n; i++ {
result = result.Mul(d)
}
return result
}
// f defines the function: f(x) = x² + 3x + 5.
func f(x Dual) Dual {
return x.Pow(2).Add(Dual{Real: 3, Dual: 0}.Mul(x)).Add(Dual{Real: 5, Dual: 0})
}
// differentiate computes the derivative of f at x0.
func differentiate(f func(Dual) Dual, x0 float64) float64 {
x := Dual{Real: x0, Dual: 1}
return f(x).Dual
}
func main() {
x0 := 2.0
value := f(Dual{Real: x0, Dual: 0}).Real
deriv := differentiate(f, x0)
fmt.Printf("Simple Example: At x = %v, f(x) = %v and f'(x) = %v\n", x0, value, deriv)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment