Skip to content

Instantly share code, notes, and snippets.

@hexagit
Last active March 19, 2020 13:12
Show Gist options
  • Save hexagit/aab1b226032f440801b80f31e874428f to your computer and use it in GitHub Desktop.
Save hexagit/aab1b226032f440801b80f31e874428f to your computer and use it in GitHub Desktop.
using System;
using System.IO;
using System.Linq;
// .Net Framework 4.6.1
namespace DeepLearningTest01
{
class Program
{
// 行列
struct Matrix
{
public int shape0 { get { return m.GetLength(0); } }
public int shape1 { get { return m.GetLength(1); } }
// コンストラクタ
public Matrix(int d0, int d1)
{
m = new double[d0, d1];
}
// 内積
public Matrix Dot(Matrix matrix)
{
var r = new double[m.GetLength(0), matrix.m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < matrix.m.GetLength(1); ++j ) {
r[i, j] = 0.0f;
}
}
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < matrix.m.GetLength(1); ++j ) {
for( var k = 0; k < m.GetLength(1); ++k ) {
r[i, j] += m[i, k] * matrix.m[k, j];
}
}
}
return new Matrix() { m = r };
}
// 加算
public Matrix Add(Matrix matrix)
{
var mat = matrix.m;
if( matrix.shape0 == 1 ) {
mat = new double[m.GetLength(0), matrix.shape1];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[0, j];
}
}
}
if( matrix.shape1 == 1 ) {
mat = new double[matrix.shape0, m.GetLength(1)];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[i, 0];
}
}
}
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] + mat[i, j];
}
}
return new Matrix() { m = r };
}
// 加算
public Matrix Add(double val)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] + val;
}
}
return new Matrix() { m = r };
}
// 減算
public Matrix Sub(Matrix matrix)
{
var mat = matrix.m;
if( matrix.shape0 == 1 ) {
mat = new double[m.GetLength(0), matrix.shape1];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[0, j];
}
}
}
if( matrix.shape1 == 1 ) {
mat = new double[matrix.shape0, m.GetLength(1)];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[i, 0];
}
}
}
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] - mat[i, j];
}
}
return new Matrix() { m = r };
}
// 減算
public Matrix Sub(double val)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] - val;
}
}
return new Matrix() { m = r };
}
// 掛け算
public Matrix Mul(Matrix matrix)
{
var mat = matrix.m;
if( matrix.shape0 == 1 ) {
mat = new double[m.GetLength(0), matrix.shape1];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[0, j];
}
}
}
if( matrix.shape1 == 1 ) {
mat = new double[matrix.shape0, m.GetLength(1)];
for( var i = 0; i < mat.GetLength(0); ++i ) {
for( var j = 0; j < mat.GetLength(1); ++j ) {
mat[i, j] = matrix.m[i, 0];
}
}
}
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] * mat[i, j];
}
}
return new Matrix() { m = r };
}
// 掛け算
public Matrix Mul(double val)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] * val;
}
}
return new Matrix() { m = r };
}
// 除算
public Matrix Div(double val)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i, j] / val;
}
}
return new Matrix() { m = r };
}
// 大きい方を残す
public Matrix Max(double c)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = Math.Max(m[i, j], c);
}
}
return new Matrix() { m = r };
}
// 大きいものを抽出
public Matrix Softmax()
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
var expSum = 0.0;
for( var j = 0; j < m.GetLength(1); ++j ) {
expSum += Math.Exp(m[i, j]);
}
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = Math.Exp(m[i, j]) / expSum;
}
}
return new Matrix() { m = r };
}
// 指数関数
public Matrix Exp()
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = Math.Exp(m[i, j]);
}
}
return new Matrix() { m = r };
}
// ログ
public Matrix Log()
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = Math.Log(m[i, j]);
}
}
return new Matrix() { m = r };
}
// 合計
public double Sum()
{
var sum = 0.0;
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
sum += m[i, j];
}
}
return sum;
}
// 転置
public Matrix Transpose()
{
var r = new double[m.GetLength(1), m.GetLength(0)];
for( var i = 0; i < m.GetLength(1); ++i ) {
for( var j = 0; j < m.GetLength(0); ++j ) {
r[i, j] = m[j, i];
}
}
return new Matrix() { m = r };
}
// ゼロで埋める
public Matrix Zeros()
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = 0.0f;
}
}
return new Matrix() { m = r };
}
// 1で埋める
public Matrix Ones()
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = 1.0f;
}
}
return new Matrix() { m = r };
}
// ランダム値で埋める
public Matrix Rands(int seed)
{
var rand = new Random(seed);
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = rand.NextDouble();
}
}
return new Matrix() { m = r };
}
// 条件によって値を設定する
public Matrix Where(Func<double, bool> condition, double x, double y)
{
var r = new double[m.GetLength(0), m.GetLength(1)];
for( var i = 0; i < m.GetLength(0); ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
if( condition(m[i, j]) ) {
r[i, j] = x;
}
else {
r[i, j] = y;
}
}
}
return new Matrix() { m = r };
}
// 特定範囲を抽出
public Matrix Range(int s, int len)
{
var r = new double[len, m.GetLength(1)];
for( var i = 0; i < len; ++i ) {
for( var j = 0; j < m.GetLength(1); ++j ) {
r[i, j] = m[i + s, j];
}
}
return new Matrix() { m = r };
}
public double[,] m;
}
const string TRAIN_IMAGE_FNAME = "train-images.idx3-ubyte";
const string TRAIN_LABEL_FNAME = "train-labels.idx1-ubyte";
const string TEST_IMAGE_FNAME = "t10k-images.idx3-ubyte";
const string TEST_LABEL_FNAME = "t10k-labels.idx1-ubyte";
// エントリーポイント
static void Main(string[] args)
{
Console.WriteLine("Let's Start Leap Learning Test.");
Training();
Console.WriteLine("End.");
Console.ReadKey();
}
static void Training()
{
// トレーニング用ラベル配列作成
Matrix t;
using( var fs = new FileStream(TRAIN_LABEL_FNAME, FileMode.Open) ) {
using( var r = new BinaryReader(fs) ) {
// マジックナンバー読み飛ばし
r.ReadInt32();
// ラベル数読み込み後、エンディアン変換して変数へ
var labelCount = ReadInt32(r);
// ラベル読み込み
t = new Matrix(labelCount, 10);
t.Zeros();
for( var i = 0; i < labelCount; ++i ) {
var val = r.ReadByte();
t.m[i, val] = 1;
}
}
}
// トレーニング用画像配列作成
Matrix x;
using( var fs = new FileStream(TRAIN_IMAGE_FNAME, FileMode.Open) ) {
using( var r = new BinaryReader(fs) ) {
// マジックナンバー読み飛ばし
r.ReadInt32();
// 画像数数読み込み後、エンディアン変換して変数へ
var imageCount = ReadInt32(r);
// 幅と高さ
var imageRows = ReadInt32(r);
var imageColumns = ReadInt32(r);
x = new Matrix(imageCount, imageRows * imageColumns);
for( var i = 0; i < x.shape0; ++i ) {
for( var j = 0; j < x.shape1; ++j ) {
var pixel = r.ReadByte();
x.m[i, j] = pixel / 255.0;
}
}
}
}
// パラメータ初期化
var d0 = x.shape1;
var d1 = 120;
var d2 = 60;
var d3 = 10;
var w1 = new Matrix(d0, d1).Rands(8).Mul(0.2).Sub(0.1);
var w2 = new Matrix(d1, d2).Rands(8).Mul(0.2).Sub(0.1);
var w3 = new Matrix(d2, d3).Rands(8).Mul(0.2).Sub(0.1);
var b1 = new Matrix(1, d1).Zeros();
var b2 = new Matrix(1, d2).Zeros();
var b3 = new Matrix(1, d3).Zeros();
// 学習
for( var i = 0; i < 30; ++i ) {
// 小分けにして回す
var batchCount = 200;
for( var j = 0; j < x.shape0 - batchCount; j += batchCount / 2 ) {
var p = Learn(x.Range(j, batchCount), t.Range(j, batchCount), w1, b1, w2, b2, w3, b3, 0.5f);
w1 = p.Item1;
b1 = p.Item2;
w2 = p.Item3;
b2 = p.Item4;
w3 = p.Item5;
b3 = p.Item6;
}
// 予測
var y = Predict(x, w1, b1, w2, b2, w3, b3);
// 正答率算出
var a = 0;
for( var j = 0; j < t.shape0; ++j ) {
var tMax = -1.0;
var tIdx = -1;
var yMax = -1.0;
var yIdx = -1;
for( var k = 0; k < t.shape1; ++k ) {
if( t.m[j, k] > tMax ) {
tMax = t.m[j, k];
tIdx = k;
}
if( y.m[j, k] > yMax ) {
yMax = y.m[j, k];
yIdx = k;
}
}
if( tIdx >= 0 && tIdx == yIdx ) {
++a;
}
}
var aRate = (double)a / t.shape0;
// エラー
var e = CrossEntropyLoss(y, t);
Console.WriteLine(string.Format("{0:D4} aRate:{1}, error:{2}", i, aRate, e));
}
}
// エンディアン考慮でInt32を読む
static int ReadInt32(BinaryReader r)
{
var bytes = r.ReadBytes(4);
if( BitConverter.IsLittleEndian ) {
bytes = bytes.Reverse().ToArray();
}
return BitConverter.ToInt32(bytes, 0);
}
// 学習
static Tuple<Matrix, Matrix, Matrix, Matrix, Matrix, Matrix> Learn(Matrix x, Matrix t, Matrix w1, Matrix b1, Matrix w2, Matrix b2, Matrix w3, Matrix b3, float lr)
{
var u1 = Affine(x, w1, b1);
var z1 = Relu(u1);
var u2 = Affine(z1, w2, b2);
var z2 = Relu(u2);
var u3 = Affine(z2, w3, b3);
var y = Softmax(u3);
var dy = SoftmaxCrossEntropyLossFB(y, t);
var d3 = AffineFB(dy, z2, w3, b3);
var dz2 = d3.Item1;
var dw3 = d3.Item2;
var db3 = d3.Item3;
var du2 = ReluFB(dz2, u2);
var d2 = AffineFB(du2, z1, w2, b2);
var dz1 = d2.Item1;
var dw2 = d2.Item2;
var db2 = d2.Item3;
var du1 = ReluFB(dz1, u1);
var d1 = AffineFB(du1, x, w1, b1);
var dx = d1.Item1;
var dw1 = d1.Item2;
var db1 = d1.Item3;
w1 = w1.Sub(dw1.Mul(lr));
b1 = b1.Sub(db1.Mul(lr));
w2 = w2.Sub(dw2.Mul(lr));
b2 = b2.Sub(db2.Mul(lr));
w3 = w3.Sub(dw3.Mul(lr));
b3 = b3.Sub(db3.Mul(lr));
return Tuple.Create(w1, b1, w2, b2, w3, b3);
}
// 予測
static Matrix Predict(Matrix x, Matrix w1, Matrix b1, Matrix w2, Matrix b2, Matrix w3, Matrix b3)
{
var u1 = Affine(x, w1, b1);
var z1 = Relu(u1);
var u2 = Affine(z1, w2, b2);
var z2 = Relu(u2);
var u3 = Affine(z2, w3, b3);
var y = Softmax(u3);
return y;
}
// アフィン変換
static Matrix Affine(Matrix z, Matrix w, Matrix b)
{
return z.Dot(w).Add(b);
}
// アフィン変換部のフィードバック
static Tuple<Matrix, Matrix, Matrix> AffineFB(Matrix du, Matrix z, Matrix w, Matrix b)
{
var dz = du.Dot(w.Transpose());
var dw = z.Transpose().Dot(du);
var db = new Matrix(1, z.shape0).Ones().Transpose().Dot(du);
return Tuple.Create(dz, dw, db);
}
// 活性化
static Matrix Relu(Matrix u)
{
return u.Max(0);
}
// 活性化部のフィードバック
static Matrix ReluFB(Matrix dz, Matrix u)
{
return dz.Mul(u.Where(val => val > 0, 1, 0));
}
// 最大部抽出
static Matrix Softmax(Matrix u)
{
return u.Softmax();
}
// 交差エントロピー誤差
static double CrossEntropyLoss(Matrix y, Matrix t)
{
return -t.Mul(y.Max(double.Epsilon).Log()).Sum() / y.shape0;
}
// 最大部抽出と交差エントロピー誤差のフィードバック
static Matrix SoftmaxCrossEntropyLossFB(Matrix y, Matrix t)
{
return y.Sub(t).Div(y.shape0);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment