Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@larswaechter
Last active February 22, 2021 17:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save larswaechter/c2df0bc9b0f15c64220eb5699b25ddf1 to your computer and use it in GitHub Desktop.
Save larswaechter/c2df0bc9b0f15c64220eb5699b25ddf1 to your computer and use it in GitHub Desktop.
A decent introduction to Gradient Descent in Python
# Medium Article: https://larswaechter.medium.com/a-decent-introduction-to-gradient-descent-in-python-846be2e41592
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# The model's output (prediction)
def predict(X, w, b):
return X * w + b
# Calculating the loss using MSE
def loss(X, Y, w, b):
return np.mean((predict(X, w, b) - Y) ** 2)
# The partial derivatives of "loss" with respect to w and b
def gradient(X, Y, w, b):
w_gradient = 2 * np.mean((predict(X, w, b) - Y) * X)
b_gradient = 2 * np.mean((predict(X, w, b) - Y) * 1)
return (w_gradient, b_gradient)
# Train the model using Gradient Descents algorithm
def train(X, Y, iterations, lr):
w = b = 0 # Initial weight and bias
loss_values = np.array([])
# Gradient Descents iterations
for i in range(iterations):
w_gradient, b_gradient = gradient(X, Y, w, b)
# Keep track of the loss after each iteration
loss_val = loss(X, Y, w, b)
loss_values = np.append(loss_values, loss_val)
w -= w_gradient * lr
b -= b_gradient * lr
return w, b, loss_values
# Plot the training data
def plot_data(X, Y, w, b):
sns.set()
# Label the axes
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.xlabel("Weather", fontsize=30)
plt.ylabel("Visitors", fontsize=30)
# Scale the axes
x_edge, y_edge = 40, 800
plt.axis([0, x_edge, 0, y_edge])
# Plot the data points
plt.plot(X, Y, "bo")
# Plot the straight line
plt.plot([0, x_edge], [b, predict(x_edge, w, b)], linewidth=1.0, color="g")
plt.show()
# Plot the course of the loss function
def plot_loss(X, Y):
sns.set()
# Label the axes
plt.yticks([])
plt.xticks(fontsize=15)
plt.xlabel("Weight", fontsize=30)
plt.ylabel("Loss", fontsize=30)
# Scale the axis
plt.axis([0, 30, 0, 100000])
# Calculate the loss for a given range
weights = np.linspace(-50, 50, 10000)
losses = [loss(X, Y, w, 0) for w in weights]
# Plot the curve
plt.plot(weights, losses, color="black")
# Mark the minimum loss
min_index = np.argmin(losses)
plt.plot(weights[min_index], losses[min_index], "gX", markersize=20)
plt.show()
# Plot the loss curve
def plot_loss_curve(loss_values):
sns.set()
# Label the axes
plt.xticks(fontsize=15)
plt.yticks([])
plt.xlabel("Iterations", fontsize=30)
plt.ylabel("Loss", fontsize=30)
# Plot the curve
plt.plot(loss_values)
plt.show()
X, Y = np.loadtxt("park.txt", skiprows=1, unpack=True)
w, b, loss_values = train(X, Y, iterations=20000, lr=0.001)
print("\nw=%.8f; b=%.8f" % (w, b))
print("Prediction: x=%d => y=%.2f" % (33, predict(33, w, b)))
# plot_data(X, Y, w, b)
# plot_loss(X, Y)
# plot_loss_curve(loss_values)
Temperature in °C Number of visitors
15 23
18 74
20 65
21 82
23 135
25 321
27 440
28 400
29 290
30 620
32 630
34 610
35 560
36 568
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment