Last active
August 14, 2021 09:29
-
-
Save JustinSDK/1ff71a25a8d4d749b272929e44a526cb to your computer and use it in GitHub Desktop.
PyTorch求線性迴歸
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import cv2 | |
import matplotlib.pyplot as plt | |
def training_loop(epochs, lr, params, x, y, verbose = False): | |
mx = torch.unsqueeze(x, 1).float() | |
my = torch.unsqueeze(y, 1).float() | |
model = torch.nn.Linear(mx.size(1), my.size(1)) | |
mse_loss = torch.nn.MSELoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr = lr) | |
for epoch in range(1, epochs + 1): | |
optimizer.zero_grad() | |
p = model(mx) | |
loss = mse_loss(p, my) | |
loss.backward() | |
optimizer.step() | |
if verbose: | |
print('週期', epoch, '--') | |
print('\t損失:', float(loss)) | |
print('\t模型參數:') | |
for param in model.parameters(): | |
print('\t\t', param.item()) | |
return torch.tensor(list(model.parameters())) | |
# https://openhome.cc/Gossip/DCHardWay/images/PolynomialRegression-1.JPG | |
img = torch.from_numpy(cv2.imread('PolynomialRegression-1.JPG', cv2.IMREAD_GRAYSCALE)) | |
idx = torch.where(img < 127) # 黑點的索引 | |
x = idx[1] | |
y = -idx[0] + img.shape[0] # 反轉 y 軸 | |
plt.gca().set_aspect(1) | |
plt.scatter(x, y) | |
w, b = training_loop( | |
epochs = 100, | |
lr = 0.001, | |
params = torch.tensor([1.0, 0.0], requires_grad = True), | |
x = x, | |
y = y | |
) | |
x = torch.linspace(0, 50, 50) | |
y = w * x + b | |
plt.plot(x, y) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment