Skip to content

Instantly share code, notes, and snippets.

@to11mtm
Created December 31, 2023 03:38
Show Gist options
  • Save to11mtm/3866f1641c9e68d025bb600d733d7a8c to your computer and use it in GitHub Desktop.
Save to11mtm/3866f1641c9e68d025bb600d733d7a8c to your computer and use it in GitHub Desktop.
WIP Akka Streams ValueTask SelectAsync stage with Pooling
using System;
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
using Akka.Annotations;
using Akka.Event;
using Akka.Streams.Stage;
using Akka.Streams.Supervision;
using Akka.Util;
namespace Akka.Streams.Implementation.Fusing;
// WIP USE AT OWN RISK
[InternalApi]
public sealed class SelectValueTaskAsync<TIn, TOut> : GraphStage<FlowShape<TIn, TOut>>
{
#region internal classes
private sealed class Logic : InAndOutGraphStageLogic
{
private sealed class Holder<T>
{
public Result<T> Element { get; private set; }
private readonly Action<Holder<T>> _callback;
private ValueTask<T> _pending;
private static readonly Action<object> OnCompletedAction =
CompletionActionVt;
private static readonly Action<Task<T>,object> TaskCompletedAction = (Task<T> t,object o) =>
{
var ca = (Holder<T>)o;
if (t.IsFaulted)
{
var exception = t.Exception?.InnerExceptions != null &&
t.Exception.InnerExceptions.Count == 1
? t.Exception.InnerExceptions[0]
: t.Exception;
ca.Invoke(Result.Failure<T>(exception));
}
else
{
ca.Invoke(Result.Success(t.Result));
}
};
private static void CompletionActionVt(object discard)
{
var inst = (Holder<T>)discard;
var vtCapture = inst._pending;
inst._pending = default;
if (vtCapture.IsCompletedSuccessfully)
{
inst.Invoke(Result.Success<T>(vtCapture.Result));
}
else if(vtCapture.IsCanceled == false)
{
inst.VTCompletionError(vtCapture);
}
}
private void VTCompletionError(ValueTask<T> vtCapture)
{
var t = vtCapture.AsTask();
//We only care about faulted, not canceled.
if (t.IsFaulted)
{
var exception = t.Exception?.InnerExceptions != null &&
t.Exception.InnerExceptions.Count == 1
? t.Exception.InnerExceptions[0]
: t.Exception;
SetElement(Result.Failure<T>(exception));
}
}
public Holder(Result<T> element, Action<Holder<T>> callback)
{
_callback = callback;
Element = element;
}
public void SetElement(Result<T> result)
{
Element = result.IsSuccess && result.Value == null
? Result.Failure<T>(ReactiveStreamsCompliance.ElementMustNotBeNullException)
: result;
}
public void SetContinuation(ValueTask<T> vt)
{
var valueTask = vt;
if (valueTask.IsCompletedSuccessfully)
{
SetElement(new Result<T>(vt.Result));
_callback(this);
}
else
{
var peeker =
Unsafe.As<ValueTask<T>, ValueTaskCheatingPeeker<T>>(ref valueTask);
if (peeker._obj == null)
{
SetElement(Result.Success(peeker._result));
_callback(this);
}
else if (peeker._obj is Task<T> asTask)
{
asTask.ContinueWith(TaskCompletedAction,this,
TaskContinuationOptions.NotOnCanceled);
}
else
{
_pending = vt;
var source = Unsafe.As<IValueTaskSource<T>>(peeker._obj);
source.OnCompleted(OnCompletedAction, this, peeker._token,
ValueTaskSourceOnCompletedFlags.None);
}
}
}
public void Invoke(Result<T> result)
{
SetElement(result);
_callback(this);
}
}
private static readonly Result<TOut> NotYetThere = Result.Failure<TOut>(NotYetThereSentinel.Instance);
private readonly SelectValueTaskAsync<TIn, TOut> _stage;
private readonly Decider _decider;
private IBuffer<Holder<TOut>> _buffer;
private readonly Action<Holder<TOut>> _taskCallback;
private readonly
ConcurrentQueue<
Holder<TOut>> _queue;
public Logic(Attributes inheritedAttributes, SelectValueTaskAsync<TIn, TOut> stage) : base(stage.Shape)
{
_stage = stage;
var attr = inheritedAttributes.GetAttribute<ActorAttributes.SupervisionStrategy>(null);
_decider = attr != null ? attr.Decider : Deciders.StoppingDecider;
_taskCallback = GetAsyncCallback<Holder<TOut>>(HolderCompleted);
_queue =
new ConcurrentQueue<
Holder<TOut>>();
SetHandlers(stage.In, stage.Out, this);
}
private Holder<TOut> RentOrGet()
{
if (_queue.TryDequeue(out var item))
{
return item;
}
else
{
return new Holder<TOut>(NotYetThere, _taskCallback);
}
}
public override void OnPush()
{
var message = Grab(_stage.In);
try
{
var task = _stage._mapFunc(message);
var holder = RentOrGet();
//var holder = new Holder<TOut>(NotYetThere, _taskCallback);
_buffer.Enqueue(holder);
// We dispatch the task if it's ready to optimize away
// scheduling it to an execution context
if (task.IsCompletedSuccessfully)
{
holder.SetElement(Result.Success(task.Result));
HolderCompleted(holder);
}
else
holder.SetContinuation(task);
//task.GetAwaiter().ContinueWith(t => holder.Invoke(Result.FromTask(t)),
// TaskContinuationOptions.ExecuteSynchronously);
}
catch (Exception e)
{
var strategy = _decider(e);
Log.Error(e, "An exception occured inside SelectAsync while processing message [{0}]. Supervision strategy: {1}", message, strategy);
switch (strategy)
{
case Directive.Stop:
FailStage(e);
break;
case Directive.Resume:
case Directive.Restart:
break;
default:
throw new AggregateException($"Unknown SupervisionStrategy directive: {strategy}", e);
}
}
if (Todo < _stage._parallelism && !HasBeenPulled(_stage.In))
TryPull(_stage.In);
}
public override void OnUpstreamFinish()
{
if (Todo == 0)
CompleteStage();
}
public override void OnPull() => PushOne();
private int Todo => _buffer.Used;
public override void PreStart() => _buffer = Buffer.Create<Holder<TOut>>(_stage._parallelism, Materializer);
private void PushOne()
{
var inlet = _stage.In;
while (true)
{
if (_buffer.IsEmpty)
{
if (IsClosed(inlet))
CompleteStage();
else if (!HasBeenPulled(inlet))
Pull(inlet);
}
else if (_buffer.Peek().Element == NotYetThere)
{
if (Todo < _stage._parallelism && !HasBeenPulled(inlet))
TryPull(inlet);
}
else
{
var dequeued = _buffer.Dequeue();
var result = dequeued.Element;
dequeued.SetElement(NotYetThere);
_queue.Enqueue(dequeued);
if (!result.IsSuccess)
continue;
Push(_stage.Out, result.Value);
if (Todo < _stage._parallelism && !HasBeenPulled(inlet))
TryPull(inlet);
}
break;
}
}
private void HolderCompleted(Holder<TOut> holder)
{
var element = holder.Element;
if (element.IsSuccess)
{
if (IsAvailable(_stage.Out))
PushOne();
return;
}
var exception = element.Exception;
var strategy = _decider(exception);
Log.Error(exception, "An exception occured inside SelectAsync while executing Task. Supervision strategy: {0}", strategy);
switch (strategy)
{
case Directive.Stop:
FailStage(exception);
break;
case Directive.Resume:
case Directive.Restart:
if (IsAvailable(_stage.Out))
PushOne();
break;
default:
throw new AggregateException($"Unknown SupervisionStrategy directive: {strategy}", exception);
}
}
public override string ToString() => $"SelectAsync.Logic(buffer={_buffer})";
}
#endregion
private readonly int _parallelism;
private readonly Func<TIn, ValueTask<TOut>> _mapFunc;
/// <summary>
/// TBD
/// </summary>
public readonly Inlet<TIn> In = new("SelectAsync.in");
/// <summary>
/// TBD
/// </summary>
public readonly Outlet<TOut> Out = new("SelectAsync.out");
/// <summary>
/// TBD
/// </summary>
/// <param name="parallelism">TBD</param>
/// <param name="mapFunc">TBD</param>
public SelectValueTaskAsync(int parallelism, Func<TIn, ValueTask<TOut>> mapFunc)
{
_parallelism = parallelism;
_mapFunc = mapFunc;
Shape = new FlowShape<TIn, TOut>(In, Out);
}
/// <summary>
/// TBD
/// </summary>
protected override Attributes InitialAttributes { get; } = Attributes.CreateName("selectAsync");
/// <summary>
/// TBD
/// </summary>
public override FlowShape<TIn, TOut> Shape { get; }
/// <summary>
/// TBD
/// </summary>
/// <param name="inheritedAttributes">TBD</param>
/// <returns>TBD</returns>
protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes)
=> new Logic(inheritedAttributes, this);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment