Skip to content

Instantly share code, notes, and snippets.

@43x2
Created September 5, 2016 15:06
Show Gist options
  • Save 43x2/9db1da7789fa93e00b1fdea2d3e74669 to your computer and use it in GitHub Desktop.
Save 43x2/9db1da7789fa93e00b1fdea2d3e74669 to your computer and use it in GitHub Desktop.
半加算器ニューラルネットワークの学習 (「誤差逆伝播法をはじめからていねいに」 ソースコード)
// main.cpp is placed in PUBLIC DOMAIN.
//
#include <cmath>
#include <vector>
#include <random>
#include <iostream>
#include "SimpleMatrix.h"
// 入力データと教師データ
struct Data
{
struct
{
double A; // 入力 a
double B; // 入力 b
} input;
struct
{
double S; // 出力 s (和)
double C; // 出力 c (キャリー)
} supervisor;
};
int main( int /*argc*/, char * /*argv*/[] )
{
// 入力データと教師データ
const std::vector< Data > datas {
// input supervisor
{ { 0.0, 0.0 }, { 0.0, 0.0 } },
{ { 0.0, 1.0 }, { 1.0, 0.0 } },
{ { 1.0, 0.0 }, { 1.0, 0.0 } },
{ { 1.0, 1.0 }, { 0.0, 1.0 } }
};
// 乱数生成器・一様分布器
std::random_device device;
std::mt19937_64 generator( device() );
std::uniform_real_distribution< double > distributor( -8.0, 8.0 );
// 隠れ層への重み
SimpleMatrix< double > weightA( 3, 2 );
for ( auto i = 0; i < 3; ++i )
{
for ( auto j = 0; j < 2; ++j )
{
weightA( i, j ) = distributor( generator );
}
}
// 隠れ層バイアス
SimpleMatrix< double > biasA( 3, 1 );
for ( auto i = 0; i < 3; ++i )
{
biasA( i, 0 ) = distributor( generator );
}
// 出力層への重み
SimpleMatrix< double > weightB( 2, 3 );
for ( auto i = 0; i < 2; ++i )
{
for ( auto j = 0; j < 3; ++j )
{
weightB( i, j ) = distributor( generator );
}
}
// 出力層バイアス
SimpleMatrix< double > biasB( 2, 1 );
for ( auto i = 0; i < 2; ++i )
{
biasB( i, 0 ) = distributor( generator );
}
// 学習回数
constexpr std::size_t training_times = 100000;
// 学習率
constexpr double learning_rate = 0.3;
//
// 学習フェーズ
//
for ( auto i = 0; i < training_times; ++i )
{
// 誤差
auto error = 0.0;
// 誤差に対する隠れ層への重みの偏微分 (∂E/∂WA)
SimpleMatrix< double > dE_dWA( 3, 2, 0.0 );
// 誤差に対する隠れ層バイアスの偏微分 (∂E/∂bA)
SimpleMatrix< double > dE_dbA( 3, 1, 0.0 );
// 誤差に対する出力層への重みの偏微分 (∂E/∂WB)
SimpleMatrix< double > dE_dWB( 2, 3, 0.0 );
// 誤差に対する出力層バイアスの偏微分 (∂E/∂bB)
SimpleMatrix< double > dE_dbB( 2, 1, 0.0 );
for ( auto & data : datas )
{
// 入力層
SimpleMatrix< double > input( 2, 1, { data.input.A, data.input.B } );
// 隠れ層を計算する
auto hidden = weightA * input + biasA;
hidden( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 0, 0 ) ) ); //
hidden( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 1, 0 ) ) ); //
hidden( 2, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 2, 0 ) ) ); // シグモイド関数による非線形変換
// 出力層を計算する
auto output = weightB * hidden + biasB;
output( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 0, 0 ) ) ); //
output( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 1, 0 ) ) ); // シグモイド関数による非線形変換
// 誤差
error += 0.5 * ( std::pow( output( 0, 0 ) - data.supervisor.S, 2.0 ) + std::pow( output( 1, 0 ) - data.supervisor.C, 2.0 ) );
// ∂E/∂bB
SimpleMatrix< double > supervisor( 2, 1, { data.supervisor.S, data.supervisor.C } );
auto temp0 = output - supervisor;
auto temp1 = temp0.element_product( output ).element_product( SimpleMatrix< double >( 2, 1, 1.0 ) - output );
dE_dbB += temp1;
// ∂E/∂WB
auto temp2 = temp1 * hidden.transpose();
dE_dWB += temp2;
// ∂E/∂bA
auto temp3 = ( temp1.transpose() * weightB ).transpose();
auto temp4 = temp3.element_product( hidden ).element_product( SimpleMatrix< double >( 3, 1, 1.0 ) - hidden );
dE_dbA += temp4;
// ∂E/∂WA
auto temp5 = temp4 * input.transpose();
dE_dWA += temp5;
}
// 100 回ごとに報告
if ( i % 100 == 0 )
{
std::cout << i << " -- " << error << std::endl;
}
// 重みとバイアスに反映
weightA -= ( dE_dWA *= learning_rate );
biasA -= ( dE_dbA *= learning_rate );
weightB -= ( dE_dWB *= learning_rate );
biasB -= ( dE_dbB *= learning_rate );
}
//
// 結果 (学習フェーズとほぼ同じコードなのでコメントは省略)
//
auto error = 0.0;
for ( auto data : datas )
{
SimpleMatrix< double > input( 2, 1, { data.input.A, data.input.B } );
auto hidden = weightA * input + biasA;
hidden( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 0, 0 ) ) );
hidden( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 1, 0 ) ) );
hidden( 2, 0 ) = 1.0 / ( 1.0 + std::exp( -hidden( 2, 0 ) ) );
auto output = weightB * hidden + biasB;
output( 0, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 0, 0 ) ) );
output( 1, 0 ) = 1.0 / ( 1.0 + std::exp( -output( 1, 0 ) ) );
error += 0.5 * ( std::pow( output( 0, 0 ) - data.supervisor.S, 2.0 ) + std::pow( output( 1, 0 ) - data.supervisor.C, 2.0 ) );
}
std::cout << training_times << " -- " << error << std::endl;
return 0;
}
// SimpleMatrix.h is placed in PUBLIC DOMAIN.
//
#if !defined( SIMPLE_MATRIX_H )
#define SIMPLE_MATRIX_H
#include <cstddef>
#include <cassert>
#include <vector>
#include <utility>
// シンプルな行列演算クラス
template< class NumberType >
class SimpleMatrix
{
//
// メンバ関数
//
public:
// コンストラクタ
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns );
// コンストラクタ (単一の値で初期化)
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, const NumberType initial_value );
// コンストラクタ (初期化リストで初期化)
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, std::initializer_list< NumberType > initial_values );
// コンストラクタ (ベクタで初期化)
explicit SimpleMatrix( const std::size_t rows, const std::size_t columns, const std::vector< NumberType > & initial_values );
// コピーコンストラクタ
SimpleMatrix( const SimpleMatrix & ) = default;
// ムーブコンストラクタ
SimpleMatrix( SimpleMatrix && source );
// デストラクタ
~SimpleMatrix() = default;
// 代入演算子
SimpleMatrix & operator=( const SimpleMatrix & ) = default;
// ムーブ代入演算子
SimpleMatrix & operator=( SimpleMatrix && source );
// すべての要素を指定の値にする
SimpleMatrix & fill( const NumberType & value );
// 行数を取得する
std::size_t get_number_of_rows() const;
// 列数を取得する
std::size_t get_number_of_columns() const;
// 要素参照 (non-const)
NumberType & operator()( const std::size_t row, const std::size_t column );
// 要素参照 (const)
const NumberType & operator()( const std::size_t row, const std::size_t column ) const;
// 加算
SimpleMatrix operator+( const SimpleMatrix & rhs ) const;
// 減算
SimpleMatrix operator-( const SimpleMatrix & rhs ) const;
// 乗算
SimpleMatrix operator*( const SimpleMatrix & rhs ) const;
// アダマール積 (要素ごとの積)
SimpleMatrix element_product( const SimpleMatrix & rhs ) const;
// 加算複合代入
SimpleMatrix & operator+=( const SimpleMatrix & rhs );
// 減算複合代入
SimpleMatrix & operator-=( const SimpleMatrix & rhs );
// 乗算複合代入 (対スカラー)
SimpleMatrix & operator*=( const NumberType & rhs );
// 除算複合代入 (対スカラー)
SimpleMatrix & operator/=( const NumberType & rhs );
// インプレースでアダマール積 (要素ごとの積)
SimpleMatrix & element_product_inplace( const SimpleMatrix & rhs );
// 転置
SimpleMatrix transpose() const;
//
// メンバ変数
//
private:
// 行数
std::size_t rows_;
// 列数
std::size_t columns_;
// 保存領域
std::vector< NumberType > buffer_;
};
// コンストラクタ
template< class NumberType >
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns )
: rows_( rows )
, columns_( columns )
, buffer_( rows * columns )
{
// アサーション
assert( rows > 0 && columns > 0 );
}
// コンストラクタ (単一の値で初期化)
template< class NumberType >
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, const NumberType initial_value )
: rows_( rows )
, columns_( columns )
, buffer_( rows * columns, initial_value )
{
// アサーション
assert( rows > 0 && columns > 0 );
}
// コンストラクタ (初期化リストで初期化)
template< class NumberType >
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, std::initializer_list< NumberType > initial_values )
: rows_( rows )
, columns_( columns )
, buffer_( initial_values )
{
// アサーション
assert( rows > 0 && columns > 0 && initial_values.size() == rows * columns );
}
// コンストラクタ (ベクタで初期化)
template< class NumberType >
SimpleMatrix< NumberType >::SimpleMatrix( const std::size_t rows, const std::size_t columns, const std::vector< NumberType > & initial_values )
: rows_( rows )
, columns_( columns )
, buffer_( initial_values )
{
// アサーション
assert( rows > 0 && columns > 0 && initial_values.size() == rows * columns );
}
// ムーブコンストラクタ
template< class NumberType >
SimpleMatrix< NumberType >::SimpleMatrix( SimpleMatrix && source )
: rows_( source.rows_ )
, columns_( source.columns_ )
, buffer_( std::move( source.buffer_ ) )
{
// source.buffer_ の内容が保証されなくなったので rows_ と columns_ も更新する
source.rows_ = source.columns_ = 0;
}
// ムーブ代入演算子
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator=( SimpleMatrix && source )
{
rows_ = source.rows_;
columns_ = source.columns_;
buffer_ = std::move( source.buffer_ );
// source.buffer_ の内容が保証されなくなったので rows_ と columns_ も更新する
source.rows_ = source.columns_ = 0;
return *this;
}
// すべての要素を指定の値にする
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::fill( const NumberType & value )
{
std::vector< NumberType > temp( rows_ * columns_, value );
buffer_.swap( temp );
return *this;
// 一時的にメモリ消費量がほぼ倍になるので,
// 要素数が非常に大きい場合は std::fill を使うほうがよいと思う.
// by Atsushi OHTA (2016/7/2)
}
// 行数を取得する
template< class NumberType >
std::size_t SimpleMatrix< NumberType >::get_number_of_rows() const
{
return rows_;
}
// 列数を取得する
template< class NumberType >
std::size_t SimpleMatrix< NumberType >::get_number_of_columns() const
{
return columns_;
}
// 要素参照 (non-const)
template< class NumberType >
NumberType & SimpleMatrix< NumberType >::operator()( const std::size_t row, const std::size_t column )
{
// アサーション
assert( row < rows_ && column < columns_ );
return buffer_[ row * columns_ + column ];
}
// 要素参照 (const)
template< class NumberType >
const NumberType & SimpleMatrix< NumberType >::operator()( const std::size_t row, const std::size_t column ) const
{
// アサーション
assert( row < rows_ && column < columns_ );
return buffer_[ row * columns_ + column ];
}
// 加算
template< class NumberType >
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator+( const SimpleMatrix & rhs ) const
{
SimpleMatrix result( *this );
return result += rhs;
}
// 減算
template< class NumberType >
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator-( const SimpleMatrix & rhs ) const
{
SimpleMatrix result( *this );
return result -= rhs;
}
// 乗算
template< class NumberType >
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::operator*( const SimpleMatrix & rhs ) const
{
// アサーション
assert( columns_ == rhs.rows_ );
SimpleMatrix result( rows_, rhs.columns_ );
for ( auto i = 0; i < rows_; ++i )
{
for ( auto j = 0; j < rhs.columns_; ++j )
{
NumberType temp = NumberType( 0 );
for ( auto k = 0; k < columns_; ++k )
{
temp += ( *this )( i, k ) * rhs( k, j );
}
result( i, j ) = temp;
}
}
return result;
}
// アダマール積 (要素ごとの積)
template< class NumberType >
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::element_product( const SimpleMatrix & rhs ) const
{
SimpleMatrix result( *this );
return result.element_product_inplace( rhs );
}
// 加算複合代入
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator+=( const SimpleMatrix & rhs )
{
// アサーション
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ );
auto it = rhs.buffer_.cbegin();
for ( auto & elem : buffer_ )
{
elem += *it++;
}
return *this;
}
// 減算複合代入
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator-=( const SimpleMatrix & rhs )
{
// アサーション
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ );
auto it = rhs.buffer_.cbegin();
for ( auto & elem : buffer_ )
{
elem -= *it++;
}
return *this;
}
// 乗算複合代入 (対スカラー)
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator*=( const NumberType & rhs )
{
for ( auto & elem : buffer_ )
{
elem *= rhs;
}
return *this;
}
// 除算複合代入 (対スカラー)
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::operator/=( const NumberType & rhs )
{
for ( auto & elem : buffer_ )
{
elem /= rhs;
}
return *this;
}
// インプレースでアダマール積 (要素ごとの積)
template< class NumberType >
SimpleMatrix< NumberType > & SimpleMatrix< NumberType >::element_product_inplace( const SimpleMatrix & rhs )
{
// アサーション
assert( rows_ == rhs.rows_ && columns_ == rhs.columns_ );
auto it = rhs.buffer_.cbegin();
for ( auto & elem : buffer_ )
{
elem *= *it++;
}
return *this;
}
// 転置
template< class NumberType >
SimpleMatrix< NumberType > SimpleMatrix< NumberType >::transpose() const
{
SimpleMatrix result( columns_, rows_ );
for ( auto i = 0; i < columns_; ++i )
{
for ( auto j = 0; j < rows_; ++j )
{
result( i, j ) = ( *this )( j, i );
}
}
return result;
}
#endif // !defined( SIMPLE_MATRIX_H )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment