Last active
February 22, 2021 17:24
-
-
Save larswaechter/c2df0bc9b0f15c64220eb5699b25ddf1 to your computer and use it in GitHub Desktop.
A decent introduction to Gradient Descent in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Medium Article: https://larswaechter.medium.com/a-decent-introduction-to-gradient-descent-in-python-846be2e41592 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# The model's output (prediction) | |
def predict(X, w, b): | |
return X * w + b | |
# Calculating the loss using MSE | |
def loss(X, Y, w, b): | |
return np.mean((predict(X, w, b) - Y) ** 2) | |
# The partial derivatives of "loss" with respect to w and b | |
def gradient(X, Y, w, b): | |
w_gradient = 2 * np.mean((predict(X, w, b) - Y) * X) | |
b_gradient = 2 * np.mean((predict(X, w, b) - Y) * 1) | |
return (w_gradient, b_gradient) | |
# Train the model using Gradient Descents algorithm | |
def train(X, Y, iterations, lr): | |
w = b = 0 # Initial weight and bias | |
loss_values = np.array([]) | |
# Gradient Descents iterations | |
for i in range(iterations): | |
w_gradient, b_gradient = gradient(X, Y, w, b) | |
# Keep track of the loss after each iteration | |
loss_val = loss(X, Y, w, b) | |
loss_values = np.append(loss_values, loss_val) | |
w -= w_gradient * lr | |
b -= b_gradient * lr | |
return w, b, loss_values | |
# Plot the training data | |
def plot_data(X, Y, w, b): | |
sns.set() | |
# Label the axes | |
plt.xticks(fontsize=15) | |
plt.yticks(fontsize=15) | |
plt.xlabel("Weather", fontsize=30) | |
plt.ylabel("Visitors", fontsize=30) | |
# Scale the axes | |
x_edge, y_edge = 40, 800 | |
plt.axis([0, x_edge, 0, y_edge]) | |
# Plot the data points | |
plt.plot(X, Y, "bo") | |
# Plot the straight line | |
plt.plot([0, x_edge], [b, predict(x_edge, w, b)], linewidth=1.0, color="g") | |
plt.show() | |
# Plot the course of the loss function | |
def plot_loss(X, Y): | |
sns.set() | |
# Label the axes | |
plt.yticks([]) | |
plt.xticks(fontsize=15) | |
plt.xlabel("Weight", fontsize=30) | |
plt.ylabel("Loss", fontsize=30) | |
# Scale the axis | |
plt.axis([0, 30, 0, 100000]) | |
# Calculate the loss for a given range | |
weights = np.linspace(-50, 50, 10000) | |
losses = [loss(X, Y, w, 0) for w in weights] | |
# Plot the curve | |
plt.plot(weights, losses, color="black") | |
# Mark the minimum loss | |
min_index = np.argmin(losses) | |
plt.plot(weights[min_index], losses[min_index], "gX", markersize=20) | |
plt.show() | |
# Plot the loss curve | |
def plot_loss_curve(loss_values): | |
sns.set() | |
# Label the axes | |
plt.xticks(fontsize=15) | |
plt.yticks([]) | |
plt.xlabel("Iterations", fontsize=30) | |
plt.ylabel("Loss", fontsize=30) | |
# Plot the curve | |
plt.plot(loss_values) | |
plt.show() | |
X, Y = np.loadtxt("park.txt", skiprows=1, unpack=True) | |
w, b, loss_values = train(X, Y, iterations=20000, lr=0.001) | |
print("\nw=%.8f; b=%.8f" % (w, b)) | |
print("Prediction: x=%d => y=%.2f" % (33, predict(33, w, b))) | |
# plot_data(X, Y, w, b) | |
# plot_loss(X, Y) | |
# plot_loss_curve(loss_values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Temperature in °C Number of visitors | |
15 23 | |
18 74 | |
20 65 | |
21 82 | |
23 135 | |
25 321 | |
27 440 | |
28 400 | |
29 290 | |
30 620 | |
32 630 | |
34 610 | |
35 560 | |
36 568 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment