Last active
August 16, 2023 07:00
-
-
Save bmitc/1e9643fabea63caa74d37a30c67e9888 to your computer and use it in GitHub Desktop.
Implementation of dual numbers and automatic differentiation in Wake
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dual | |
# This is an implementation of automatic differentiation in the Wake programming language | |
# using dual numbers. A dual number data type is implemented and then arithmetic is implemented | |
# for dual numbers. Once that is done, a few trigonometric functions are implemented. From there, | |
# functions can be freely defined by compositions of these operators and functions and then either | |
# evaluated or automatically differentiated. | |
# | |
# The automatic-ness of the differentiation comes from the dual numbers. It is a fact that for a | |
# dual number, a + b*e, where e is defined by being an element such that e^2 = 0, then | |
# f(a + b*e) = f(a) + b*f'(a)*e | |
# That is, to get the derivative of a function at a, we simply evaluate it on the dual number a + e, | |
# and then read off the dual part to get f'(a). In other words, b is set to 1. | |
# | |
# For more about the Wake programming language, see: https://github.com/sifive/wake. | |
from wake import _ | |
############################################################ | |
#### Dual number type ###################################### | |
############################################################ | |
# Represents a dual number a + b*e where a is referred to as the `Real` part and `b` as | |
# the `Dual` part. The element e is defined such that `e^2 = 0`. | |
export tuple Dual = | |
Real: Double | |
Dual: Double | |
############################################################ | |
#### Helper constructor and getter functions ############### | |
############################################################ | |
# Gets the `Dual` part of a dual number tuple | |
export def getReal = getDualReal | |
# Gets the `Real` part of a dual number tuple | |
export def getDual = getDualDual | |
# Prints a dual number `Dual a b` as `"a + be"`. | |
export def printDual (dual: Dual) = | |
println "{format dual.getReal} + {format dual.getDual}e" | |
# Builds a dual number constant from the given value | |
export def constant (a: Double) = Dual a 0.0 | |
# The dual number constant 0 | |
export def dualZero = constant 0.0 | |
# The dual number constant 1 | |
export def dualOne = constant 1.0 | |
############################################################ | |
#### Dual number arithemetic ############################### | |
############################################################ | |
export def (x: Dual) ==& (y: Dual) = | |
(x.getReal ==. y.getReal) && (x.getDual ==. y.getDual) | |
export def -&(x: Dual) = | |
Dual (-. x.getReal) (-. x.getDual) | |
export def (x: Dual) +& (y: Dual) = | |
Dual (x.getReal +. y.getReal) (x.getDual +. y.getDual) | |
export def (x: Dual) -& (y: Dual) = | |
Dual (x.getReal -. y.getReal) (x.getDual -. y.getDual) | |
export def (x: Dual) *& (y: Dual) = | |
Dual (x.getReal *. y.getReal) (x.getReal *. y.getDual +. x.getDual *. y.getReal) | |
export def (x: Dual) /& (y: Dual) = | |
def Dual a b = x | |
def Dual c d = y | |
Dual (a /. c) ((b *. c -. a *. d) /. (c ^. 2.0)) | |
export def (x: Dual) ^& (n: Integer) = | |
def Dual real dual = x | |
def m = dint n | |
Dual (real ^. m) (dual *. m *. (real ^. (m -. 1.0))) | |
############################################################ | |
#### Dual number trigonometric functions ################### | |
############################################################ | |
export def dualSin (x: Dual) = | |
Dual (dsin x.getReal) (x.getDual *. dcos x.getReal) | |
export def dualCos (x: Dual) = | |
Dual (dcos x.getReal) (x.getDual *. (-. dsin x.getReal)) | |
export def dsec (radians: Double): Double = | |
1.0 /. (dcos radians) | |
export def dualTan (x: Dual) = | |
Dual (dtan x.getReal) (x.getDual *. (dsec x.getReal) ^. 2.0) | |
export def dualSec (x: Dual) = | |
dualOne /& (dualCos x) | |
export def dualCsc (x: Dual) = | |
dualOne /& (dualSin x) | |
export def dualCot (x: Dual) = | |
dualOne /& (dualTan x) | |
############################################################ | |
#### Evaluation and automatic differentiation ############## | |
############################################################ | |
# Evaluates the given dual number function at the point a. This is done by simply | |
# evaluating the function at a + 0*e. | |
export def evaluate (f: Dual => Dual) (a: Double): Double = | |
(f $ Dual a 0.0).getReal | |
# Differentiates the given dual number function at the point a. This is done | |
# by evaluating the function at a + 1*e and then reading off the dual part. | |
export def differentiate (f: Dual => Dual) (a: Double): Double = | |
(f $ Dual a 1.0).getDual | |
############################################################ | |
#### Helper functions ###################################### | |
############################################################ | |
# Build a range, consisting of a list of doubles, by using the start and end values | |
# and incrementing values between them. | |
export def drange (start: Double) (increment: Double) (end: Double) = | |
def helper state accumulatedList: List Double = | |
match state | |
state if state >. end -> accumulatedList | |
x -> helper (x +. increment) (x, accumulatedList) | |
helper start Nil | |
| reverse | |
############################################################ | |
#### Examples ############################################## | |
############################################################ | |
# The sinc function is defined by sinc x = (sin x) / x | |
def sinc (x: Dual) = dualSin x /& x | |
def sincRange = drange (-.10.0) 0.1 10.0 | |
# Simply import this CSV file into something like Microsoft Excel or Google Sheets and plot it | |
# to see the function f(x) = sinc x plotted alongside it's derivative, which has been computed | |
# automatically here. | |
export def sincData = | |
sincRange | |
| map (\x "{format x},{format (evaluate sinc x)},{format (differentiate sinc x)}") | |
| catWith "\n" | |
| write "sinc.csv" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment