Skip to content

Instantly share code, notes, and snippets.

@zchrissirhcz
Forked from MarisaKirisame/AD.py
Created November 11, 2021 10:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zchrissirhcz/e915777343b9cf18d23f95ae6558159f to your computer and use it in GitHub Desktop.
Save zchrissirhcz/e915777343b9cf18d23f95ae6558159f to your computer and use it in GitHub Desktop.
import math
def sin(x):
if isinstance(x, Dual):
return Dual(sin(x.x), cos(x.x) * x.dx)
return math.sin(x)
def cos(x):
if isinstance(x, Dual):
return Dual(cos(x.x), -1 * sin(x.x) * x.dx)
return math.cos(x)
class Dual:
def __init__(self, x, dx):
self.x = x
self.dx = dx
def __add__(self, r):
assert isinstance(r, Dual)
return Dual(self.x + r.x, self.dx + r.dx)
def __mul__(self, r):
if isinstance(r, Dual):
return Dual(self.x * r.x, self.x * r.dx + r.x * self.dx)
assert isinstance(r, float)
return Dual(self.x * r, r * self.dx)
def __rmul__(self, r):
return self * r
def __repr__(self):
return repr((self.x, self.dx))
class Raw:
def __mul__(self, r):
assert isinstance(r, float)
return Raw()
def __rmul__(self, r):
return self * r
def __add__(self, r):
assert isinstance(r, Raw)
return Raw()
def __repr__(self):
return repr(())
def f(x, y):
return sin(x) * y, sin(y) + x
print("raw:")
print(f(Dual(1.0, Raw()), Dual(2.0, Raw())))
print("forward mode:")
print(f(Dual(1.0, 1.0), Dual(2.0, 0.0)))
print(f(Dual(1.0, 0.0), Dual(2.0, 1.0)))
class Ref:
def __init__(self, v):
self.v = v
class WithBP:
def __init__(self, rdx, bp):
self.rdx = rdx
self.bp = bp
def __mul__(self, rhs):
assert isinstance(rhs, float)
r = Ref(0.0)
bpv = self.bp.v
def new_bp():
self.rdx.v = self.rdx.v + r.v * rhs
bpv()
self.bp.v = new_bp
return WithBP(r, self.bp)
def __rmul__(self, rhs):
return self * rhs
def __add__(self, rhs):
assert isinstance(rhs, WithBP)
r = Ref(0.0)
bpv = self.bp.v
def new_bp():
self.rdx.v = self.rdx.v + r.v
rhs.rdx.v = rhs.rdx.v + r.v
bpv()
self.bp.v = new_bp
return WithBP(r, self.bp)
print("reverse mode:")
bp = Ref(lambda: ())
x = WithBP(Ref(0.0), bp)
y = WithBP(Ref(0.0), bp)
a, b = f(Dual(1.0, x), Dual(2.0, y))
a.dx.rdx.v = 1.0
a.dx.bp.v()
print((x.rdx.v, y.rdx.v))
bp.v = lambda: ()
x = WithBP(Ref(0.0), bp)
y = WithBP(Ref(0.0), bp)
a, b = f(Dual(1.0, x), Dual(2.0, y))
b.dx.rdx.v = 1.0
bp.v()
print((x.rdx.v, y.rdx.v))
class Batched:
def __init__(self, *l):
self.l = l
def __mul__(self, rhs):
assert isinstance(rhs, float)
return Batched(*[x * rhs for x in self.l])
def __add__(self, rhs):
assert isinstance(rhs, Batched)
assert len(self.l) == len(rhs.l)
return Batched(*[self.l[i] + rhs.l[i] for i in range(len(self.l))])
def __rmul__(self, rhs):
return self * rhs
def __repr__(self):
return repr(self.l)
print("batched forward mode:")
print(f(Dual(1.0, Batched(1.0, 0.0)), Dual(2.0, Batched(0.0, 1.0))))
print("batched reverse mode:")
bp = Ref(lambda: ())
ax = WithBP(Ref(0.0), bp)
bx = WithBP(Ref(0.0), bp)
ay = WithBP(Ref(0.0), bp)
by = WithBP(Ref(0.0), bp)
a, b = f(Dual(1.0, Batched(ax, bx)), Dual(2.0, Batched(ay, by)))
a.dx.l[0].rdx.v = 1.0
b.dx.l[1].rdx.v = 1.0
bp.v()
print((ax.rdx.v, bx.rdx.v, ay.rdx.v, by.rdx.v))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment