Skip to content

Instantly share code, notes, and snippets.

@Koziev
Created December 4, 2018 13:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Koziev/d95b7302529a1467b08cfc6037ea0724 to your computer and use it in GitHub Desktop.
Save Koziev/d95b7302529a1467b08cfc6037ea0724 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Использование автоматического дифференцирования autograd (https://github.com/HIPS/autograd)
для решения линейной регрессии МНК.
Код может решать только задачу линейно регрессии, так как
в нем отдельно выписывается градиентный спуск по каждому из двух
компонентов решения через частные производные.
"""
from __future__ import print_function
import random
import autograd.numpy as np
from autograd import grad
def calc_y(x):
y = 0.3 + 2.0*x + random.gauss(mu=0.0, sigma=0.000001)
return y
def loss(a, b, x_data, y_data):
#y_pred = np.add(np.multiply(x_data, b), a)
y_pred = a + b * x_data
return np.sqrt(((y_pred - y_data) ** 2).mean(axis=None))
# Сформируем слегка зашумленный датасет
nb_samples = 100
x_data = np.linspace(start=0.0, stop=99.0, num=nb_samples)
y_data = np.array(list(map(calc_y, x_data)))
# Частные производные функции потерь по двум подбираемым параметрам соответственно.
a_grad = grad(loss, 0)
b_grad = grad(loss, 1)
# начальное приближение для решения
a = 0.0
b = 0.0
learning_rate0 = 0.01 # начальное значения скорости обучения
learning_rate = learning_rate0
lr_decay = 0.9999 # для постепенно уменьшения скорости обучения
min_lr = 1e-5
tolerance = 1e-3 # при достижении такого значения функции потерь итерации прекращаем
iter = 0
while True:
iter += 1
da = a_grad(a, b, x_data, y_data)
db = b_grad(a, b, x_data, y_data)
a = a - da*learning_rate
b = b - db*learning_rate
cur_loss = loss(a, b, x_data, y_data)
print('='*30)
print('i={} a={} b={} loss={} da={} db={}'.format(iter, a, b, cur_loss, da, db))
learning_rate = max(min_lr, learning_rate * lr_decay)
print('learning_rate={:8.6f}'.format(learning_rate))
if cur_loss <= tolerance:
print('Current loss={} is smaller than tolerance={}, optimization complete.'.format(cur_loss, tolerance))
break
print('Finish: a={} b={}'.format(a, b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment