Created
October 14, 2020 14:07
-
-
Save zlepper/7f55ab76547d81eb6eb403ad4feab06b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
using Microsoft.EntityFrameworkCore; | |
using Microsoft.EntityFrameworkCore.Infrastructure; | |
using Microsoft.EntityFrameworkCore.Update; | |
using Microsoft.Extensions.DependencyInjection; | |
using Microsoft.Extensions.DependencyInjection.Extensions; | |
using PocPerf.EfOverrides; | |
namespace PocPerf | |
{ | |
public static class DgzEfContextExtensions | |
{ | |
public static DbContextOptionsBuilder ExtendPostgres(this DbContextOptionsBuilder efContextBuilder) | |
{ | |
var builder = (IDbContextOptionsBuilderInfrastructure) efContextBuilder; | |
var extension = efContextBuilder.Options.FindExtension<DigizuiteContextExtensions>() ?? new DigizuiteContextExtensions(); | |
builder.AddOrUpdateExtension(extension); | |
return efContextBuilder; | |
} | |
} | |
public class DigizuiteContextExtensions : IDbContextOptionsExtension | |
{ | |
public void ApplyServices(IServiceCollection services) | |
{ | |
services.Replace(new ServiceDescriptor(typeof(IUpdateSqlGenerator), typeof(DigizuiteUpdateSqlGenerator), | |
ServiceLifetime.Singleton)); | |
services.Replace(new ServiceDescriptor(typeof(IModificationCommandBatchFactory), typeof(DigizuiteModificationCommandBatchFactory), | |
ServiceLifetime.Scoped)); | |
} | |
public void Validate(IDbContextOptions options) | |
{ | |
} | |
public DigizuiteContextExtensions() | |
{ | |
Info = new DigizuiteContextInfo(this); | |
} | |
public DbContextOptionsExtensionInfo Info { get; } | |
} | |
public class DigizuiteContextInfo : DbContextOptionsExtensionInfo | |
{ | |
public DigizuiteContextInfo(IDbContextOptionsExtension extension) : base(extension) | |
{ | |
} | |
public override long GetServiceProviderHashCode() | |
{ | |
return 0; | |
} | |
public override void PopulateDebugInfo(IDictionary<string, string> debugInfo) | |
{ | |
} | |
public override bool IsDatabaseProvider => false; | |
public override string LogFragment => "'Digizuite'"; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using System.Diagnostics.CodeAnalysis; | |
using System.Linq; | |
using System.Text; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using Microsoft.EntityFrameworkCore; | |
using Microsoft.EntityFrameworkCore.Diagnostics; | |
using Microsoft.EntityFrameworkCore.Infrastructure; | |
using Microsoft.EntityFrameworkCore.Storage; | |
using Microsoft.EntityFrameworkCore.Update; | |
using Npgsql; | |
using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.Internal; | |
using Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal; | |
namespace PocPerf.EfOverrides | |
{ | |
public class DigizuiteModificationCommandBatchFactory : IModificationCommandBatchFactory | |
{ | |
readonly ModificationCommandBatchFactoryDependencies _dependencies; | |
readonly IDbContextOptions _options; | |
public DigizuiteModificationCommandBatchFactory( | |
ModificationCommandBatchFactoryDependencies dependencies, | |
IDbContextOptions options) | |
{ | |
_dependencies = dependencies ?? throw new ArgumentNullException(nameof(dependencies)); | |
_options = options; | |
} | |
public ModificationCommandBatch Create() | |
{ | |
var optionsExtension = _options?.Extensions.OfType<NpgsqlOptionsExtension>().FirstOrDefault(); | |
return new DigizuiteModificationCommandBatch(_dependencies, optionsExtension?.MaxBatchSize); | |
} | |
} | |
/// <summary> | |
/// A lot of the code here is stolen from SqlServerModificationCommandBatch, | |
/// which seems to support this somewhat | |
/// </summary> | |
public class DigizuiteModificationCommandBatch : NpgsqlModificationCommandBatch | |
{ | |
private DigizuiteUpdateSqlGenerator _digizuiteUpdateSqlGenerator; | |
public DigizuiteModificationCommandBatch(ModificationCommandBatchFactoryDependencies dependencies, | |
int? maxBatchSize) : base(dependencies, maxBatchSize) | |
{ | |
// ReSharper disable once VirtualMemberCallInConstructor | |
if (UpdateSqlGenerator is DigizuiteUpdateSqlGenerator d) | |
{ | |
_digizuiteUpdateSqlGenerator = d; | |
} | |
else | |
{ | |
throw new Exception($"{nameof(UpdateSqlGenerator)} is not a {nameof(DigizuiteUpdateSqlGenerator)}"); | |
} | |
} | |
protected override void ResetCommandText() | |
{ | |
base.ResetCommandText(); | |
_bulkInsertCommands.Clear(); | |
} | |
private List<ModificationCommand> _bulkInsertCommands = new List<ModificationCommand>(); | |
protected override string GetCommandText() | |
=> base.GetCommandText() + GetBulkInsertCommandText(ModificationCommands.Count); | |
private string GetBulkInsertCommandText(int lastIndex) | |
{ | |
if (_bulkInsertCommands.Count == 0) | |
{ | |
return string.Empty; | |
} | |
var stringBuilder = new StringBuilder(); | |
var resultSetMapping = _digizuiteUpdateSqlGenerator.AppendBulkInsertOperation( | |
stringBuilder, _bulkInsertCommands, lastIndex - _bulkInsertCommands.Count); | |
for (var i = lastIndex - _bulkInsertCommands.Count; i < lastIndex; i++) | |
{ | |
CommandResultSet[i] = resultSetMapping; | |
} | |
if (resultSetMapping != ResultSetMapping.NoResultSet) | |
{ | |
CommandResultSet[lastIndex - 1] = ResultSetMapping.LastInResultSet; | |
} | |
return stringBuilder.ToString(); | |
} | |
protected override void UpdateCachedCommandText(int commandPosition) | |
{ | |
var newModificationCommand = ModificationCommands[commandPosition]; | |
if (newModificationCommand.EntityState == EntityState.Added) | |
{ | |
if (_bulkInsertCommands.Count > 0 | |
&& !CanBeInsertedInSameStatement(_bulkInsertCommands[0], newModificationCommand)) | |
{ | |
CachedCommandText.Append(GetBulkInsertCommandText(commandPosition)); | |
_bulkInsertCommands.Clear(); | |
} | |
_bulkInsertCommands.Add(newModificationCommand); | |
LastCachedCommandIndex = commandPosition; | |
} | |
else | |
{ | |
CachedCommandText.Append(GetBulkInsertCommandText(commandPosition)); | |
_bulkInsertCommands.Clear(); | |
base.UpdateCachedCommandText(commandPosition); | |
} | |
} | |
private static bool CanBeInsertedInSameStatement(ModificationCommand firstCommand, ModificationCommand secondCommand) | |
=> string.Equals(firstCommand.TableName, secondCommand.TableName, StringComparison.Ordinal) | |
&& string.Equals(firstCommand.Schema, secondCommand.Schema, StringComparison.Ordinal) | |
&& firstCommand.ColumnModifications.Where(o => o.IsWrite).Select(o => o.ColumnName).SequenceEqual( | |
secondCommand.ColumnModifications.Where(o => o.IsWrite).Select(o => o.ColumnName)) | |
&& firstCommand.ColumnModifications.Where(o => o.IsRead).Select(o => o.ColumnName).SequenceEqual( | |
secondCommand.ColumnModifications.Where(o => o.IsRead).Select(o => o.ColumnName)); | |
protected override async Task ConsumeAsync( | |
RelationalDataReader reader, | |
CancellationToken cancellationToken = default) | |
{ | |
Debug.Assert(CommandResultSet.Count == ModificationCommands.Count); | |
var commandIndex = 0; | |
try | |
{ | |
var actualResultSetCount = 0; | |
do | |
{ | |
while (commandIndex < CommandResultSet.Count | |
&& CommandResultSet[commandIndex] == ResultSetMapping.NoResultSet) | |
{ | |
commandIndex++; | |
} | |
if (commandIndex < CommandResultSet.Count) | |
{ | |
commandIndex = ModificationCommands[commandIndex].RequiresResultPropagation | |
? await ConsumeResultSetWithPropagationAsync(commandIndex, reader, cancellationToken) | |
: await ConsumeResultSetWithoutPropagationAsync(commandIndex, reader, cancellationToken); | |
actualResultSetCount++; | |
} | |
} | |
while (commandIndex < CommandResultSet.Count | |
&& await reader.DbDataReader.NextResultAsync(cancellationToken)); | |
#if DEBUG | |
while (commandIndex < CommandResultSet.Count | |
&& CommandResultSet[commandIndex] == ResultSetMapping.NoResultSet) | |
{ | |
commandIndex++; | |
} | |
Debug.Assert( | |
commandIndex == ModificationCommands.Count, | |
"Expected " + ModificationCommands.Count + " results, got " + commandIndex); | |
var expectedResultSetCount = CommandResultSet.Count(e => e == ResultSetMapping.LastInResultSet); | |
Debug.Assert( | |
actualResultSetCount == expectedResultSetCount, | |
"Expected " + expectedResultSetCount + " result sets, got " + actualResultSetCount); | |
#endif | |
} | |
catch (Exception ex) when (!(ex is DbUpdateException)) | |
{ | |
throw new DbUpdateException( | |
RelationalStrings.UpdateStoreException, | |
ex, | |
ModificationCommands[commandIndex].Entries); | |
} | |
} | |
protected virtual async Task<int> ConsumeResultSetWithPropagationAsync( | |
int commandIndex, [NotNull] RelationalDataReader reader, CancellationToken cancellationToken) | |
{ | |
var rowsAffected = 0; | |
do | |
{ | |
var tableModification = ModificationCommands[commandIndex]; | |
Debug.Assert(tableModification.RequiresResultPropagation); | |
if (!await reader.ReadAsync(cancellationToken)) | |
{ | |
var expectedRowsAffected = rowsAffected + 1; | |
while (++commandIndex < CommandResultSet.Count | |
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet) | |
{ | |
expectedRowsAffected++; | |
} | |
ThrowAggregateUpdateConcurrencyException(commandIndex, expectedRowsAffected, rowsAffected); | |
} | |
var valueBufferFactory = CreateValueBufferFactory(tableModification.ColumnModifications); | |
tableModification.PropagateResults(valueBufferFactory.Create(reader.DbDataReader)); | |
rowsAffected++; | |
} | |
while (++commandIndex < CommandResultSet.Count | |
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet); | |
return commandIndex; | |
} | |
protected virtual async Task<int> ConsumeResultSetWithoutPropagationAsync( | |
int commandIndex, [NotNull] RelationalDataReader reader, CancellationToken cancellationToken) | |
{ | |
var expectedRowsAffected = 1; | |
while (++commandIndex < CommandResultSet.Count | |
&& CommandResultSet[commandIndex - 1] == ResultSetMapping.NotLastInResultSet) | |
{ | |
Debug.Assert(!ModificationCommands[commandIndex].RequiresResultPropagation); | |
expectedRowsAffected++; | |
} | |
if (await reader.ReadAsync(cancellationToken)) | |
{ | |
var rowsAffected = reader.DbDataReader.GetInt32(0); | |
if (rowsAffected != expectedRowsAffected) | |
{ | |
ThrowAggregateUpdateConcurrencyException(commandIndex, expectedRowsAffected, rowsAffected); | |
} | |
} | |
else | |
{ | |
ThrowAggregateUpdateConcurrencyException(commandIndex, 1, 0); | |
} | |
return commandIndex; | |
} | |
protected virtual void ThrowAggregateUpdateConcurrencyException( | |
int commandIndex, | |
int expectedRowsAffected, | |
int rowsAffected) | |
{ | |
throw new DbUpdateConcurrencyException( | |
RelationalStrings.UpdateConcurrencyException(expectedRowsAffected, rowsAffected), | |
AggregateEntries(commandIndex, expectedRowsAffected)); | |
} | |
private IReadOnlyList<IUpdateEntry> AggregateEntries(int endIndex, int commandCount) | |
{ | |
var entries = new List<IUpdateEntry>(); | |
for (var i = endIndex - commandCount; i < endIndex; i++) | |
{ | |
entries.AddRange(ModificationCommands[i].Entries); | |
} | |
return entries; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
using System.Diagnostics.CodeAnalysis; | |
using System.Globalization; | |
using System.Linq; | |
using System.Text; | |
using Microsoft.EntityFrameworkCore; | |
using Microsoft.EntityFrameworkCore.Update; | |
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; | |
using Npgsql.EntityFrameworkCore.PostgreSQL.Update.Internal; | |
namespace PocPerf.EfOverrides | |
{ | |
public class DigizuiteUpdateSqlGenerator : NpgsqlUpdateSqlGenerator | |
{ | |
public DigizuiteUpdateSqlGenerator(UpdateSqlGeneratorDependencies dependencies) : base(dependencies) | |
{ | |
} | |
/// <summary> | |
/// Generates bulk inserts in the database. Does make some assumptions about how the | |
/// commands are passed into the functions. Mainly that they all insert for the same | |
/// entity type with the exactly same columns in the exact same order. | |
/// </summary> | |
public ResultSetMapping AppendBulkInsertOperation( | |
[NotNull] StringBuilder commandStringBuilder, | |
[NotNull] IReadOnlyList<ModificationCommand> modificationCommands, | |
int commandPosition) | |
{ | |
if (modificationCommands.Count == 1 | |
&& modificationCommands[0].ColumnModifications.All( | |
o => | |
!o.IsKey | |
|| !o.IsRead)) | |
{ | |
return AppendInsertOperation(commandStringBuilder, modificationCommands[0], commandPosition); | |
} | |
var readOperations = modificationCommands[0].ColumnModifications.Where(o => o.IsRead).ToList(); | |
var writeOperations = modificationCommands[0].ColumnModifications.Where(o => o.IsWrite).ToList(); | |
var defaultValuesOnly = writeOperations.Count == 0; | |
var nonIdentityOperations = modificationCommands[0].ColumnModifications | |
.Where(o =>o.Property != null && NpgsqlPropertyExtensions.GetValueGenerationStrategy(o.Property) != NpgsqlValueGenerationStrategy.IdentityAlwaysColumn) | |
.ToList(); | |
if (defaultValuesOnly) | |
{ | |
if (nonIdentityOperations.Count == 0 | |
|| readOperations.Count == 0) | |
{ | |
foreach (var modification in modificationCommands) | |
{ | |
AppendInsertOperation(commandStringBuilder, modification, commandPosition); | |
} | |
return readOperations.Count == 0 | |
? ResultSetMapping.NoResultSet | |
: ResultSetMapping.LastInResultSet; | |
} | |
if (nonIdentityOperations.Count > 1) | |
{ | |
nonIdentityOperations.RemoveRange(1, nonIdentityOperations.Count - 1); | |
} | |
} | |
if (readOperations.Count == 0) | |
{ | |
return AppendBulkInsertWithoutServerValues(commandStringBuilder, modificationCommands, writeOperations); | |
} | |
if (defaultValuesOnly) | |
{ | |
return AppendBulkInsertWithServerValuesOnly( | |
commandStringBuilder, modificationCommands, nonIdentityOperations, readOperations); | |
} | |
return AppendBulkInsertWithServerValues( | |
commandStringBuilder, modificationCommands, writeOperations, readOperations); | |
} | |
private ResultSetMapping AppendBulkInsertWithServerValues(StringBuilder commandStringBuilder, | |
IReadOnlyList<ModificationCommand> modificationCommands, List<ColumnModification> writeOperations, List<ColumnModification> readOperations) | |
{ | |
var name = modificationCommands[0].TableName; | |
var schema = modificationCommands[0].Schema; | |
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations); | |
AppendValuesHeader(commandStringBuilder, writeOperations); | |
AppendValues(commandStringBuilder, writeOperations); | |
for (var i = 1; i < modificationCommands.Count; i++) | |
{ | |
commandStringBuilder.AppendLine(","); | |
AppendValues(commandStringBuilder, modificationCommands[i].ColumnModifications.Where(o => o.IsWrite).ToList()); | |
} | |
AppendReturningClause(commandStringBuilder, readOperations); | |
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator); | |
return ResultSetMapping.LastInResultSet; | |
} | |
private ResultSetMapping AppendBulkInsertWithoutServerValues( | |
StringBuilder commandStringBuilder, | |
IReadOnlyList<ModificationCommand> modificationCommands, | |
List<ColumnModification> writeOperations) | |
{ | |
var name = modificationCommands[0].TableName; | |
var schema = modificationCommands[0].Schema; | |
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations); | |
AppendValuesHeader(commandStringBuilder, writeOperations); | |
AppendValues(commandStringBuilder, writeOperations); | |
for (var i = 1; i < modificationCommands.Count; i++) | |
{ | |
commandStringBuilder.AppendLine(","); | |
AppendValues(commandStringBuilder, modificationCommands[i].ColumnModifications.Where(o => o.IsWrite).ToList()); | |
} | |
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator); | |
return ResultSetMapping.NoResultSet; | |
} | |
private ResultSetMapping AppendInsertOperationWithServerKeys( | |
StringBuilder commandStringBuilder, | |
ModificationCommand command, | |
IReadOnlyList<ColumnModification> readOperations) | |
{ | |
var name = command.TableName; | |
var schema = command.Schema; | |
var operations = command.ColumnModifications; | |
var writeOperations = operations.Where(o => o.IsWrite).ToList(); | |
AppendInsertCommandHeader(commandStringBuilder, name, schema, writeOperations); | |
AppendValuesHeader(commandStringBuilder, writeOperations); | |
AppendValues(commandStringBuilder, writeOperations); | |
AppendReturningClause(commandStringBuilder, readOperations); | |
commandStringBuilder.Append(SqlGenerationHelper.StatementTerminator); | |
return ResultSetMapping.LastInResultSet; | |
} | |
private ResultSetMapping AppendBulkInsertWithServerValuesOnly( | |
StringBuilder commandStringBuilder, | |
IReadOnlyList<ModificationCommand> modificationCommands, | |
List<ColumnModification> nonIdentityOperations, | |
List<ColumnModification> readOperations) | |
{ | |
var name = modificationCommands[0].TableName; | |
var schema = modificationCommands[0].Schema; | |
AppendInsertCommandHeader(commandStringBuilder, name, schema, nonIdentityOperations); | |
AppendValuesHeader(commandStringBuilder, nonIdentityOperations); | |
AppendValues(commandStringBuilder, nonIdentityOperations); | |
for (var i = 1; i < modificationCommands.Count; i++) | |
{ | |
commandStringBuilder.AppendLine(","); | |
AppendValues(commandStringBuilder, nonIdentityOperations); | |
} | |
AppendReturningClause(commandStringBuilder, readOperations); | |
commandStringBuilder.AppendLine(SqlGenerationHelper.StatementTerminator); | |
return ResultSetMapping.LastInResultSet; | |
} | |
// ReSharper disable once ParameterTypeCanBeEnumerable.Local | |
private void AppendReturningClause( | |
StringBuilder commandStringBuilder, | |
IReadOnlyList<ColumnModification> operations) | |
{ | |
commandStringBuilder | |
.AppendLine() | |
.Append("RETURNING ") | |
.AppendJoin(",", operations.Select(c => SqlGenerationHelper.DelimitIdentifier(c.ColumnName))); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment