Skip to content

Instantly share code, notes, and snippets.

@bymyslf
Last active June 25, 2020 08:42
Show Gist options
  • Save bymyslf/e9b657159decdc0bc7439a77927108b7 to your computer and use it in GitHub Desktop.
Save bymyslf/e9b657159decdc0bc7439a77927108b7 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
public class SqlScriptMigrator
{
private readonly Func<IDbConnection> _createConnection;
private readonly string _schema;
private readonly IDictionary<string, string> _variables;
public SqlScriptMigrator(Func<IDbConnection> createConnection, string schema)
: this(createConnection, schema, new Dictionary<string, string>())
{ }
public SqlScriptMigrator(Func<IDbConnection> createConnection, string schema, IDictionary<string, string> variables)
{
_createConnection = createConnection;
_schema = schema;
_variables = variables ?? new Dictionary<string, string>();
}
public void Migrate()
{
using (var connection = _createConnection())
{
connection.Open();
EnsureTableExists(connection);
foreach (var script in GetScriptsToExecute(connection))
ExecuteWithinTransaction(connection, script);
}
}
private IEnumerable<string> GetExecutedScripts(IDbConnection connection)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(EXECUTED_SCRIPTS_SCRIPT, _schema);
using (var reader = command.ExecuteReader())
{
while (reader.Read())
yield return (string)reader[0];
}
}
}
private IEnumerable<SqlScript> GetScriptsToExecute(IDbConnection connection)
{
var executed = new HashSet<string>(GetExecutedScripts(connection));
foreach (var script in GetEmbeddedScripts())
if (executed.Add(script.Name))
yield return script;
}
private void EnsureSchemaExists(IDbConnection connection)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(VERIFY_SCHEMA_SCRIPT, _schema);
command.ExecuteNonQuery();
}
}
private void EnsureTableExists(IDbConnection connection)
{
EnsureSchemaExists(connection);
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(CREATE_TABLE_SCRIPT, _schema);
command.ExecuteNonQuery();
}
}
private void MarkAsExecuted(IDbConnection connection, IDbTransaction transaction, string scriptName)
{
using (var command = connection.CreateCommand())
{
command.CommandText = string.Format(INSERT_EXECUTED_SCRIPT, _schema);
command.Transaction = transaction;
var scriptNameParameter = command.CreateParameter();
scriptNameParameter.ParameterName = "scriptName";
scriptNameParameter.Value = scriptName;
command.Parameters.Add(scriptNameParameter);
command.ExecuteNonQuery();
}
}
private void ExecuteWithinTransaction(IDbConnection connection, SqlScript script)
{
using (var transaction = connection.BeginTransaction())
{
try
{
var scriptContent = SubstituteVariables(script.Content);
var sqlCommands = SplitCommands(scriptContent);
foreach (var command in sqlCommands)
{
using (var cmd = connection.CreateCommand())
{
cmd.CommandText = command;
cmd.Transaction = transaction;
cmd.ExecuteNonQuery();
};
}
MarkAsExecuted(connection, transaction, script.Name);
transaction.Commit();
}
catch (Exception ex)
{
transaction.Rollback();
throw new SqlScriptMigrationException($"ERROR executing script: {script.Name}", ex);
}
}
}
private string SubstituteVariables(string content)
=> _variables.Aggregate(content, (current, replacement) => current.Replace($"${replacement.Key}$", replacement.Value));
private IEnumerable<SqlScript> GetEmbeddedScripts()
{
var assembly = typeof(SqlScriptMigrator).GetTypeInfo().Assembly;
var scriptNames = assembly.GetManifestResourceNames()
.Where(s => s.EndsWith(".sql"))
.OrderBy(s => s);
foreach (var script in scriptNames)
{
using (var stream = assembly.GetManifestResourceStream(script))
{
if (stream is null)
throw new Exception($"Embedded resource, {script}, not found. BUG!");
using (var reader = new StreamReader(stream))
yield return new SqlScript(script, reader.ReadToEnd());
}
}
}
private const string INSERT_EXECUTED_SCRIPT = @"INSERT INTO [{0}].[MigrationScripts] (ScriptName) VALUES(@scriptName)";
private const string EXECUTED_SCRIPTS_SCRIPT = @"SELECT [ScriptName] FROM [{0}].[MigrationScripts] ORDER BY [ScriptName]";
private const string VERIFY_SCHEMA_SCRIPT = @"IF NOT EXISTS (SELECT 1 FROM sys.schemas WHERE name = N'{0}')
BEGIN
EXEC sp_executesql N'CREATE SCHEMA [{0}]'
END";
private const string CREATE_TABLE_SCRIPT = @"IF NOT EXISTS (SELECT *
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = '{0}'
AND TABLE_NAME = 'MigrationScripts')
BEGIN
CREATE TABLE [{0}].[MigrationScripts]
(
[Id] INT IDENTITY(1,1) NOT NULL
CONSTRAINT PK_MigrationScripts PRIMARY KEY,
[ScriptName] NVARCHAR(255) NOT NULL,
[Applied] DATETIME NOT NULL DEFAULT GETDATE()
)
END";
private static IList<string> SplitCommands(string sql)
{
var commands = new List<string>();
//origin from the Microsoft.EntityFrameworkCore.Migrations.SqlServerMigrationsSqlGenerator.Generate method
sql = Regex.Replace(sql, @"\\\r?\n", string.Empty);
var batches = Regex.Split(sql, @"^\s*(GO[ \t]+[0-9]+|GO)(?:\s+|$)", RegexOptions.IgnoreCase | RegexOptions.Multiline);
for (var i = 0; i < batches.Length; i++)
{
if (string.IsNullOrWhiteSpace(batches[i]) || batches[i].StartsWith("GO", StringComparison.OrdinalIgnoreCase))
continue;
var count = 1;
if (i != batches.Length - 1 && batches[i + 1].StartsWith("GO", StringComparison.OrdinalIgnoreCase))
{
var match = Regex.Match(batches[i + 1], "([0-9]+)");
if (match.Success)
count = int.Parse(match.Value);
}
var builder = new StringBuilder();
for (var j = 0; j < count; j++)
{
builder.Append(batches[i]);
if (i == batches.Length - 1)
builder.AppendLine();
}
commands.Add(builder.ToString());
}
return commands;
}
private readonly struct SqlScript
{
public SqlScript(string name, string content)
=> (Name, Content) = (name, content);
public string Name { get; }
public string Content { get; }
}
[Serializable]
private class SqlScriptMigrationException : Exception
{
public SqlScriptMigrationException()
{
}
public SqlScriptMigrationException(string message) : base(message)
{
}
public SqlScriptMigrationException(string message, Exception innerException) : base(message, innerException)
{
}
protected SqlScriptMigrationException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment