Skip to content

Instantly share code, notes, and snippets.

@OswaldHurlem
Created October 3, 2017 08:58
Show Gist options
  • Save OswaldHurlem/9ed9d74eaaae99331ef279dfb4513a22 to your computer and use it in GitHub Desktop.
Save OswaldHurlem/9ed9d74eaaae99331ef279dfb4513a22 to your computer and use it in GitHub Desktop.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
// Ill-advised experiment making a SoA container in C# using Reflection
namespace SoaExperiment
{
public struct Pixel
{
public byte R, G, B, A;
public Pixel(byte r, byte g, byte b, byte a)
{
R = r;
G = g;
B = b;
A = a;
}
public override bool Equals(object obj) => (obj is Pixel) && Equals((Pixel)obj);
public bool Equals(Pixel pixel)
{
return R == pixel.R &&
G == pixel.G &&
B == pixel.B &&
A == pixel.A;
}
public static bool operator==(Pixel a, Pixel b) => a.Equals(b);
public static bool operator!=(Pixel a, Pixel b) => !a.Equals(b);
public override int GetHashCode()
{
var hashCode = 1960784236;
hashCode = hashCode * -1521134295 + base.GetHashCode();
hashCode = hashCode * -1521134295 + R.GetHashCode();
hashCode = hashCode * -1521134295 + G.GetHashCode();
hashCode = hashCode * -1521134295 + B.GetHashCode();
hashCode = hashCode * -1521134295 + A.GetHashCode();
return hashCode;
}
}
class Program
{
static void Main(string[] args)
{
var image = new SoA<Pixel>(100);
image[16] = new Pixel(10, 20, 30, 40);
Debug.Assert(image.FieldArray(p => p.R)[16] == 10);
image.FieldArray(p => p.G)[44] = 100;
Debug.Assert(image[44].G == 100);
Benchmark();
}
public static void Benchmark()
{
var count = 10000000;
var bigSoA = new SoA<Pixel>(count);
var bigPlainArray = new Pixel[count];
var pixel = new Pixel(10, 20, 30, 40);
{
var s = Stopwatch.StartNew();
for (var i = 0; i < count; i++)
{
bigSoA[i] = pixel;
}
s.Stop();
Console.WriteLine("SoA took " + s.ElapsedMilliseconds + "ms to fill naively");
}
{
var s = Stopwatch.StartNew();
var Rs = bigSoA.FieldArray(p => p.R);
var Gs = bigSoA.FieldArray(p => p.G);
var Bs = bigSoA.FieldArray(p => p.B);
var As = bigSoA.FieldArray(p => p.A);
for (var i = 0; i < count; i++)
{
Rs[i] = 50;
Gs[i] = 60;
Bs[i] = 70;
As[i] = 80;
}
s.Stop();
Console.WriteLine("SoA took " + s.ElapsedMilliseconds + "ms to fill with some extra prep work");
}
{
var s = Stopwatch.StartNew();
for (var i = 0; i < count; i++)
{
bigPlainArray[i] = pixel;
}
s.Stop();
Console.WriteLine("bigPlainArray took " + s.ElapsedMilliseconds + "ms to fill naively");
}
{
var s = Stopwatch.StartNew();
var reds = bigSoA.FieldArray(p => p.R);
s.Stop();
Console.WriteLine("Getting reds from SoA took " + s.ElapsedMilliseconds + "ms");
}
{
var s = Stopwatch.StartNew();
var reds = bigPlainArray.Select(p => p.R).ToArray();
s.Stop();
Console.WriteLine("Getting reds from bigPlainArray took " + s.ElapsedMilliseconds + "ms");
}
{
var s = Stopwatch.StartNew();
var target = new Pixel(200, 200, 200, 200);
for (var i = 0; i < count; i++)
{
if (bigSoA[i] == target)
{
Console.WriteLine("!");
}
}
s.Stop();
Console.WriteLine("SoA took " + s.ElapsedMilliseconds + "ms to scan naively");
}
{
var s = Stopwatch.StartNew();
var Rs = bigSoA.FieldArray(p => p.R);
var Gs = bigSoA.FieldArray(p => p.G);
var Bs = bigSoA.FieldArray(p => p.B);
var As = bigSoA.FieldArray(p => p.A);
var target = new Pixel(200, 200, 200, 200);
for (var i = 0; i < count; i++)
{
var pixelAtI = new Pixel(Rs[i], Gs[i], Bs[i], As[i]);
if (pixelAtI == target)
{
Console.WriteLine("!");
}
}
s.Stop();
Console.WriteLine("SoA took " + s.ElapsedMilliseconds + "ms to scan with some extra prep work");
}
{
var s = Stopwatch.StartNew();
var target = new Pixel(200, 200, 200, 200);
for (var i = 0; i < count; i++)
{
if (bigPlainArray[i] == target)
{
Console.WriteLine("!");
}
}
s.Stop();
Console.WriteLine("bigPlainArray took " + s.ElapsedMilliseconds + "ms to scan naively");
}
/*
This prints:
SoA took 4495ms to fill naively
SoA took 31ms to fill with some extra prep work
bigPlainArray took 29ms to fill naively
Getting reds from SoA took 0ms
Getting reds from bigPlainArray took 125ms
SoA took 6874ms to scan naively
SoA took 220ms to scan with some extra prep work
bigPlainArray took 167ms to scan naively
*/
}
}
// Stores TStruct as separate arrays, one for each field. SoA = Structure of Arrays
// For 8 pixels this means "RRRRRRRRGGGGGGGGBBBBBBBBAAAAAAAA"
// This is useful for SIMD code
// Technically there shouldn't be need for "where T:struct"
public class SoA<TStruct> : IList<TStruct> where TStruct : new()
{
static IList<FieldInfo> Fields = typeof(TStruct).GetFields();
public static Func<object[], TStruct> MakeStruct;
public static Action<TStruct, object[]> SplitStruct;
static SoA()
{
Expression<Func<object[], TStruct>> makeStructExpr;
{
var arrayParamExpr = Expression.Parameter(typeof(object[]), "arrayInst");
var variableExpr = Expression.Parameter(typeof(TStruct), "structInst");
var constructExpr = Expression.Assign(variableExpr,
Expression.New(typeof(TStruct)));
var assignments = new List<Expression>();
assignments.Add(constructExpr);
foreach (var it in Fields.Select((fieldInfo, ind) => new { fieldInfo, ind }))
{
var fieldExp = Expression.Field(variableExpr, it.fieldInfo);
var arrayAccessExpr = Expression.ArrayAccess(arrayParamExpr, Expression.Constant(it.ind));
var castExpr = Expression.Convert(arrayAccessExpr, it.fieldInfo.FieldType);
var assignment = Expression.Assign(fieldExp, castExpr);
assignments.Add(assignment);
}
assignments.Add(variableExpr);
var block = Expression.Block(new[] { variableExpr }, assignments);
makeStructExpr = Expression.Lambda<Func<object[], TStruct>>(block, arrayParamExpr);
}
MakeStruct = makeStructExpr.Compile();
Expression<Action<TStruct, object[]>> splitStructExpr;
{
var arrayParamExpr = Expression.Parameter(typeof(object[]), "arrayInst");
var structParamExpr = Expression.Parameter(typeof(TStruct), "structInst");
var assignments = new List<Expression>();
foreach (var it in Fields.Select((fieldInfo, ind) => new { fieldInfo, ind }))
{
var arrayAccessExpr = Expression.ArrayAccess(arrayParamExpr, Expression.Constant(it.ind));
var fieldExp = Expression.Field(structParamExpr, it.fieldInfo);
var castExpr = Expression.Convert(fieldExp, typeof(object));
var assignment = Expression.Assign(arrayAccessExpr, castExpr);
assignments.Add(assignment);
}
assignments.Add(Expression.Label(Expression.Label()));
var block = Expression.Block(assignments);
splitStructExpr = Expression.Lambda<Action<TStruct, object[]>>(block, new[] { structParamExpr, arrayParamExpr });
}
SplitStruct = splitStructExpr.Compile();
}
private Dictionary<FieldInfo, Array> fieldArrays = new Dictionary<FieldInfo, Array>();
public SoA(int length)
{
foreach (var fi in Fields)
{
fieldArrays[fi] = Array.CreateInstance(fi.FieldType, length);
}
Count = length;
}
public int Count { get; private set; }
public TField[] FieldArray<TField>(Expression<Func<TStruct, TField>> expr)
{
FieldInfo fi = (FieldInfo)((MemberExpression)expr.Body).Member;
return (TField[])fieldArrays[fi];
}
private TStruct GetAtIndex(int index)
{
object[] makeStructParams = new object[Fields.Count];
foreach (var kvp in fieldArrays)
{
makeStructParams[Fields.IndexOf(kvp.Key)] = kvp.Value.GetValue(index);
}
return MakeStruct(makeStructParams);
}
private void SetAtIndex(TStruct val, int index)
{
object[] structUnpack = new object[Fields.Count];
SplitStruct(val, structUnpack);
foreach (var kvp in fieldArrays)
{
kvp.Value.SetValue(structUnpack[Fields.IndexOf(kvp.Key)], index);
}
}
public TStruct this[int index]
{
get => GetAtIndex(index);
set => SetAtIndex(value, index);
}
public int IndexOf(TStruct item)
{
object heuristic = null;
{
object[] structUnpack = new object[Fields.Count];
SplitStruct(item, structUnpack);
heuristic = structUnpack[0];
}
for (var i = 0; i < Count; i++)
{
if (heuristic.Equals(fieldArrays[Fields[0]].GetValue(i))
&& item.Equals(GetAtIndex(i)))
{
return i;
}
}
return -1;
}
public bool IsReadOnly => false;
public void Insert(int index, TStruct item) => throw new NotSupportedException();
public void RemoveAt(int index) => throw new NotSupportedException();
public void Add(TStruct item) => throw new NotSupportedException();
public bool Remove(TStruct item) => throw new NotSupportedException();
public void Clear()
{
foreach (var kvp in fieldArrays)
{
Array.Clear(kvp.Value, 0, Count);
}
}
public bool Contains(TStruct item)
{
return IndexOf(item) != -1;
}
public void CopyTo(TStruct[] array, int arrayIndex)
{
for (var i = 0; i < Count; i++)
{
array[i + arrayIndex] = GetAtIndex(i);
}
}
public IEnumerator<TStruct> GetEnumerator()
{
return Enumerable.Range(0, Count).Select(GetAtIndex).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return Enumerable.Range(0, Count).Select(GetAtIndex).GetEnumerator();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment