Skip to content

Instantly share code, notes, and snippets.

@FisherKK
Created August 13, 2018 10:59
Show Gist options
  • Save FisherKK/7c19abce0bab8c2a90384f632cc71c52 to your computer and use it in GitHub Desktop.
Save FisherKK/7c19abce0bab8c2a90384f632cc71c52 to your computer and use it in GitHub Desktop.
from copy import deepcopy
def train(X, y, model_parameters, step=0.1, iterations=100):
# Make prediction for every data sample
predictions = [predict(x, model_parameters) for x in X]
# Calculate cost for model - MSE
lowest_error = mse(predictions, y)
print("\nInitial state:")
print(" - error: {}".format(lowest_error))
print(" - parameters: {}".format(model_parameters))
for i in range(iterations):
candidates, errors = list(), list()
# w increased, b increased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] += step
param_candidate["w"] += step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w increased, b unchanged
param_candidate = deepcopy(model_parameters)
param_candidate["w"] += step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w increased, b decreased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] -= step
param_candidate["w"][0] += step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w unchanged, b increased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] += step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w unchanged, b unchanged
param_candidate = deepcopy(model_parameters)
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w unchanged, b decreased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] -= step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w decreased, b increased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] += step
param_candidate["w"] -= step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w decreased, b unchanged
param_candidate = deepcopy(model_parameters)
param_candidate["w"] -= step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# w decreased, b decreased
param_candidate = deepcopy(model_parameters)
param_candidate["b"] -= step
param_candidate["w"] -= step
candidate_pred = [predict(x, param_candidate) for x in X]
candidate_error = mse(candidate_pred, y)
candidates.append(param_candidate)
errors.append(candidate_error)
# Update with parameters for which loss is smallest
best_candidate = None
for candidate, candidate_error in zip(candidates, errors):
if candidate_error < lowest_error:
lowest_error = candidate_error
model_parameters["w"], model_parameters["b"] = candidate["w"], candidate["b"]
# Display training progress every 20th iteration
if i % 20 == 0:
print("\nIteration {}:".format(i))
print(" - error: {}".format(lowest_error))
print(" - parameters: {}".format(model_parameters))
print("\nFinal state:")
print(" - error: {}".format(lowest_error))
print(" - parameters: {}".format(model_parameters))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment