Skip to content

Instantly share code, notes, and snippets.

@szalapski
Created January 8, 2018 15:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szalapski/96131c43ceefbf66149e5f6f855a4449 to your computer and use it in GitHub Desktop.
Save szalapski/96131c43ceefbf66149e5f6f855a4449 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
namespace GeneralMills.ResourceDemand.Core.Extensions
{
/// <summary>
/// Contains static methods for working with data queries and LINQ.
/// </summary>
public static class DataUtilities
{
/// <summary>
/// Gives an expression that tests if any of a list of values is contained in a collection (such as in a database column).
/// </summary>
/// <typeparam name="TElement">The type of the parent object or entity.</typeparam>
/// <typeparam name="TValue">The type of the value to match</typeparam>
/// <param name="valueSelector">A function that, when called with the parent object or entity, returns the value to match.</param>
/// <param name="values">A list of items to match</param>
/// <returns>An expression suitable for use in a call to .Where(expression) </returns>
/// <remarks>Useful to help generate an efficient IN query when you have a list of values to match.
/// Without this, the alternative (the standard LINQ way using .Contains) usually results in long, inefficient UNION queries.
/// </remarks>
/// <example>
/// To query a database table (entities) `Foos` of type `FooSql` to return any row where `Foo.Id` matches a list of IDs `idList`:
/// <code>Foos.Where(DataUtilities.In&lt;FooSql, Guid&gt;(x => x.Id, idList));</code>
/// </example>
public static Expression<Func<TElement, bool>> In<TElement, TValue>(
Expression<Func<TElement, TValue>> valueSelector, IEnumerable<TValue> values)
{
if (null == valueSelector) { throw new ArgumentNullException(nameof(valueSelector)); }
if (null == values) { throw new ArgumentNullException(nameof(values)); }
ParameterExpression p = valueSelector.Parameters.Single();
if (!values.Any())
{
return e => false;
}
IEnumerable<Expression> equals = values.Select(value => (Expression)Expression.Equal(valueSelector.Body, Expression.Constant(value, typeof(TValue))));
Expression body = @equals.Aggregate(Expression.Or);
return Expression.Lambda<Func<TElement, bool>>(body, p);
}
public static IOrderedQueryable<TSource> OrderByDirection<TSource, TKey>(
this IQueryable<TSource> source, Expression<Func<TSource, TKey>> keySelector, bool descending) =>
descending ? source.OrderByDescending(keySelector) : source.OrderBy(keySelector);
public static IOrderedQueryable<TSource> ThenByDirection<TSource, TKey>(
this IOrderedQueryable<TSource> source, Expression<Func<TSource, TKey>> keySelector, bool descending) =>
descending ? source.ThenByDescending(keySelector) : source.ThenBy(keySelector);
/// <summary>
/// Returns a queryable object filtered to where a string value from the queryable object contains a value, but only when the value is not null or whitespace.
/// If the value is null or whitespace, returns the queryable object unaltered.
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="filterExpression">A function on each item in query that returns the string value to compare against (the potentially containing string).</param>
/// <param name="filterValue">The value to filter on.</param>
public static IQueryable<T> WhereEqualIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, string>> filterExpression,
string filterValue)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (filterExpression == null) throw new ArgumentNullException(nameof(filterExpression));
return string.IsNullOrWhiteSpace(filterValue)
? query
: query.Where(AreEqual(filterExpression, filterValue));
}
/// <summary>
/// Returns a queryable object filtered to where a nullable guid value from the queryable object equals a value if it is specified.
/// If the value is null, returns the queryable object unaltered.
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable guid value to compare against.</param>
/// <param name="filterValue">The value to filter on.</param>
public static IQueryable<T> WhereEqualIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, Guid?>> sourceExpression,
Guid? filterValue)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return filterValue.HasValue
? query.Where(AreEqual(sourceExpression, filterValue))
: query;
}
/// <summary>
/// Returns a queryable object filtered to where a nullable decimal value from the queryable object equals a value if it is specified.
/// If the value is null, returns the queryable object unaltered.
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable decimal value to compare against.</param>
/// <param name="filterValue">The value to filter on.</param>
public static IQueryable<T> WhereEqualIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, decimal?>> sourceExpression,
decimal? filterValue)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return filterValue.HasValue
? query.Where(AreEqual(sourceExpression, filterValue))
: query;
}
/// <summary>
/// Returns a queryable object filtered to where a nullable integer value from the queryable object equals a value if it is specified.
/// If the value is null, returns the queryable object unaltered.
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable integer value to compare against.</param>
/// <param name="filterValue">The value to filter on.</param>
public static IQueryable<T> WhereEqualIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, int?>> sourceExpression,
int? filterValue)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return filterValue.HasValue
? query.Where(AreEqual(sourceExpression, filterValue))
: query;
}
/// <summary>
/// Returns a queryable object filtered to where a nullable boolean value from the queryable object equals a nullable value if it is specified.
/// If the value is null, returns the queryable object unaltered.
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable boolean value to compare against.</param>
/// <param name="filterValue">The value to filter on.</param>
public static IQueryable<T> WhereEqualIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, bool?>> sourceExpression,
bool? filterValue)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return filterValue.HasValue
? query.Where(AreEqual(sourceExpression, filterValue))
: query;
}
/// <summary>
/// Returns a queryable object filtered to where a nullable decimal from the queryable object is in a range.
/// Null filter values are interpreted as no limit individually or together
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable value to compare against.</param>
/// <param name="filterValueLow">The lowest value of the range to accept</param>
/// <param name="filterValueHigh">The highest value of the range to accept</param>
public static IQueryable<T> WhereBetweenIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, decimal?>> sourceExpression,
decimal? filterValueLow,
decimal? filterValueHigh)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return query.Where(sourceExpression.Compose(value =>
(filterValueLow == null || value >= filterValueLow) && (filterValueHigh == null || value <= filterValueHigh)
));
}
/// <summary>
/// Returns a queryable object filtered to where a nullable integer from the queryable object is in a range.
/// Null filter values are interpreted as no limit individually or together
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable value to compare against.</param>
/// <param name="filterValueLow">The lowest value of the range to accept</param>
/// <param name="filterValueHigh">The highest value of the range to accept</param>
public static IQueryable<T> WhereBetweenIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, int?>> sourceExpression,
int? filterValueLow,
int? filterValueHigh)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return query.Where(sourceExpression.Compose(value =>
(filterValueLow == null || value >= filterValueLow) && (filterValueHigh == null || value <= filterValueHigh)
));
}
/// <summary>
/// Returns a queryable object filtered to where a nullable date/time from the queryable object is in a range.
/// Null filter values are interpreted as no limit individually or together
/// </summary>
/// <param name="query">The query to filter.</param>
/// <param name="sourceExpression">A function on each item in query that returns the nullable value to compare against.</param>
/// <param name="filterValueLow">The lowest value of the range to accept</param>
/// <param name="filterValueHigh">The highest value of the range to accept</param>
public static IQueryable<T> WhereBetweenIfSpecified<T>(
this IQueryable<T> query,
Expression<Func<T, DateTime?>> sourceExpression,
DateTime? filterValueLow,
DateTime? filterValueHigh)
{
if (query == null) throw new ArgumentNullException(nameof(query));
if (sourceExpression == null) throw new ArgumentNullException(nameof(sourceExpression));
return query.Where(sourceExpression.Compose(value =>
(filterValueLow == null || value >= filterValueLow) && (filterValueHigh == null || value <= filterValueHigh)
));
}
public static Expression<Func<T, bool>> AreEqual<T>(
Expression<Func<T, string>> sourceExpression,
string filterValue) =>
sourceExpression.Compose(value => value == filterValue);
public static Expression<Func<T, bool>> AreEqual<T>(
Expression<Func<T, Guid?>> sourceExpression,
Guid? filterValue) =>
sourceExpression.Compose(value => value == filterValue);
public static Expression<Func<T, bool>> AreEqual<T>(
Expression<Func<T, decimal?>> sourceExpression,
decimal? filterValue) =>
sourceExpression.Compose(value => value == filterValue);
public static Expression<Func<T, bool>> AreEqual<T>(
Expression<Func<T, int?>> sourceExpression,
int? filterValue) =>
sourceExpression.Compose(value => value == filterValue);
public static Expression<Func<T, bool>> AreEqual<T>(
Expression<Func<T, bool?>> sourceExpression,
bool? filterValue) =>
sourceExpression.Compose(value => value == filterValue);
/// <summary>
/// Composes two expressions into one.
/// In other words, if first is x => f(x), and second is y => g(y),
/// you use first.Compose(second) to get x => f(g(x)).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="TIntermediate"></typeparam>
/// <typeparam name="TResult"></typeparam>
/// <param name="first"></param>
/// <param name="second"></param>
/// <returns></returns>
/// <remarks>See http://stackoverflow.com/questions/37602729 </remarks>
public static Expression<Func<T, TResult>> Compose<T, TIntermediate, TResult>(
this Expression<Func<T, TIntermediate>> first,
Expression<Func<TIntermediate, TResult>> second)
{
if (first == null) throw new ArgumentNullException(nameof(first));
if (second == null) throw new ArgumentNullException(nameof(second));
return Expression.Lambda<Func<T, TResult>>(
second.Body.Replace(second.Parameters[0], first.Body),
first.Parameters[0]);
}
private class ReplaceVisitor : ExpressionVisitor
{
private readonly Expression from, to;
public ReplaceVisitor(Expression from, Expression to)
{
this.from = from;
this.to = to;
}
public override Expression Visit(Expression ex)
{
if (ex == from) return to;
else return base.Visit(ex);
}
}
private static Expression Replace(this Expression ex,
Expression from,
Expression to)
{
return new ReplaceVisitor(from, to).Visit(ex);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment