Skip to content

Instantly share code, notes, and snippets.

@AdamWhiteHat
Last active January 25, 2023 17:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AdamWhiteHat/71d548ebfb2ee67fcbd78a39a9751423 to your computer and use it in GitHub Desktop.
Save AdamWhiteHat/71d548ebfb2ee67fcbd78a39a9751423 to your computer and use it in GitHub Desktop.
Generic arithmetic operations with an example
public static class GenericArithmeticFactory<T>
{
private static Dictionary<ExpressionType, Func<T, T, T>> _binaryOperationDictionary;
private static Dictionary<ExpressionType, Func<T, T, bool>> _comparisonOperationDictionary;
private static Func<T, T> _sqrtOperation = null;
static GenericArithmeticFactory()
{
_binaryOperationDictionary = new Dictionary<ExpressionType, Func<T, T, T>>();
_comparisonOperationDictionary = new Dictionary<ExpressionType, Func<T, T, bool>>();
}
public static Func<T, T, T> GetArithmeticOperation(ExpressionType operationType)
{
if (_binaryOperationDictionary.ContainsKey(operationType))
{
return _binaryOperationDictionary[operationType];
}
ParameterExpression left = Expression.Parameter(typeof(T), "left");
ParameterExpression right = Expression.Parameter(typeof(T), "right");
BinaryExpression operation = null;
if (operationType == ExpressionType.Add) { operation = Expression.Add(left, right); }
else if (operationType == ExpressionType.Subtract) { operation = Expression.Subtract(left, right); }
else if (operationType == ExpressionType.Multiply) { operation = Expression.Multiply(left, right); }
else if (operationType == ExpressionType.Divide) { operation = Expression.Divide(left, right); }
else if (operationType == ExpressionType.Modulo) { operation = Expression.Modulo(left, right); }
else if (operationType == ExpressionType.Power) { operation = Expression.Power(left, right); }
else
{
throw new NotSupportedException($"{nameof(ExpressionType)} not supported: {Enum.GetName(typeof(ExpressionType), operationType)}.");
}
Func<T, T, T> result = Expression.Lambda<Func<T, T, T>>(operation, left, right).Compile();
_binaryOperationDictionary.Add(operationType, result);
return result;
}
public static Func<T, T, bool> GetComparisonOperation(ExpressionType operationType)
{
if (_comparisonOperationDictionary.ContainsKey(operationType))
{
return _comparisonOperationDictionary[operationType];
}
ParameterExpression left = Expression.Parameter(typeof(T), "left");
ParameterExpression right = Expression.Parameter(typeof(T), "right");
BinaryExpression comparison = null;
if (operationType == ExpressionType.GreaterThan) { comparison = Expression.GreaterThan(left, right); }
else if (operationType == ExpressionType.LessThan) { comparison = Expression.LessThan(left, right); }
else if (operationType == ExpressionType.GreaterThanOrEqual) { comparison = Expression.GreaterThanOrEqual(left, right); }
else if (operationType == ExpressionType.LessThanOrEqual) { comparison = Expression.LessThanOrEqual(left, right); }
else if (operationType == ExpressionType.Equal) { comparison = Expression.Equal(left, right); }
else if (operationType == ExpressionType.NotEqual) { comparison = Expression.NotEqual(left, right); }
Func<T, T, bool> result = Expression.Lambda<Func<T, T, bool>>(comparison, left, right).Compile();
_comparisonOperationDictionary.Add(operationType, result);
return result;
}
public static Func<T, T> GetSquareRootOperation()
{
if (_sqrtOperation == null)
{
MethodInfo method = typeof(Math).GetMethod("Sqrt", BindingFlags.Static | BindingFlags.Public);
ParameterExpression parameter = Expression.Parameter(typeof(T), "input");
Expression argument = ConvertIfNeeded(parameter, typeof(double));
MethodCallExpression methodCall = Expression.Call(method, argument);
Expression convertedMethod = ConvertIfNeeded(methodCall, typeof(T));
_sqrtOperation = Expression.Lambda<Func<T, T>>(convertedMethod, parameter).Compile();
}
return _sqrtOperation;
}
public static Func<T, T> GetRoundOperation()
{
if (typeof(T) == typeof(Complex))
{
return (n) => n;
}
MethodInfo method = typeof(Math).GetMethod("Round", new Type[] { typeof(double) });
if (method == null) { throw new MissingMethodException("Math", "Round"); }
Type typeFromHandle = typeof(T);
ParameterExpression parameter = Expression.Parameter(typeFromHandle, "input");
Expression argument = ConvertIfNeeded(parameter, typeof(double));
MethodCallExpression methodCall = Expression.Call(method, argument);
Expression convertedMethod = ConvertIfNeeded(methodCall, typeFromHandle);
Func<T, T> result = Expression.Lambda<Func<T, T>>(convertedMethod, parameter).Compile();
return result;
}
private static Expression ConvertIfNeeded(Expression valueExpression, Type targetType)
{
Type expressionType = null;
if (valueExpression.NodeType == ExpressionType.Parameter)
{
expressionType = ((ParameterExpression)valueExpression).Type;
}
else if (valueExpression.NodeType == ExpressionType.Call)
{
expressionType = ((MethodCallExpression)valueExpression).Method.ReturnType;
}
if (expressionType != targetType)
{
return Expression.Convert(valueExpression, targetType);
}
return valueExpression;
}
public static class ConvertImplementation<TFrom, TTo>
{
private static Func<TFrom, TTo> _convertFunction = null;
public static TTo Convert(TFrom value)
{
if (typeof(T) == typeof(Complex))
{
return (TTo)((object)new Complex((double)System.Convert.ChangeType(value, typeof(double)), 0d));
}
if (_convertFunction == null)
{
_convertFunction = CreateConvertFunction();
}
return _convertFunction.Invoke(value);
}
private static Func<TFrom, TTo> CreateConvertFunction()
{
ParameterExpression value = Expression.Parameter(typeof(TFrom), "value");
Expression convert = Expression.Convert(value, typeof(TTo));
Func<TFrom, TTo> result = Expression.Lambda<Func<TFrom, TTo>>(convert, value).Compile();
return result;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment