Skip to content

Instantly share code, notes, and snippets.

@tomokishii
Last active May 9, 2019 19:47
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tomokishii/b154ea83f7dd060ac05844164fc4a4fb to your computer and use it in GitHub Desktop.
Save tomokishii/b154ea83f7dd060ac05844164fc4a4fb to your computer and use it in GitHub Desktop.
Chainer vs. PyTorch - Linear Regression

Chainer vs. PyTorch Linear Regression

  1. linear_reg_chainer.py (Chainer version)
  2. linear_reg_pytorch.py (PyTorch version)
  • Python 3.5.2 (or 3.5.3)
  • Chainer 2.0.0
  • PyTorch 0.1.12
# -*- coding: utf-8 -*-
#
# linear_reg_chainer.py - Chainer version
# date. 6/2/2017
#
import numpy as np
import chainer
from chainer import Function, Variable
import chainer.functions as F
# Target値 (3.0, 4.0), これを元に学習データサンプルを作成する.
W_target = np.array([[3.0]], dtype=np.float32) # size = [1, 1]
b_target = 4.0
# Model Parameters
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
W = Variable(np.random.randn(1, 1).astype(np.float32) * 0.01,
requires_grad=True)
b = Variable(np.zeros([1, 1], dtype=np.float32), requires_grad=True)
def model(x, W, b):
# 線形回帰モデルの定義
y1 = F.matmul(x, W)
b1 = F.broadcast_to(b, x.shape)
y = y1 + b1
return y
def get_batch(W_target, b_target, batch_size=32):
# バッチ・データの準備
x = np.random.randn(batch_size, 1).astype(np.float32)
y = x * W_target + b_target
return Variable(x), Variable(y)
# Train loop
for batch_idx in range(100):
# Get data
batch_x, batch_y = get_batch(W_target, b_target)
# Forward pass
y_pred = model(batch_x, W, b)
# 損失関数 MSE(mean square error)
loss = F.mean_squared_error(y_pred, batch_y)
# Manually zero the gradients after updating weights
# パラメータの勾配をゼロ化する.(重要)
W.cleargrad()
b.cleargrad()
# Backward pass
loss.backward()
# Apply gradients
learning_rate = 0.1
W.data = W.data - learning_rate * W.grad
b.data = b.data - learning_rate * b.grad
# Stop criterion
if loss.data < 1.e-3:
break
# 計算結果の出力
print('Loss: {:>8.4f} after {:d} batches'.format(
float(loss.data), batch_idx))
print('==> Learned function:\t' + 'y = {:>8.4f} x + {:>8.4f}'.format(
float(W.data), float(b.data)))
print('==> Actual function:\t' + 'y = {:>8.4f} x + {:>8.4f}'.format(
float(W_target), float(b_target)))
# -*- coding: utf-8 -*-
#
# linear_reg_pytorch.py - PyTorch version
# date. 5/24/2017
#
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
# Target値 (3.0, 4.0)
W_target = torch.FloatTensor([[3.0]]) # size = [1, 1]
b_target = 4.0
# Model Parameters
dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
W = Variable((torch.randn(1, 1) * 0.01).type(dtype), requires_grad=True)
b = Variable(torch.zeros(1, 1).type(dtype), requires_grad=True)
def model(x):
# 線形回帰モデルの定義
y = torch.mm(x, W) + b.expand_as(x)
return y
def get_batch(batch_size=32):
# バッチ・データの準備
x = torch.randn(batch_size, 1)
y = torch.mm(x, W_target) + b_target
return Variable(x), Variable(y)
# 損失関数 MSE(mean square error)
loss_fn = torch.nn.MSELoss(size_average=True)
# Train loop
for batch_idx in range(20):
# Get data
batch_x, batch_y = get_batch()
# Forward pass
y_pred = model(batch_x)
loss = loss_fn(y_pred, batch_y)
loss_np = loss.data[0]
# Backward pass
loss.backward()
# Apply gradients
learning_rate = 0.1
W.data = W.data - learning_rate * W.grad.data
b.data = b.data - learning_rate * b.grad.data
# Manually zero the gradients by torch.Tensor.zero_()
# パラメータの勾配をゼロ化する.(重要)
W.grad.data.zero_()
b.grad.data.zero_()
# Stop criterion
if loss_np < 1.e-3:
break
# 計算結果の出力
def model_desc(W, b):
# Support function to show result.
if type(W) == torch.FloatTensor:
W = W.numpy()
if type(b) == torch.FloatTensor:
b = b.numpy()
b = float(b)
result = 'y = {0} x + {1:>8.4f}'.format(W, b)
return result
print('Loss: {:>8.4e} after {:d} batches'.format(loss_np, batch_idx))
print('==> Learned function:\t' + model_desc(W.data.view(-1), b.data))
print('==> Actual function:\t' + model_desc(W_target.view(-1), b_target))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment