Skip to content

Instantly share code, notes, and snippets.

@bmitc
Last active August 16, 2023 07:00
Show Gist options
  • Save bmitc/1e9643fabea63caa74d37a30c67e9888 to your computer and use it in GitHub Desktop.
Save bmitc/1e9643fabea63caa74d37a30c67e9888 to your computer and use it in GitHub Desktop.
Implementation of dual numbers and automatic differentiation in Wake
package dual
# This is an implementation of automatic differentiation in the Wake programming language
# using dual numbers. A dual number data type is implemented and then arithmetic is implemented
# for dual numbers. Once that is done, a few trigonometric functions are implemented. From there,
# functions can be freely defined by compositions of these operators and functions and then either
# evaluated or automatically differentiated.
#
# The automatic-ness of the differentiation comes from the dual numbers. It is a fact that for a
# dual number, a + b*e, where e is defined by being an element such that e^2 = 0, then
# f(a + b*e) = f(a) + b*f'(a)*e
# That is, to get the derivative of a function at a, we simply evaluate it on the dual number a + e,
# and then read off the dual part to get f'(a). In other words, b is set to 1.
#
# For more about the Wake programming language, see: https://github.com/sifive/wake.
from wake import _
############################################################
#### Dual number type ######################################
############################################################
# Represents a dual number a + b*e where a is referred to as the `Real` part and `b` as
# the `Dual` part. The element e is defined such that `e^2 = 0`.
export tuple Dual =
Real: Double
Dual: Double
############################################################
#### Helper constructor and getter functions ###############
############################################################
# Gets the `Dual` part of a dual number tuple
export def getReal = getDualReal
# Gets the `Real` part of a dual number tuple
export def getDual = getDualDual
# Prints a dual number `Dual a b` as `"a + be"`.
export def printDual (dual: Dual) =
println "{format dual.getReal} + {format dual.getDual}e"
# Builds a dual number constant from the given value
export def constant (a: Double) = Dual a 0.0
# The dual number constant 0
export def dualZero = constant 0.0
# The dual number constant 1
export def dualOne = constant 1.0
############################################################
#### Dual number arithemetic ###############################
############################################################
export def (x: Dual) ==& (y: Dual) =
(x.getReal ==. y.getReal) && (x.getDual ==. y.getDual)
export def -&(x: Dual) =
Dual (-. x.getReal) (-. x.getDual)
export def (x: Dual) +& (y: Dual) =
Dual (x.getReal +. y.getReal) (x.getDual +. y.getDual)
export def (x: Dual) -& (y: Dual) =
Dual (x.getReal -. y.getReal) (x.getDual -. y.getDual)
export def (x: Dual) *& (y: Dual) =
Dual (x.getReal *. y.getReal) (x.getReal *. y.getDual +. x.getDual *. y.getReal)
export def (x: Dual) /& (y: Dual) =
def Dual a b = x
def Dual c d = y
Dual (a /. c) ((b *. c -. a *. d) /. (c ^. 2.0))
export def (x: Dual) ^& (n: Integer) =
def Dual real dual = x
def m = dint n
Dual (real ^. m) (dual *. m *. (real ^. (m -. 1.0)))
############################################################
#### Dual number trigonometric functions ###################
############################################################
export def dualSin (x: Dual) =
Dual (dsin x.getReal) (x.getDual *. dcos x.getReal)
export def dualCos (x: Dual) =
Dual (dcos x.getReal) (x.getDual *. (-. dsin x.getReal))
export def dsec (radians: Double): Double =
1.0 /. (dcos radians)
export def dualTan (x: Dual) =
Dual (dtan x.getReal) (x.getDual *. (dsec x.getReal) ^. 2.0)
export def dualSec (x: Dual) =
dualOne /& (dualCos x)
export def dualCsc (x: Dual) =
dualOne /& (dualSin x)
export def dualCot (x: Dual) =
dualOne /& (dualTan x)
############################################################
#### Evaluation and automatic differentiation ##############
############################################################
# Evaluates the given dual number function at the point a. This is done by simply
# evaluating the function at a + 0*e.
export def evaluate (f: Dual => Dual) (a: Double): Double =
(f $ Dual a 0.0).getReal
# Differentiates the given dual number function at the point a. This is done
# by evaluating the function at a + 1*e and then reading off the dual part.
export def differentiate (f: Dual => Dual) (a: Double): Double =
(f $ Dual a 1.0).getDual
############################################################
#### Helper functions ######################################
############################################################
# Build a range, consisting of a list of doubles, by using the start and end values
# and incrementing values between them.
export def drange (start: Double) (increment: Double) (end: Double) =
def helper state accumulatedList: List Double =
match state
state if state >. end -> accumulatedList
x -> helper (x +. increment) (x, accumulatedList)
helper start Nil
| reverse
############################################################
#### Examples ##############################################
############################################################
# The sinc function is defined by sinc x = (sin x) / x
def sinc (x: Dual) = dualSin x /& x
def sincRange = drange (-.10.0) 0.1 10.0
# Simply import this CSV file into something like Microsoft Excel or Google Sheets and plot it
# to see the function f(x) = sinc x plotted alongside it's derivative, which has been computed
# automatically here.
export def sincData =
sincRange
| map (\x "{format x},{format (evaluate sinc x)},{format (differentiate sinc x)}")
| catWith "\n"
| write "sinc.csv"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment