Skip to content

Instantly share code, notes, and snippets.

@phillip-haydon
Created June 9, 2020 09:21
Show Gist options
  • Save phillip-haydon/621a4977af7bf2f135e91778262ccbee to your computer and use it in GitHub Desktop.
Save phillip-haydon/621a4977af7bf2f135e91778262ccbee to your computer and use it in GitHub Desktop.
public static class PostgresServiceCollectionExtensions
{
public static WorkflowOptions UsePostgreSqlLocking(this WorkflowOptions options, string connectionString, string schemaName = "wfc")
{
options.UseDistributedLockManager(sp => new PostgreSqlLockProvider( connectionString, schemaName, sp.GetService<ILoggerFactory>()));
return options;
}
}
public class PostgreSqlLockProvider : IDistributedLockProvider
{
private readonly string _connectionString;
private readonly ILogger _logger;
private readonly Guid _nodeId;
private readonly long _ttl = 30000;
private readonly int _heartbeat = 10000;
private readonly List<string> _localLocks;
private Task _heartbeatTask;
private CancellationTokenSource _cancellationTokenSource;
private readonly AutoResetEvent _mutex = new AutoResetEvent(true);
public PostgreSqlLockProvider(string connectionString, string schemaName, ILoggerFactory logFactory)
{
_connectionString = connectionString;
_logger = logFactory.CreateLogger<PostgreSqlLockProvider>();
_localLocks = new List<string>();
_nodeId = Guid.NewGuid();
SetupCommands(schemaName);
}
private string _acquireLockCommand = "";
private string _releaseLockCommand = "";
private string _checkTableCommand = "";
private string _createTableCommand = "";
private string _heartbeatCommand = "";
public async Task<bool> AcquireLock(string id, CancellationToken cancellationToken)
{
try
{
await using var conn = new NpgsqlConnection(_connectionString);
await conn.OpenAsync(cancellationToken);
await using var cmd = new NpgsqlCommand(_acquireLockCommand, conn);
cmd.Parameters.AddWithValue("id", NpgsqlDbType.Text, id);
cmd.Parameters.AddWithValue("lock_owner", NpgsqlDbType.Uuid, _nodeId);
cmd.Parameters.AddWithValue("expires", NpgsqlDbType.Timestamp, DateTime.UtcNow.AddMilliseconds(_ttl));
await cmd.PrepareAsync(cancellationToken);
var result = await cmd.ExecuteNonQueryAsync(cancellationToken);
if (result == 1)
{
_localLocks.Add(id);
return true;
}
}
catch (NpgsqlException exception)
{
_logger.LogError(exception, "Could not acquire lock");
}
return false;
}
public async Task ReleaseLock(string id)
{
_mutex.WaitOne();
try
{
_localLocks.Remove(id);
}
finally
{
_mutex.Set();
}
try
{
await using var conn = new NpgsqlConnection(_connectionString);
await using var cmd = new NpgsqlCommand(_releaseLockCommand, conn);
cmd.Parameters.AddWithValue("id", NpgsqlDbType.Text, id);
await conn.OpenAsync(default);
await cmd.PrepareAsync(default);
_ = await cmd.ExecuteNonQueryAsync(default);
}
catch (NpgsqlException exception)
{
_logger.LogError(exception, "Could not acquire lock");
}
}
public async Task Start()
{
await EnsureTable();
if (_heartbeatTask != null)
{
throw new InvalidOperationException();
}
_cancellationTokenSource = new CancellationTokenSource();
_heartbeatTask = new Task(SendHeartbeat);
_heartbeatTask.Start();
}
public Task Stop()
{
_cancellationTokenSource.Cancel();
_heartbeatTask.Wait();
_heartbeatTask = null;
return Task.CompletedTask;
}
private async void SendHeartbeat()
{
while (!_cancellationTokenSource.IsCancellationRequested)
{
try
{
await Task.Delay(_heartbeat, _cancellationTokenSource.Token);
if (_mutex.WaitOne())
{
try
{
await using var conn = new NpgsqlConnection(_connectionString);
await using var cmd = new NpgsqlCommand(_heartbeatCommand, conn);
cmd.Parameters.AddWithValue("ids", NpgsqlDbType.Array | NpgsqlDbType.Text, _localLocks.ToArray());
cmd.Parameters.AddWithValue("lock_owner", NpgsqlDbType.Uuid, _nodeId);
cmd.Parameters.AddWithValue("expires", NpgsqlDbType.Timestamp, DateTime.UtcNow.AddMilliseconds(_ttl));
await conn.OpenAsync(default);
await cmd.PrepareAsync(default);
_ = await cmd.ExecuteNonQueryAsync(default);
}
catch (NpgsqlException exception)
{
_logger.LogError(exception, "Exception occured when sending heartbeat.");
}
finally
{
_mutex.Set();
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, ex.Message);
}
}
}
private async Task EnsureTable()
{
try
{
await using var conn = new NpgsqlConnection(_connectionString);
await conn.OpenAsync(default);
await using var checkCommand = new NpgsqlCommand(_checkTableCommand, conn);
var exists = (bool)await checkCommand.ExecuteScalarAsync(default);
if (!exists)
{
await using var createCommand = new NpgsqlCommand(_createTableCommand, conn);
_ = await createCommand.ExecuteNonQueryAsync(default);
}
await conn.CloseAsync();
}
catch (NpgsqlException exception)
{
_logger.LogError(exception, "Error occured when calling EnsureTable");
}
}
private void SetupCommands(string schemaName)
{
_acquireLockCommand = $@"
begin;
lock table {schemaName}.workflow_lock_provider in access exclusive mode;
select * from {schemaName}.workflow_lock_provider
where id = @id
and lock_owner = @lock_owner
for no key update;
insert into {schemaName}.workflow_lock_provider (id, lock_owner, expires)
select @id, @lock_owner, @expires
where not exists (
select from {schemaName}.workflow_lock_provider
where id = @id
and lock_owner = @lock_owner
and expires > now()
);
commit;
";
_releaseLockCommand = $@"
delete from {schemaName}.workflow_lock_provider where id = @id
";
_checkTableCommand = $@"
select exists (
select from pg_tables
where schemaname = '{schemaName}'
and tablename = 'workflow_lock_provider'
);
";
_createTableCommand = $@"
create schema if not exists {schemaName};
create table if not exists {schemaName}.workflow_lock_provider (
id text not null,
lock_owner uuid not null,
expires timestamp not null
);
create index ix_workflow_lock_provider_id on {schemaName}.workflow_lock_provider using hash(id);
create index ix_workflow_lock_provider_lock_owner on {schemaName}.workflow_lock_provider using hash(lock_owner);
create index ix_workflow_lock_provider_expires on {schemaName}.workflow_lock_provider using btree(expires);
";
_heartbeatCommand = $@"
-- Update any expiring locks
update {schemaName}.workflow_lock_provider
set expires = @expires
where lock_owner = @lock_owner
and id = any(@ids);
-- Delete any expired locks
delete from {schemaName}.workflow_lock_provider
where expires < now();
";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment