Last active
August 11, 2021 07:03
-
-
Save JustinSDK/5205e4583e981a8f51a821a8be3eda39 to your computer and use it in GitHub Desktop.
PyTorch反向傳播求梯度
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import cv2 | |
import matplotlib.pyplot as plt | |
def model(x, w, b): | |
return w * x + b | |
def mse_loss(p, y): | |
return ((p - y) ** 2).mean() | |
def training_loop(epochs, lr, params, x, y, verbose = False): | |
for epoch in range(1, epochs + 1): | |
# 記得歸零梯度計算 | |
if params.grad is not None: | |
params.grad.zero_() | |
w, b = params | |
p = model(x, w, b) | |
loss = mse_loss(p, y) | |
loss.backward() | |
# 這部份是更新參數,不應該加入模型的運算圖 | |
with torch.no_grad(): | |
# 不能使用 params = params - lr * params.grad | |
# -= 才會做 in-place 操作,更新資料,其餘屬性設定不變 | |
params -= lr * params.grad | |
if verbose: | |
print('週期', epoch, '--') | |
print('\t損失:', float(loss)) | |
print('\t模型參數:', params) | |
return params.detach() | |
# https://openhome.cc/Gossip/DCHardWay/images/PolynomialRegression-1.JPG | |
img = torch.from_numpy(cv2.imread('PolynomialRegression-1.JPG', cv2.IMREAD_GRAYSCALE)) | |
idx = torch.where(img < 127) # 黑點的索引 | |
x = idx[1] | |
y = -idx[0] + img.shape[0] # 反轉 y 軸 | |
plt.gca().set_aspect(1) | |
plt.scatter(x, y) | |
w, b = training_loop( | |
epochs = 100, | |
lr = 0.001, | |
params = torch.tensor([1.0, 0.0], requires_grad = True), # requires_grad 設為 True | |
x = x, | |
y = y | |
) | |
x = torch.linspace(0, 50, 50) | |
y = w * x + b | |
plt.plot(x, y) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment