Skip to content

Instantly share code, notes, and snippets.

@anilzeybek
Created August 24, 2023 23:58
Show Gist options
  • Save anilzeybek/3e96b4fc48e59612dc3f56586b233718 to your computer and use it in GitHub Desktop.
Save anilzeybek/3e96b4fc48e59612dc3f56586b233718 to your computer and use it in GitHub Desktop.
NN from Scratch
import math
import random
class Var:
def __init__(self, value, requires_grad=False, prev_var1=None, prev_var2=None, prev_op=None):
self.value = value
self.requires_grad = requires_grad
if requires_grad:
self.grad = 0
self.prev_var1 = prev_var1
self.prev_var2 = prev_var2
self.prev_op = prev_op
def __add__(self, other):
if not isinstance(other, Var):
other = Var(other)
return Var(self.value + other.value, prev_var1=self, prev_var2=other, prev_op="add")
def __mul__(self, other):
if not isinstance(other, Var):
other = Var(other)
return Var(self.value * other.value, prev_var1=self, prev_var2=other, prev_op="mul")
def __pow__(self, other):
if not isinstance(other, Var):
other = Var(other)
return Var(self.value**other.value, prev_var1=self, prev_var2=other, prev_op="pow")
def __rpow__(self, other):
if not isinstance(other, Var):
other = Var(other)
return Var(other.value**self.value, prev_var1=other, prev_var2=self, prev_op="pow")
def __radd__(self, other):
return self + other
def __sub__(self, other):
return self + (-other)
def __rsub__(self, other):
return other + (-self)
def __neg__(self):
return self * -1
def __rmul__(self, other):
return self * other
def __truediv__(self, other):
return self * other**-1
def __rtruediv__(self, other):
return other * self**-1
def sigmoid(self):
return 1 / (1 + math.e ** (-self))
def backward(self, current_grad=1):
if self.prev_op == "add":
self.prev_var1.backward(current_grad)
self.prev_var2.backward(current_grad)
elif self.prev_op == "mul":
self.prev_var1.backward(current_grad * self.prev_var2.value)
self.prev_var2.backward(current_grad * self.prev_var1.value)
elif self.prev_op == "pow":
self.prev_var1.backward(
current_grad * self.prev_var2.value * self.prev_var1.value ** (self.prev_var2.value - 1)
)
self.prev_var2.backward(
current_grad * self.prev_var1.value**self.prev_var2.value * math.log(self.prev_var1.value)
if self.prev_var1.value > 0
else 0
)
elif self.prev_op == "sigmoid":
self.prev_var1.backward(
current_grad * self.prev_var1.sigmoid().value * (1 - self.prev_var1.sigmoid().value)
)
elif self.prev_op is None:
pass
else:
assert False, "No such operation"
if self.requires_grad:
self.grad += current_grad
class NN:
def __init__(self):
self.weights1 = [Var(random.random(), requires_grad=True), Var(random.random(), requires_grad=True)]
self.bias1 = Var(random.random(), requires_grad=True)
self.weights2 = [Var(random.random(), requires_grad=True), Var(random.random(), requires_grad=True)]
self.bias2 = Var(random.random(), requires_grad=True)
self.weights3 = [Var(random.random(), requires_grad=True), Var(random.random(), requires_grad=True)]
self.bias3 = Var(random.random(), requires_grad=True)
self.parameters = [
*self.weights1,
self.bias1,
*self.weights2,
self.bias2,
*self.weights3,
self.bias3,
]
def sigmoid(self, x):
return 1 / (1 + math.e ** (-x))
def forward(self, p, q):
out1 = p * self.weights1[0] + q * self.weights1[1] + self.bias1
out1 = self.sigmoid(out1)
out2 = p * self.weights2[0] + q * self.weights2[1] + self.bias2
out2 = self.sigmoid(out2)
out3 = out1 * self.weights3[0] + out2 * self.weights3[1] + self.bias3
output = self.sigmoid(out3)
return output
def calculate_loss(self, data):
loss = 0
for d in data:
bit1, bit2 = d[0], d[1]
prediction = self.forward(bit1, bit2)
real_value = d[2]
loss += (prediction - real_value) ** 2
return loss
data = [
[0, 0, 0],
[0, 1, 1],
[1, 0, 1],
[1, 1, 0],
]
model = NN()
lr = 1
for _ in range(2000):
loss = model.calculate_loss(data)
print(f"loss: {loss.value}")
loss.backward()
for i in range(len(model.parameters)):
model.parameters[i].value -= lr * model.parameters[i].grad
model.parameters[i].grad = 0
for d in data:
bit1, bit2 = d[0], d[1]
prediction = model.forward(bit1, bit2).value
print(f"{bit1} & {bit2} = {prediction}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment