Skip to content

Instantly share code, notes, and snippets.

@YCAyca
Created December 19, 2021 17:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YCAyca/a2c625dfe288d72dd45c865a4ba56088 to your computer and use it in GitHub Desktop.
Save YCAyca/a2c625dfe288d72dd45c865a4ba56088 to your computer and use it in GitHub Desktop.
import numpy as np
def loss_function(prediction, ground_truth): #mean square error with batch size = 1
cost = (ground_truth-prediction)**2
return cost
def prediction(x,current_weights,current_bias):
y_predicted = np.sum((current_weights * x)) + current_bias
return y_predicted
def gradient_descent(x, y, current_weights, current_bias, learning_rate):
# Making predictions
y_predicted = prediction(x,current_weights, current_bias)
# Calculationg the current cost
current_cost = loss_function(y, y_predicted)
print("current_cost ", current_cost)
# Calculating the gradients
weight_derivative = -2 * sum(x * (y-y_predicted))
bias_derivative = -2 * sum(y-y_predicted)
# Updating weights and bias
updated_weight = current_weights - (learning_rate * weight_derivative)
updated_bias = current_bias - (learning_rate * bias_derivative)
return updated_weight, updated_bias
""" Dataset """
v1 = np.array([1, 1, 0, 0])
v2 = np.array([1, 0, 0, 0])
v3 = np.array([0, 0, 0, 1])
v4 = np.array([0, 0, 1, 1])
inputs = [v1, v2, v3, v4]
ground_truth = [0, 0 ,1 , 1]
dataset_length = len(inputs)
print(dataset_length)
""" Training """
iteration_number = 10
learning_rate = 0.01
current_weights = [0.09, 0.2, 0.5, 0.95]
current_bias = [0.02]
for i in range(iteration_number):
for k in range(dataset_length):
current_weights, current_bias = gradient_descent(inputs[k], ground_truth[k], current_weights, current_bias, learning_rate)
print("Iteration :", i, "current weights", current_weights, "current bias", current_bias)
""" Test """
for k in range(dataset_length):
p_ = prediction(inputs[k], current_weights, current_bias)
print("data:", inputs[k], "prediction", p_ )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment