Skip to content

Instantly share code, notes, and snippets.

@JustinSDK
Last active August 14, 2021 09:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JustinSDK/1ff71a25a8d4d749b272929e44a526cb to your computer and use it in GitHub Desktop.
Save JustinSDK/1ff71a25a8d4d749b272929e44a526cb to your computer and use it in GitHub Desktop.
PyTorch求線性迴歸
import torch
import cv2
import matplotlib.pyplot as plt
def training_loop(epochs, lr, params, x, y, verbose = False):
mx = torch.unsqueeze(x, 1).float()
my = torch.unsqueeze(y, 1).float()
model = torch.nn.Linear(mx.size(1), my.size(1))
mse_loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = lr)
for epoch in range(1, epochs + 1):
optimizer.zero_grad()
p = model(mx)
loss = mse_loss(p, my)
loss.backward()
optimizer.step()
if verbose:
print('週期', epoch, '--')
print('\t損失:', float(loss))
print('\t模型參數:')
for param in model.parameters():
print('\t\t', param.item())
return torch.tensor(list(model.parameters()))
# https://openhome.cc/Gossip/DCHardWay/images/PolynomialRegression-1.JPG
img = torch.from_numpy(cv2.imread('PolynomialRegression-1.JPG', cv2.IMREAD_GRAYSCALE))
idx = torch.where(img < 127) # 黑點的索引
x = idx[1]
y = -idx[0] + img.shape[0] # 反轉 y 軸
plt.gca().set_aspect(1)
plt.scatter(x, y)
w, b = training_loop(
epochs = 100,
lr = 0.001,
params = torch.tensor([1.0, 0.0], requires_grad = True),
x = x,
y = y
)
x = torch.linspace(0, 50, 50)
y = w * x + b
plt.plot(x, y)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment