Last active
June 8, 2020 16:17
-
-
Save bhaprayan/64a0a53a2cbb4949773ea4abd7c4d2f9 to your computer and use it in GitHub Desktop.
computing derivatives
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
function call example: | |
ret = deriv(dz17, z15, z16, "+") | |
dz15 += ret[0] | |
dz16 += ret[1] | |
""" | |
def sigdx(x): | |
return x * (1 - x) | |
def tanhdx(x): | |
return 1 - (x**2) | |
def deriv(dz, arg1, arg2, op): | |
if op == "*": | |
return dz * arg2.T, dz * arg1.T | |
elif op == "@": | |
return arg2 @ dz, dz @ arg1 | |
elif op == "+": | |
return dz, dz | |
elif op == "-": | |
return dz, -dz | |
elif op == "tanh": | |
return dz * (1 - (tanhdx(arg1).reshape(-1, 1) ** 2)).T | |
elif op == "sigmoid": | |
return ( | |
dz | |
* (sigdx(arg1).reshape(-1, 1)).T | |
* (1 - sigdx(arg1).reshape(-1, 1)).T | |
) | |
""" | |
complete gru cell backward pass: | |
dz17 = delta | |
ret = deriv(dz17, z15, z16, "+") | |
dz15 += ret[0] | |
dz16 += ret[1] | |
ret = deriv(dz16, z4, z13, "*") | |
dz4 += ret[0] | |
dz13 += ret[1] | |
ret = deriv(dz15, z14, h, "*") | |
dz14 += ret[0] | |
dh += ret[1] | |
ret = deriv(dz14, 1, z4, "-") | |
dz4 += ret[1] | |
ret = deriv(dz13, z12, 0, "tanh") | |
dz12 += ret | |
ret = deriv(dz12, z10, z11, "+") | |
dz10 += ret[0] | |
dz11 += ret[1] | |
ret = deriv(dz11, Wx, x, "@") | |
dWx += ret[0] | |
dx += ret[1] | |
ret = deriv(dz10, Wh, z9, "@") | |
dWh += ret[0] | |
dz9 += ret[1] | |
ret = deriv(dz9, z8, h, "*") | |
dz8 += ret[0] | |
dh += ret[1] | |
ret = deriv(dz8, z7, 0, "sigmoid") | |
dz7 += ret | |
ret = deriv(dz7, z5, z6, "+") | |
dz5 += ret[0] | |
dz6 += ret[1] | |
ret = deriv(dz6, Wrx, x, "@") | |
dWrx += ret[0] | |
dx += ret[1] | |
ret = deriv(dz5, Wrh, h, "@") | |
dWrh += ret[0] | |
dh += ret[1] | |
ret = deriv(dz4, z3, 0, "sigmoid") | |
dz3 += ret | |
ret = deriv(dz3, z1, z2, "+") | |
dz1 += ret[0] | |
dz2 += ret[1] | |
ret = deriv(dz2, Wzx, x, "@") | |
dWzx += ret[0] | |
dx += ret[1] | |
ret = deriv(dz1, Wzh, h, "@") | |
dWzh += ret[0] | |
dh += ret[1] | |
ret = deriv(dz17, z15, z16, "+") | |
dz15 += ret[0] | |
dz16 += ret[1] | |
ret = deriv(dz8, z7, 0, "sigmoid") | |
dz7 += ret | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment