Skip to content

Instantly share code, notes, and snippets.

@odashi
Last active December 26, 2017 03:10
Show Gist options
  • Save odashi/83bef0a8d24d293811bd7852ba1916ad to your computer and use it in GitHub Desktop.
Save odashi/83bef0a8d24d293811bd7852ba1916ad to your computer and use it in GitHub Desktop.
primitiv examples for Qiita (C++11/Python3)
// 実行方法:
// g++ -std=c++11 xor.cc -lprimitiv && ./a.out
#include <cstdio>
#include <iostream>
#include <primitiv/primitiv.h>
using namespace primitiv;
namespace D = primitiv::devices;
namespace F = primitiv::functions;
namespace I = primitiv::initializers;
namespace O = primitiv::optimizers;
int main() {
// デバイスと計算グラフの設定
devices::Naive dev;
Device::set_default(dev);
Graph g;
Graph::set_default(g);
// 入力データ
std::vector<float> input_data {
1, 1, // 第一象限
-1, 1, // 第二象限
-1, -1, // 第三象限
1, -1, // 第四象限
};
// 対応する正解
std::vector<float> label_data {
1, // 第一象限
-1, // 第二象限
1, // 第三象限
-1, // 第四象限
};
// パラメータ
const int N = 8;
Parameter pw({1, N}, I::XavierUniform());
Parameter pb({}, I::Constant(0));
Parameter pu({N, 2}, I::XavierUniform());
Parameter pc({N}, I::Constant(0));
// 学習器
O::SGD optimizer(0.5);
optimizer.add(pw, pb, pu, pc);
// ネットワークの定義
auto build_graph = [&] {
auto x = F::input<Node>(Shape({2}, 4), input_data);
auto w = F::parameter<Node>(pw);
auto b = F::parameter<Node>(pb);
auto u = F::parameter<Node>(pu);
auto c = F::parameter<Node>(pc);
auto h = F::tanh(F::matmul(u, x) + c);
return F::tanh(F::matmul(w, h) + b);
};
// 損失の定義
auto calc_loss = [&](Node y) {
auto t = F::input<Node>(Shape({}, 4), label_data);
auto diff = y - t;
return F::batch::mean(diff * diff);
};
// 学習ループ
for (int epoch = 0; epoch < 20; ++epoch) {
std::cout << epoch << ' ';
// グラフの初期化
g.clear();
// 出力の計算
auto y = build_graph();
for (float val : y.to_vector()) {
std::printf("%+.6f, ", val);
}
// 損失の計算
auto loss = calc_loss(y);
std::printf("loss=%.6f", loss.to_float());
std::cout << std::endl;
// 勾配の計算・パラメータの更新
optimizer.reset_gradients();
loss.backward();
optimizer.update();
}
return 0;
}
#!/usr/bin/env python3
# 実行方法: ./xor.py
import numpy as np
from primitiv import *
from primitiv import devices as D
from primitiv import functions as F
from primitiv import initializers as I
from primitiv import optimizers as O
# デバイスと計算グラフの設定
dev = devices.Naive()
Device.set_default(dev)
g = Graph()
Graph.set_default(g)
# 入力データ
input_data = [
np.array([[ 1], [ 1]]), # 第一象限
np.array([[-1], [ 1]]), # 第二象限
np.array([[-1], [-1]]), # 第三象限
np.array([[ 1], [-1]]), # 第四象限
]
# 対応する正解
label_data = [
np.array([ 1]), # 第一象限
np.array([-1]), # 第二象限
np.array([ 1]), # 第三象限
np.array([-1]), # 第四象限
]
# パラメータ
N = 8
pw = Parameter([1, N], I.XavierUniform())
pb = Parameter([], I.Constant(0))
pu = Parameter([N, 2], I.XavierUniform())
pc = Parameter([N], I.Constant(0))
# 学習器
optimizer = O.SGD(0.5)
optimizer.add(pw, pb, pu, pc)
# ネットワークの定義
def build_graph():
x = F.input(input_data)
w = F.parameter(pw)
b = F.parameter(pb)
u = F.parameter(pu)
c = F.parameter(pc)
h = F.tanh(u @ x + c)
return F.tanh(w @ h + b)
# 損失の定義
def calc_loss(y):
t = F.input(label_data)
diff = y - t
return F.batch.mean(diff * diff)
# 学習ループ
for epoch in range(20):
print(epoch, end=' ')
# グラフの初期化
g.clear()
# 出力の計算
y = build_graph()
for val in y.to_list():
print('{:+.6f},'.format(val), end=' ')
# 損失の計算
loss = calc_loss(y)
print('loss={:.6f}'.format(loss.to_float()))
# 勾配の計算・パラメータの更新
optimizer.reset_gradients()
loss.backward()
optimizer.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment