Skip to content

Instantly share code, notes, and snippets.

@akr4
Last active July 28, 2016 08:24
Show Gist options
  • Save akr4/86d5612145bab8885e57320a7752a86b to your computer and use it in GitHub Desktop.
Save akr4/86d5612145bab8885e57320a7752a86b to your computer and use it in GitHub Desktop.
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable, Chain
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
class MyChain(Chain):
def __init__(self):
super(MyChain, self).__init__(
l1 = F.Linear(8, 8),
l2 = F.Linear(8, 8),
l3 = F.Linear(8, 8),
)
self.train = True
def __call__(self, x):
h = x
h = F.relu(self.l1(h))
h = F.relu(self.l2(h))
return self.l3(h)
x_data = np.array( [ [1,2,3,4,5,6,7,8] ],dtype=np.float32 )
x = Variable(x_data)
y_data = np.array( [ [3,5,7,9,11,13,15,9] ],dtype=np.float32 )
y = Variable(y_data)
model = L.Classifier(MyChain(), chainer.functions.loss.mean_squared_error.mean_squared_error)
model.compute_accuracy = False
optimizer = optimizers.SGD()
optimizer.setup(model)
print(x.data)
print(y.data)
for i in range(0, 1000):
optimizer.update(model, x, y)
test_data = np.array( [ [2,4,6,8,10,12,14,16] ],dtype=np.float32 )
test = Variable(test_data)
test_out = model.predictor(test)
print("test=",test_data,"\nout=",test_out.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment