Skip to content

Instantly share code, notes, and snippets.

@mariano
Created March 12, 2021 18:32
Show Gist options
  • Save mariano/acd5fd20eee5adb9d5db52722b7a5e96 to your computer and use it in GitHub Desktop.
Save mariano/acd5fd20eee5adb9d5db52722b7a5e96 to your computer and use it in GitHub Desktop.
RedisQueue.cs
using Microsoft.Extensions.Logging;
using StackExchange.Redis;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Tactic.Actions.Exceptions;
using Tactic.Data.Entities.Calendar;
namespace Tactic.Event.Queue
{
public sealed class RedisQueue<T> : IQueue<T> where T : class
{
private sealed record TypedQueueItem
{
[JsonPropertyName("queueName")] public string QueueName { get; init; }
[JsonPropertyName("id")] public Guid Id { get; init; }
[JsonPropertyName("type")] public string Type { get; init; }
[JsonPropertyName("json")] public string Json { get; init; }
public static TypedQueueItem CreateFromQueueItem<TI>(string queueName, QueueItem<TI> queueItem) where TI : class
{
if (queueItem.Item is null || queueItem.Item.GetType() is null) {
throw new ArgumentNullException(nameof(queueItem));
}
Type itemType = queueItem.Item.GetType();
if (itemType is null || itemType.FullName is null or { Length: 0 }) {
throw new ArgumentNullException(nameof(queueItem));
}
return new(
queueName,
queueItem.Id,
itemType.FullName,
JsonSerializer.Serialize(queueItem.Item, queueItem.Item.GetType())
);
}
public TypedQueueItem(string queueName, Guid id, string type, string json) =>
(this.QueueName, this.Id, this.Type, this.Json) = (queueName, id, type, json);
}
private readonly ConnectionMultiplexer connectionFactory;
private readonly string? namespaceRoot;
private readonly ILogger<RedisQueue<T>> logger;
private readonly Dictionary<string,RedisKey> queues = new();
private readonly ConcurrentDictionary<Guid,object> locks = new();
public RedisQueue(ConnectionMultiplexer connectionFactory, string namespaceRoot, ILogger<RedisQueue<T>> logger) =>
(this.connectionFactory, this.namespaceRoot, this.logger) = (connectionFactory, namespaceRoot, logger);
public async Task Enqueue(string pendingQueueName, string delayedQueueName, T item, DateTimeWithTimeZone? triggerOn, CancellationToken? _)
{
if (item == default(T)) {
throw new ArgumentNullException(nameof(item));
} else if (!this.isValidNamespace(item.GetType())) {
throw new ArgumentException(nameof(item) + " does not belong to the restricted namespace");
} else if (triggerOn.HasValue && !triggerOn.Value.IsFuture) {
throw new ArgumentException(nameof(item) + " is being scheduled for a moment not in the future");
}
RedisKey queue = this.getQueue(pendingQueueName);
TypedQueueItem typedQueueItem = TypedQueueItem.CreateFromQueueItem<T>(pendingQueueName, new(item));
string json = JsonSerializer.Serialize<TypedQueueItem>(typedQueueItem);
IDatabase database = this.connectionFactory.GetDatabase();
if (triggerOn is not null) {
DateTime triggerOnDate = triggerOn.Value.UTC.DateTime;
this.logger.LogInformation($"Enqueuing delayed job #{typedQueueItem.Id} in queue {delayedQueueName} for {triggerOnDate.ToString(DateTimeWithTimeZone.FORMAT_ISO_8601_DATETIME_EXTENDED)} ({triggerOnDate.Ticks} ticks)");
if (!await database.SortedSetAddAsync($"{delayedQueueName}:", json, (double) triggerOnDate.Ticks)) {
throw new CouldNotScheduleException(delayedQueueName, json);
}
this.logger.LogInformation($"Enqueued delayed job #{typedQueueItem.Id} in queue {delayedQueueName} for {triggerOnDate.ToString(DateTimeWithTimeZone.FORMAT_ISO_8601_DATETIME_EXTENDED)} ({triggerOnDate.Ticks} ticks)");
} else {
this.logger.LogInformation($"Enqueuing job #{typedQueueItem.Id} in queue {queue}");
await database.ListLeftPushAsync(queue, json);
this.logger.LogInformation($"Enqueued job #{typedQueueItem.Id} in queue {queue}");
}
}
public async Task<QueueItem<T>?> DequeueAndMove(string queueName, string destinationQueueName, int millisecondsTimeout, CancellationToken cancellationToken)
{
RedisKey queue = this.getQueue(queueName);
RedisKey destinationQueue = this.getQueue(destinationQueueName);
string json = await this.connectionFactory.GetDatabase().ListRightPopLeftPushAsync(queue, destinationQueue);
if (json is null or { Length: 0 }) {
return null;
}
TypedQueueItem? typedQueueItem = JsonSerializer.Deserialize<TypedQueueItem>(json);
if (typedQueueItem is null) {
throw new ArgumentNullException(json);
}
// Validate that typedQueueItem.Type belongs to the core namespace to prevent
// code injection by hacking namespaces into Redis.
if (!this.isValidNamespace(typedQueueItem.Type)) {
throw new ArgumentException($"Got item with namespace {typedQueueItem.Type} which does not belong to the restricted namespace");
}
Type? itemType = Type.GetType(typedQueueItem.Type);
if (itemType is null) {
throw new ArgumentNullException(typedQueueItem.Type);
}
T? item = (T?) JsonSerializer.Deserialize(typedQueueItem.Json, itemType);
if (item is null) {
throw new ArgumentNullException(typedQueueItem.Json);
}
QueueItem<T> queueItem = new(typedQueueItem.Id, item);
this.logger.LogInformation($"Dequeued job #{typedQueueItem.Id} in queue {queue}");
return queueItem;
}
public async Task<Guid?> DequeueDelayedAndEnqueue(string delayedQueueName, int millisecondsTimeout, CancellationToken cancellationToken)
{
RedisKey queue = this.getQueue($"{delayedQueueName}:");
long nowTicks = DateTime.UtcNow.Ticks;
IDatabase database = this.connectionFactory.GetDatabase();
SortedSetEntry[] entries = await database.SortedSetRangeByScoreWithScoresAsync(queue, stop: (double) nowTicks, take: 1);
SortedSetEntry? entry = entries is { Length: >0 } ? entries[0] : null;
if (entry is null || !entry.HasValue || entry.Value.Score > nowTicks) {
if (entry.HasValue) {
this.logger.LogInformation($"Got possible delayed job in queue {delayedQueueName} but its ticks score is {entry.Value.Score} when we are currently on ticks {nowTicks}");
}
// The first available entry is still in the future, so leave it there
return null;
}
RedisValue entryValue = entry.Value.Element;
string json = entryValue;
if (json is null or { Length: 0 }) {
return null;
}
TypedQueueItem? typedQueueItem = JsonSerializer.Deserialize<TypedQueueItem>(json);
if (typedQueueItem is null) {
throw new ArgumentNullException(json);
}
// Validate that typedQueueItem.Type belongs to the core namespace to prevent
// code injection by hacking namespaces into Redis.
if (!this.isValidNamespace(typedQueueItem.Type)) {
throw new ArgumentException($"Got item with namespace {typedQueueItem.Type} which does not belong to the restricted namespace");
}
this.logger.LogInformation($"Dequeued delayed job #{typedQueueItem.Id} from queue {delayedQueueName}");
RedisKey destinationQueue = typedQueueItem.QueueName;
// Ensure no async code within lock or compiler will be very mad at you. See https://stackoverflow.com/a/7612714
lock (this.getLock(typedQueueItem.Id)) {
if (database.SortedSetRemove(queue, entryValue)) {
this.logger.LogInformation($"Enqueuing delayed job #{typedQueueItem.Id} for instant processing in queue {destinationQueue}");
database.ListLeftPush(destinationQueue, json);
this.logger.LogInformation($"Enqueued delayed job #{typedQueueItem.Id} for instant processing in queue {destinationQueue}");
} else {
this.logger.LogError($"Could not remove delayed job #{typedQueueItem.Id} in queue {queue}: {json}");
}
}
return typedQueueItem.Id;
}
public async Task Delete(string queueName, QueueItem<T> queueItem, CancellationToken cancellationToken)
{
if (!this.queues.ContainsKey(queueName)) {
throw new System.ArgumentException($"{nameof(queueName)} specifies an invalid queue name: {queueName}");
}
TypedQueueItem typedQueueItem = TypedQueueItem.CreateFromQueueItem<T>(queueName, queueItem);
string json = JsonSerializer.Serialize<TypedQueueItem>(typedQueueItem);
await this.connectionFactory.GetDatabase().ListRemoveAsync(this.queues[queueName], json, count: 1);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private object getLock(Guid id) => this.locks.GetOrAdd(id, _ => new object());
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private RedisKey getQueue(string queueName)
{
if (!this.queues.ContainsKey(queueName)) {
this.queues[queueName] = new RedisKey(queueName);
}
return this.queues[queueName];
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool isValidNamespace(Type type) => type.FullName is not null && this.isValidNamespace(type.FullName);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool isValidNamespace(string @namespace) =>
(this.namespaceRoot is null or { Length: 0 }) || (@namespace is { Length: >0 } && @namespace.StartsWith(this.namespaceRoot));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment