積和演算の並列化
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Diagnostics; | |
using System.Linq; | |
using System.Numerics; | |
using System.Threading.Tasks; | |
class Program | |
{ | |
private const int Loops = 500; | |
private const int N = 4 * 1024 * 1024; | |
private static readonly int NumWorkerThread; | |
static Program() | |
{ | |
NumWorkerThread = Environment.ProcessorCount; | |
} | |
static void Main() | |
{ | |
var r = new Random(); | |
var x = CreateWhiteNoise(r, N); | |
var y = CreateWhiteNoise(r, N); | |
Measure("single/scalar", () => SingleThreadScalar(x, y)); | |
Measure("multi /scalar", () => MultiThreadScalar(x, y)); | |
var vx = CopyToVector4(x); | |
var vy = CopyToVector4(y); | |
Measure("single/vector", () => SingleThreadVector(vx, vy)); | |
Measure("multi /vector", () => MultiThreadVector(vx, vy)); | |
} | |
/// <summary> | |
/// <see cref="Loops"/>回<paramref name="f"/>を呼ぶのに掛かる時間を計測。 | |
/// 一応確認のために計算した値も表示。 | |
/// </summary> | |
private static void Measure(string tag, Func<float> f) | |
{ | |
Console.WriteLine("-------"); | |
Console.WriteLine(tag); | |
var sw = new Stopwatch(); | |
sw.Start(); | |
var prod = 0f; | |
for (int l = 0; l < Loops; l++) | |
{ | |
prod = f(); | |
} | |
sw.Stop(); | |
Console.WriteLine(sw.Elapsed); | |
// N 要素、平均値 0.5 のホワイトノイズ同士の内積なので、N * 0.5 * 0.5 くらいの値になってるはず。 | |
// N/4 で割って、ほぼ1になるはず。 | |
// Random.NextDoubleが [0, 1) で、1にはならないのもあって、1よりはちょっと小さめの値が出る。 | |
Console.WriteLine(prod / N * 4); | |
} | |
/// <summary> | |
/// 普通の for ループ。 | |
/// </summary> | |
private static float SingleThreadScalar(float[] x, float[] y) | |
{ | |
var prod = 0f; | |
for (int i = 0; i < N; i++) | |
prod += x[i] * y[i]; | |
return prod; | |
} | |
/// <summary> | |
/// マルチスレッド化。 | |
/// </summary> | |
private static float MultiThreadScalar(float[] x, float[] y) | |
{ | |
var windowSize = N / NumWorkerThread; | |
var prod = 0f; | |
var partialProds = Task.WhenAll( | |
Enumerable.Range(0, NumWorkerThread) | |
.Select(n => Task.Run(() => | |
{ | |
var local = 0f; | |
for (int i = n * windowSize; i < (n + 1) * windowSize; i++) | |
local += x[i] * y[i]; | |
return local; | |
})) | |
).GetAwaiter().GetResult(); | |
prod = 0f; | |
for (int i = 0; i < partialProds.Length; i++) | |
prod += partialProds[i]; | |
return prod; | |
} | |
/// <summary> | |
/// SIMD 化。 | |
/// </summary> | |
private static float SingleThreadVector(Vector4[] vx, Vector4[] vy) | |
{ | |
var prod = 0f; | |
prod = 0f; | |
for (int i = 0; i < N / 4; i++) | |
prod += Vector4.Dot(vx[i], vy[i]); | |
return prod; | |
} | |
/// <summary> | |
/// マルチスレッド化 + SIMD 化。 | |
/// </summary> | |
private static float MultiThreadVector(Vector4[] vx, Vector4[] vy) | |
{ | |
var windowSize = N / NumWorkerThread; | |
var prod = 0f; | |
var partialProds = Task.WhenAll( | |
Enumerable.Range(0, NumWorkerThread) | |
.Select(n => Task.Factory.StartNew(() => | |
{ | |
var local = 0f; | |
for (int i = n * windowSize / 4; i < (n + 1) * windowSize / 4; i++) | |
local += Vector4.Dot(vx[i], vy[i]); | |
return local; | |
})) | |
).GetAwaiter().GetResult(); | |
prod = 0f; | |
for (int i = 0; i < partialProds.Length; i++) | |
prod += partialProds[i]; | |
return prod; | |
} | |
/// <summary> | |
/// SIMD計算用に、float配列を<see cref="Vector4"/>配列にコピー。 | |
/// プリミティブじゃないので<see cref="Buffer.BlockCopy(Array, int, Array, int, int)"/>は使えず。 | |
/// </summary> | |
private static Vector4[] CopyToVector4(float[] x) | |
{ | |
var vx = new Vector4[x.Length / 4]; | |
for (int i = 0; i < x.Length / 4; i++) | |
vx[i] = new Vector4(x[4 * i], x[4 * i + 1], x[4 * i + 2], x[4 * i + 3]); | |
return vx; | |
} | |
/// <summary> | |
/// N 要素の乱数列を作る。 | |
/// <see cref="Random.NextDouble"/>で作るので、0~1の乱数列 = 平均値 0.5 のホワイトノイズ。 | |
/// </summary> | |
private static float[] CreateWhiteNoise(Random r, int N) | |
{ | |
var x = new float[N]; | |
for (int i = 0; i < N; i++) | |
x[i] = (float)r.NextDouble(); | |
return x; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment