Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Generic arithmetic operations with an example
public static class GenericArithmeticFactory<T>
{
private static Dictionary<ExpressionType, Func<T, T, T>> _binaryOperationDictionary;
private static Func<T, T> _sqrtOperation = null;
static GenericArithmeticFactory()
{
_binaryOperationDictionary = new Dictionary<ExpressionType, Func<T, T, T>>();
}
public static Func<T, T, T> GetArithmeticOperation(ExpressionType operationType)
{
if (_binaryOperationDictionary.ContainsKey(operationType))
{
return _binaryOperationDictionary[operationType];
}
ParameterExpression left = Expression.Parameter(typeof(T), "left");
ParameterExpression right = Expression.Parameter(typeof(T), "right");
BinaryExpression operation = null;
if (operationType == ExpressionType.Add) { operation = Expression.Add(left, right); }
else if (operationType == ExpressionType.Subtract) { operation = Expression.Subtract(left, right); }
else if (operationType == ExpressionType.Multiply) { operation = Expression.Multiply(left, right); }
else if (operationType == ExpressionType.Divide) { operation = Expression.Divide(left, right); }
else if (operationType == ExpressionType.Modulo) { operation = Expression.Modulo(left, right); }
else if (operationType == ExpressionType.Power) { operation = Expression.Power(left, right); }
else
{
throw new NotSupportedException($"{nameof(ExpressionType)} not supported: {Enum.GetName(typeof(ExpressionType), operationType)}.");
}
Func<T, T, T> result = Expression.Lambda<Func<T, T, T>>(operation, left, right).Compile();
_binaryOperationDictionary.Add(operationType, result);
return result;
}
public static Func<T, T> GetSquareRootOperation()
{
if (_sqrtOperation == null)
{
MethodInfo method = typeof(Math).GetMethod("Sqrt", BindingFlags.Static | BindingFlags.Public);
ParameterExpression parameter = Expression.Parameter(typeof(T), "input");
Expression argument = ConvertIfNeeded(parameter, typeof(double));
MethodCallExpression methodCall = Expression.Call(method, argument);
Expression convertedMethod = ConvertIfNeeded(methodCall, typeof(T));
_sqrtOperation = Expression.Lambda<Func<T, T>>(convertedMethod, parameter).Compile();
}
return _sqrtOperation;
}
private static Expression ConvertIfNeeded(Expression valueExpression, Type targetType)
{
Type expressionType = null;
if (valueExpression.NodeType == ExpressionType.Parameter)
{
expressionType = ((ParameterExpression)valueExpression).Type;
}
else if (valueExpression.NodeType == ExpressionType.Call)
{
expressionType = ((MethodCallExpression)valueExpression).Method.ReturnType;
}
if (expressionType != targetType)
{
return Expression.Convert(valueExpression, targetType);
}
return valueExpression;
}
}
/// <summary>
/// A generic vector class
/// </summary>
public class Vector<T>
{
public int Dimensions { get { return Elements == null ? 0 : Elements.Length; } }
public T[] Elements { get; set; }
public T this[int index] { get { return Elements[index]; } }
public Vector() { Elements = new T[0]; }
public Vector(T[] elements)
{
Elements = elements;
}
public static T Norm(Vector<T> input)
{
Func<T, T> squareRootOperation = GenericArithmeticFactory<T>.GetSquareRootOperation();
T dotProduct = DotProduct(input, input);
return squareRootOperation.Invoke(dotProduct);
}
public static Vector<T> Normalize(Vector<T> input)
{
T norm = Norm(input);
return ScalarDivide(input, norm);
}
public static Vector<T> Add(Vector<T> left, Vector<T> right)
{
Func<T, T, T> addOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Add);
return PairwiseForEach(left, right, addOperation);
}
public static Vector<T> Subtract(Vector<T> left, Vector<T> right)
{
Func<T, T, T> subtractOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Subtract);
return PairwiseForEach(left, right, subtractOperation);
}
public static Vector<T> Multiply(Vector<T> left, Vector<T> right)
{
Func<T, T, T> multiplyOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Multiply);
return PairwiseForEach(left, right, multiplyOperation);
}
public static Vector<T> Divide(Vector<T> left, Vector<T> right)
{
Func<T, T, T> divideOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Divide);
return PairwiseForEach(left, right, divideOperation);
}
public static T DotProduct(Vector<T> left, Vector<T> right)
{
Func<T, T, T> addOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Add);
Vector<T> productVector = Multiply(left, right);
T result = default(T);
bool isFirstPass = true;
foreach (T t in productVector.Elements)
{
if (isFirstPass)
{
isFirstPass = false;
result = t;
}
else
{
result = addOperation(result, t);
}
}
return result;
}
public static Vector<T> ScalarAdd(Vector<T> vector, T scalar)
{
Func<T, T, T> addOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Add);
Vector<T> scalarVector = new Vector<T>(Enumerable.Repeat(scalar, vector.Dimensions).ToArray());
return PairwiseForEach(vector, scalarVector, addOperation);
}
public static Vector<T> ScalarMultiply(Vector<T> vector, T scalar)
{
Func<T, T, T> multiplyOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Multiply);
Vector<T> scalarVector = new Vector<T>(Enumerable.Repeat(scalar, vector.Dimensions).ToArray());
return PairwiseForEach(vector, scalarVector, multiplyOperation);
}
public static Vector<T> ScalarDivide(Vector<T> vector, T scalar)
{
Func<T, T, T> divideOperation = GenericArithmeticFactory<T>.GetArithmeticOperation(ExpressionType.Divide);
Vector<T> scalarVector = new Vector<T>(Enumerable.Repeat(scalar, vector.Dimensions).ToArray());
return PairwiseForEach(vector, scalarVector, divideOperation);
}
private static Vector<T> PairwiseForEach(Vector<T> left, Vector<T> right, Func<T, T, T> operation)
{
if (left.Dimensions != right.Dimensions)
{
throw new Exception("Both vector dimensions must be the same.");
}
int index = 0;
int max = left.Dimensions;
List<T> results = new List<T>();
while (index < max)
{
T result = operation.Invoke(left[index], right[index]);
results.Add(result);
index++;
}
return new Vector<T>(results.ToArray());
}
public override string ToString()
{
return $"[ {string.Join(", ", this.Elements.Select(e => e.ToString()))} ]";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment