Skip to content

Instantly share code, notes, and snippets.

@clement911
Last active February 28, 2024 09:37
Show Gist options
  • Save clement911/695d7a2ef2ac2845a8090309c5d184e7 to your computer and use it in GitHub Desktop.
Save clement911/695d7a2ef2ac2845a8090309c5d184e7 to your computer and use it in GitHub Desktop.
Record DB commands generated by EF core queries
public static IReadOnlyCollection<DbBatchCommand> RecordCommands<TContext, TResult>(this TContext dbCtx, params Func<TContext, TResult>[] getResults) where TContext : DbContext
{
using (RecordCommandsScope.StartNew())
{
foreach (var getResult in getResults)
getResult(dbCtx);
return RecordCommandsScope.Current.RecordedCommands;
}
}
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder.AddInterceptors(RecordCommandsInterceptor.Instance);
base.OnConfiguring(optionsBuilder);
}
//An ambient scope similar to TransactionScope
//The base class is not provided but there are a number of implementations, which usually leverage AsyncLocal<T>
public class RecordCommandsScope : AmbientScopeBase<RecordCommandsScope>
{
private List<DbBatchCommand> _recordedCommands = new();
public IReadOnlyCollection<DbBatchCommand> RecordedCommands => _recordedCommands.AsReadOnly();
public static RecordCommandsScope StartNew()
{
return StartNew(new RecordCommandsScope());
}
private RecordCommandsScope()
{
}
public void Record(DbBatchCommand dbCommand)
{
this._recordedCommands.Add(dbCommand);
}
}
public class RecordCommandsInterceptor : DbCommandInterceptor
{
public static RecordCommandsInterceptor Instance { get; } = new();
private RecordCommandsInterceptor()
{
}
public override InterceptionResult<DbCommand> CommandCreating(CommandCorrelatedEventData eventData, InterceptionResult<DbCommand> result)
{
if (RecordCommandsScope.Current != null && eventData.IsAsync)
throw new InvalidOperationException("Cannot use Async operation when recording a command");
return base.CommandCreating(eventData, result);
}
public override InterceptionResult<int> NonQueryExecuting(DbCommand cmd, CommandEventData eventData, InterceptionResult<int> result)
=> this.Executing(cmd, result);
public override InterceptionResult<DbDataReader> ReaderExecuting(DbCommand cmd, CommandEventData eventData, InterceptionResult<DbDataReader> result)
=> this.Executing(cmd, result);
public override InterceptionResult<object> ScalarExecuting(DbCommand cmd, CommandEventData eventData, InterceptionResult<object> result)
=> this.Executing(cmd, result);
private InterceptionResult<T> Executing<T>(DbCommand cmd, InterceptionResult<T> result)
{
if (RecordCommandsScope.Current == null)
return result;
var cloned = new Microsoft.Data.SqlClient.SqlBatchCommand()
{
CommandText = cmd.CommandText,
CommandType = cmd.CommandType,
//CommandBehavior = ???
};
cloned.Parameters.AddRange(cmd.Parameters.OfType<SqlParameter>().Select(p => new SqlParameter
{
ParameterName = p.ParameterName,
DbType = p.DbType,
SqlDbType = p.SqlDbType,
Value = p.Value,
SqlValue = p.SqlValue,
Precision = p.Precision,
Scale = p.Scale,
Size = p.Size,
IsNullable = p.IsNullable,
Direction = p.Direction,
//should we copy another other properties?
}).ToArray());
RecordCommandsScope.Current.Record(cloned);
return InterceptionResult<T>.SuppressWithResult(default);
}
}
long productId = 123;
long[] variantIds = [4,5,6];
var dbCommands = ctx.RecordCommands(ctx => ctx.Set<ProductVariant>()
.Where(v => v.ProductId == productId)
.Where(v => !variantIds.Contains(v.Id))
.ExecuteDelete(),
ctx => ctx.Set<ProductVariantOption>()
.Where(v => v.ProductId == productId)
.Where(v => !variantIds.Contains(v.VariantId))
.ExecuteDelete(),
ctx => ctx.Set<ProductVariantMetafield>()
.Where(v => v.ProductId == productId)
.Where(v => !variantIds.Contains(v.VariantId))
.ExecuteDelete());
//do other stuff with context
int rowAffected = await ctxPublic.Database
.CreateExecutionStrategy()
.ExecuteInTransactionAsync(async (token) =>
{
int rowCountAffected = 0;
var batch = new SqlBatch(ctxPublic.Database.GetDbConnection() as SqlConnection, ctxPublic.Database.CurrentTransaction.GetDbTransaction() as SqlTransaction);
batch.Commands.AddRange(dbCommands);
rowCountAffected += await batch.ExecuteNonQueryAsync();
rowCountAffected += await ctxPublic.SaveChangesAsync();
return rowCountAffected;
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment