Skip to content

Instantly share code, notes, and snippets.

@bhaprayan
Last active June 8, 2020 16:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bhaprayan/64a0a53a2cbb4949773ea4abd7c4d2f9 to your computer and use it in GitHub Desktop.
Save bhaprayan/64a0a53a2cbb4949773ea4abd7c4d2f9 to your computer and use it in GitHub Desktop.
computing derivatives
"""
function call example:
ret = deriv(dz17, z15, z16, "+")
dz15 += ret[0]
dz16 += ret[1]
"""
def sigdx(x):
return x * (1 - x)
def tanhdx(x):
return 1 - (x**2)
def deriv(dz, arg1, arg2, op):
if op == "*":
return dz * arg2.T, dz * arg1.T
elif op == "@":
return arg2 @ dz, dz @ arg1
elif op == "+":
return dz, dz
elif op == "-":
return dz, -dz
elif op == "tanh":
return dz * (1 - (tanhdx(arg1).reshape(-1, 1) ** 2)).T
elif op == "sigmoid":
return (
dz
* (sigdx(arg1).reshape(-1, 1)).T
* (1 - sigdx(arg1).reshape(-1, 1)).T
)
"""
complete gru cell backward pass:
dz17 = delta
ret = deriv(dz17, z15, z16, "+")
dz15 += ret[0]
dz16 += ret[1]
ret = deriv(dz16, z4, z13, "*")
dz4 += ret[0]
dz13 += ret[1]
ret = deriv(dz15, z14, h, "*")
dz14 += ret[0]
dh += ret[1]
ret = deriv(dz14, 1, z4, "-")
dz4 += ret[1]
ret = deriv(dz13, z12, 0, "tanh")
dz12 += ret
ret = deriv(dz12, z10, z11, "+")
dz10 += ret[0]
dz11 += ret[1]
ret = deriv(dz11, Wx, x, "@")
dWx += ret[0]
dx += ret[1]
ret = deriv(dz10, Wh, z9, "@")
dWh += ret[0]
dz9 += ret[1]
ret = deriv(dz9, z8, h, "*")
dz8 += ret[0]
dh += ret[1]
ret = deriv(dz8, z7, 0, "sigmoid")
dz7 += ret
ret = deriv(dz7, z5, z6, "+")
dz5 += ret[0]
dz6 += ret[1]
ret = deriv(dz6, Wrx, x, "@")
dWrx += ret[0]
dx += ret[1]
ret = deriv(dz5, Wrh, h, "@")
dWrh += ret[0]
dh += ret[1]
ret = deriv(dz4, z3, 0, "sigmoid")
dz3 += ret
ret = deriv(dz3, z1, z2, "+")
dz1 += ret[0]
dz2 += ret[1]
ret = deriv(dz2, Wzx, x, "@")
dWzx += ret[0]
dx += ret[1]
ret = deriv(dz1, Wzh, h, "@")
dWzh += ret[0]
dh += ret[1]
ret = deriv(dz17, z15, z16, "+")
dz15 += ret[0]
dz16 += ret[1]
ret = deriv(dz8, z7, 0, "sigmoid")
dz7 += ret
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment