Skip to content

Instantly share code, notes, and snippets.

@zlepper
Created October 14, 2020 14:07
Show Gist options
  • Save zlepper/7f55ab76547d81eb6eb403ad4feab06b to your computer and use it in GitHub Desktop.
Save zlepper/7f55ab76547d81eb6eb403ad4feab06b to your computer and use it in GitHub Desktop.
using System.Collections.Generic;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Update;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using PocPerf.EfOverrides;
namespace PocPerf
{
public static class DgzEfContextExtensions
{
public static DbContextOptionsBuilder ExtendPostgres(this DbContextOptionsBuilder efContextBuilder)
{
var builder = (IDbContextOptionsBuilderInfrastructure) efContextBuilder;
var extension = efContextBuilder.Options.FindExtension<DigizuiteContextExtensions>() ?? new DigizuiteContextExtensions();
builder.AddOrUpdateExtension(extension);
return efContextBuilder;
}
}
public class DigizuiteContextExtensions : IDbContextOptionsExtension
{
public void ApplyServices(IServiceCollection services)
{
services.Replace(new ServiceDescriptor(typeof(IUpdateSqlGenerator), typeof(DigizuiteUpdateSqlGenerator),
ServiceLifetime.Singleton));
services.Replace(new ServiceDescriptor(typeof(IModificationCommandBatchFactory), typeof(DigizuiteModificationCommandBatchFactory),
ServiceLifetime.Scoped));
}
public void Validate(IDbContextOptions options)
{
}
public DigizuiteContextExtensions()
{
Info = new DigizuiteContextInfo(this);
}
public DbContextOptionsExtensionInfo Info { get; }
}
public class DigizuiteContextInfo : DbContextOptionsExtensionInfo
{
public DigizuiteContextInfo(IDbContextOptionsExtension extension) : base(extension)
{
}
public override long GetServiceProviderHashCode()
{
return 0;
}
public override void PopulateDebugInfo(IDictionary<string, string> debugInfo)
{
}
public override bool IsDatabaseProvider => false;
public override string LogFragment => "'Digizuite'";
}
}
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Update;
using Npgsql;
using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal;
namespace PocPerf.EfOverrides
{
public class DigizuiteModificationCommandBatchFactory : IModificationCommandBatchFactory
{
readonly ModificationCommandBatchFactoryDependencies _dependencies;
readonly IDbContextOptions _options;
public DigizuiteModificationCommandBatchFactory(
ModificationCommandBatchFactoryDependencies dependencies,
IDbContextOptions options)
{
_dependencies = dependencies ?? throw new ArgumentNullException(nameof(dependencies));
_options = options;
}
public ModificationCommandBatch Create()
{
var optionsExtension = _options?.Extensions.OfType<NpgsqlOptionsExtension>().FirstOrDefault();
return new DigizuiteModificationCommandBatch(_dependencies, optionsExtension?.MaxBatchSize);
}
}
/// <summary>
/// A lot of the code here is stolen from SqlServerModificationCommandBatch,
/// which seems to support this somewhat
/// </summary>
public class DigizuiteModificationCommandBatch : NpgsqlModificationCommandBatch
{
private DigizuiteUpdateSqlGenerator _digizuiteUpdateSqlGenerator;
public DigizuiteModificationCommandBatch(ModificationCommandBatchFactoryDependencies dependencies,
int? maxBatchSize) : base(dependencies, maxBatchSize)
{
// ReSharper disable once VirtualMemberCallInConstructor
if (UpdateSqlGenerator is DigizuiteUpdateSqlGenerator d)
{
_digizuiteUpdateSqlGenerator = d;
}
else
{
throw new Exception($"{nameof(UpdateSqlGenerator)} is not a {nameof(DigizuiteUpdateSqlGenerator)}");
}
}
protected override void ResetCommandText()
{
base.ResetCommandText();
_bulkInsertCommands.Clear();
}
private List<ModificationCommand> _bulkInsertCommands = new List<ModificationCommand>();
protected override string GetCommandText()
=> base.GetCommandText() + GetBulkInsertCommandText(ModificationCommands.Count);
private string GetBulkInsertCommandText(int lastIndex)
{
if (_bulkInsertCommands.Count == 0)
{
return string.Empty;
}
var stringBuilder = new StringBuilder();
var resultSetMapping = _digizuiteUpdateSqlGenerator.AppendBulkInsertOperation(
stringBuilder, _bulkInsertCommands, lastIndex - _bulkInsertCommands.Count);
for (var i = lastIndex - _bulkInsertCommands.Count; i < lastIndex; i++)
{
CommandResultSet[i] = resultSetMapping;
}
if (resultSetMapping != ResultSetMapping.NoResultSet)
{
CommandResultSet[lastIndex - 1] = ResultSetMapping.LastInResultSet;
}
return stringBuilder.ToString();
}
protected override void UpdateCachedCommandText(int commandPosition)
{
var newModificationCommand = ModificationCommands[commandPosition];
if (newModificationCommand.EntityState == EntityState.Added)
{
if (_bulkInsertCommands.Count > 0
&& !CanBeInsertedInSameStatement(_bulkInsertCommands[0], newModificationCommand))
{
CachedCommandText.Append(GetBulkInsertCommandText(commandPosition));
_bulkInsertCommands.Clear();
}
_bulkInsertCommands.Add(newModificationCommand);
LastCachedCommandIndex = commandPosition;
}
else
{
CachedCommandText.Append(GetBulkInsertCommandText(commandPosition));
_bulkInsertCommands.Clear();
base.UpdateCachedCommandText(commandPosition);
}
}
private static bool CanBeInsertedInSameStatement(ModificationCommand firstCommand, ModificationCommand secondCommand)
=> string.Equals(firstCommand.TableName, secondCommand.TableName, StringComparison.Ordinal)
&& string.Equals(firstCommand.Schema, secondCommand.Schema, StringComparison.Ordinal)
&& firstCommand.ColumnModifications.Where(o => o.IsWrite).Select(o => o.ColumnName).SequenceEqual(
secondCommand.ColumnModifications.Where(o => o.IsWrite).Select(o => o.ColumnName))
&& firstCommand.ColumnModifications.Where(o => o.IsRead).Select(o => o.ColumnName).SequenceEqual(
secondCommand.ColumnModifications.Where(o => o.IsRead).Select(o => o.ColumnName));
protected override async Task ConsumeAsync(
RelationalDataReader reader,
CancellationToken cancellationToken = default)
{
Debug.Assert(CommandResultSet.Count == ModificationCommands.Count);
var commandIndex = 0;
try
{
var actualResultSetCount = 0;
do
{
while (commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex] == ResultSetMapping.NoResultSet)
{
commandIndex++;
}
if (commandIndex < CommandResultSet.Count)
{
commandIndex = ModificationCommands[commandIndex].RequiresResultPropagation
? await ConsumeResultSetWithPropagationAsync(commandIndex, reader, cancellationToken)
: await ConsumeResultSetWithoutPropagationAsync(commandIndex, reader, cancellationToken);
actualResultSetCount++;
}
}
while (commandIndex < CommandResultSet.Count
&& await reader.DbDataReader.NextResultAsync(cancellationToken));
#if DEBUG
while (commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex] == ResultSetMapping.NoResultSet)
{
commandIndex++;
}
Debug.Assert(
commandIndex == ModificationCommands.Count,
"Expected " + ModificationCommands.Count + " results, got " + commandIndex);
var expectedResultSetCount = CommandResultSet.Count(e => e == ResultSetMapping.LastInResultSet);
Debug.Assert(
actualResultSetCount == expectedResultSetCount,
"Expected " + expectedResultSetCount + " result sets, got " + actualResultSetCount);
#endif
}
catch (Exception ex) when (!(ex is DbUpdateException))
{
throw new DbUpdateException(
RelationalStrings.UpdateStoreException,
ex,
ModificationCommands[commandIndex].Entries);
}
}
protected virtual async Task<int> ConsumeResultSetWithPropagationAsync(
int commandIndex, [NotNull] RelationalDataReader reader, CancellationToken cancellationToken)
{
var rowsAffected = 0;
do
{
var tableModification = ModificationCommands[commandIndex];
Debug.Assert(tableModification.RequiresResultPropagation);
if (!await reader.ReadAsync(cancellationToken))
{
var expectedRowsAffected = rowsAffected + 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet)
{
expectedRowsAffected++;
}
ThrowAggregateUpdateConcurrencyException(commandIndex, expectedRowsAffected, rowsAffected);
}
var valueBufferFactory = CreateValueBufferFactory(tableModification.ColumnModifications);
tableModification.PropagateResults(valueBufferFactory.Create(reader.DbDataReader));
rowsAffected++;
}
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet);
return commandIndex;
}
protected virtual async Task<int> ConsumeResultSetWithoutPropagationAsync(
int commandIndex, [NotNull] RelationalDataReader reader, CancellationToken cancellationToken)
{
var expectedRowsAffected = 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet)
{
Debug.Assert(!ModificationCommands[commandIndex].RequiresResultPropagation);
expectedRowsAffected++;
}
if (await reader.ReadAsync(cancellationToken))
{
var rowsAffected = reader.DbDataReader.GetInt32(0);
if (rowsAffected != expectedRowsAffected)
{
ThrowAggregateUpdateConcurrencyException(commandIndex, expectedRowsAffected, rowsAffected);
}
}
else
{
ThrowAggregateUpdateConcurrencyException(commandIndex, 1, 0);
}
return commandIndex;
}
protected virtual void ThrowAggregateUpdateConcurrencyException(
int commandIndex,
int expectedRowsAffected,
int rowsAffected)
{
throw new DbUpdateConcurrencyException(
RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected),
AggregateEntries(commandIndex, expectedRowsAffected));
}
private IReadOnlyList<IUpdateEntry> AggregateEntries(int endIndex, int commandCount)
{
var entries = new List<IUpdateEntry>();
for (var i = endIndex - commandCount; i < endIndex; i++)
{
entries.AddRange(ModificationCommands[i].Entries);
}
return entries;
}
}
}
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Text;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Update;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
using Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal;
namespace PocPerf.EfOverrides
{
public class DigizuiteUpdateSqlGenerator : NpgsqlUpdateSqlGenerator
{
public DigizuiteUpdateSqlGenerator(UpdateSqlGeneratorDependencies dependencies) : base(dependencies)
{
}
/// <summary>
/// Generates bulk inserts in the database. Does make some assumptions about how the
/// commands are passed into the functions. Mainly that they all insert for the same
/// entity type with the exactly same columns in the exact same order.
/// </summary>
public ResultSetMapping AppendBulkInsertOperation(
[NotNull] StringBuilder commandStringBuilder,
[NotNull] IReadOnlyList<ModificationCommand> modificationCommands,
int commandPosition)
{
if (modificationCommands.Count == 1
&& modificationCommands[0].ColumnModifications.All(
o =>
!o.IsKey
|| !o.IsRead))
{
return AppendInsertOperation(commandStringBuilder, modificationCommands[0], commandPosition);
}
var readOperations = modificationCommands[0].ColumnModifications.Where(o => o.IsRead).ToList();
var writeOperations = modificationCommands[0].ColumnModifications.Where(o => o.IsWrite).ToList();
var defaultValuesOnly = writeOperations.Count == 0;
var nonIdentityOperations = modificationCommands[0].ColumnModifications
.Where(o =>o.Property != null && NpgsqlPropertyExtensions.GetValueGenerationStrategy(o.Property) != NpgsqlValueGenerationStrategy.IdentityAlwaysColumn)
.ToList();
if (defaultValuesOnly)
{
if (nonIdentityOperations.Count == 0
|| readOperations.Count == 0)
{
foreach (var modification in modificationCommands)
{
AppendInsertOperation(commandStringBuilder, modification, commandPosition);
}
return readOperations.Count == 0
? ResultSetMapping.NoResultSet
: ResultSetMapping.LastInResultSet;
}
if (nonIdentityOperations.Count > 1)
{
nonIdentityOperations.RemoveRange(1, nonIdentityOperations.Count - 1);
}
}
if (readOperations.Count == 0)
{
return AppendBulkInsertWithoutServerValues(commandStringBuilder, modificationCommands, writeOperations);
}
if (defaultValuesOnly)
{
return AppendBulkInsertWithServerValuesOnly(
commandStringBuilder, modificationCommands, nonIdentityOperations, readOperations);
}
return AppendBulkInsertWithServerValues(
commandStringBuilder, modificationCommands, writeOperations, readOperations);
}
private ResultSetMapping AppendBulkInsertWithServerValues(StringBuilder commandStringBuilder,
IReadOnlyList<ModificationCommand> modificationCommands, List<ColumnModification> writeOperations, List<ColumnModification> readOperations)
{
var name = modificationCommands[0].TableName;
var schema = modificationCommands[0].Schema;
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations);
AppendValuesHeader(commandStringBuilder, writeOperations);
AppendValues(commandStringBuilder, writeOperations);
for (var i = 1; i < modificationCommands.Count; i++)
{
commandStringBuilder.AppendLine(",");
AppendValues(commandStringBuilder, modificationCommands[i].ColumnModifications.Where(o => o.IsWrite).ToList());
}
AppendReturningClause(commandStringBuilder, readOperations);
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator);
return ResultSetMapping.LastInResultSet;
}
private ResultSetMapping AppendBulkInsertWithoutServerValues(
StringBuilder commandStringBuilder,
IReadOnlyList<ModificationCommand> modificationCommands,
List<ColumnModification> writeOperations)
{
var name = modificationCommands[0].TableName;
var schema = modificationCommands[0].Schema;
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations);
AppendValuesHeader(commandStringBuilder, writeOperations);
AppendValues(commandStringBuilder, writeOperations);
for (var i = 1; i < modificationCommands.Count; i++)
{
commandStringBuilder.AppendLine(",");
AppendValues(commandStringBuilder, modificationCommands[i].ColumnModifications.Where(o => o.IsWrite).ToList());
}
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator);
return ResultSetMapping.NoResultSet;
}
private ResultSetMapping AppendInsertOperationWithServerKeys(
StringBuilder commandStringBuilder,
ModificationCommand command,
IReadOnlyList<ColumnModification> readOperations)
{
var name = command.TableName;
var schema = command.Schema;
var operations = command.ColumnModifications;
var writeOperations = operations.Where(o => o.IsWrite).ToList();
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations);
AppendValuesHeader(commandStringBuilder, writeOperations);
AppendValues(commandStringBuilder, writeOperations);
AppendReturningClause(commandStringBuilder, readOperations);
commandStringBuilder.Append(SqlGenerationHelper.StatementTerminator);
return ResultSetMapping.LastInResultSet;
}
private ResultSetMapping AppendBulkInsertWithServerValuesOnly(
StringBuilder commandStringBuilder,
IReadOnlyList<ModificationCommand> modificationCommands,
List<ColumnModification> nonIdentityOperations,
List<ColumnModification> readOperations)
{
var name = modificationCommands[0].TableName;
var schema = modificationCommands[0].Schema;
AppendInsertCommandHeader(commandStringBuilder, name, schema, nonIdentityOperations);
AppendValuesHeader(commandStringBuilder, nonIdentityOperations);
AppendValues(commandStringBuilder, nonIdentityOperations);
for (var i = 1; i < modificationCommands.Count; i++)
{
commandStringBuilder.AppendLine(",");
AppendValues(commandStringBuilder, nonIdentityOperations);
}
AppendReturningClause(commandStringBuilder, readOperations);
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator);
return ResultSetMapping.LastInResultSet;
}
// ReSharper disable once ParameterTypeCanBeEnumerable.Local
private void AppendReturningClause(
StringBuilder commandStringBuilder,
IReadOnlyList<ColumnModification> operations)
{
commandStringBuilder
.AppendLine()
.Append("RETURNING ")
.AppendJoin(",", operations.Select(c => SqlGenerationHelper.DelimitIdentifier(c.ColumnName)));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment