Skip to content

Instantly share code, notes, and snippets.

@mgravell
Last active December 3, 2019 18:26
Show Gist options
  • Save mgravell/c12d3749cdff9090093874ff27c51a5b to your computer and use it in GitHub Desktop.
Save mgravell/c12d3749cdff9090093874ff27c51a5b to your computer and use it in GitHub Desktop.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
{
public class SocketAwaitableEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion
{
private static readonly Action _callbackCompleted = () => { };
private readonly PipeScheduler _ioScheduler;
private Action _callback;
public SocketAwaitableEventArgs(PipeScheduler ioScheduler)
{
_ioScheduler = ioScheduler;
}
public SocketAwaitableEventArgs GetAwaiter() => this;
public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted);
public int GetResult()
{
Debug.Assert(ReferenceEquals(_callback, _callbackCompleted));
_callback = null;
if (SocketError != SocketError.Success)
{
ThrowSocketException(SocketError);
}
return BytesTransferred;
void ThrowSocketException(SocketError e)
{
throw new SocketException((int)e);
}
}
public void OnCompleted(Action continuation)
{
if (ReferenceEquals(_callback, _callbackCompleted) ||
ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted))
{
Task.Run(continuation);
}
}
public void UnsafeOnCompleted(Action continuation)
{
OnCompleted(continuation);
}
public void Complete()
{
OnCompleted(this);
}
protected override void OnCompleted(SocketAsyncEventArgs _)
{
var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted);
if (continuation != null)
{
_ioScheduler.Schedule(state => ((Action)state)(), continuation);
}
}
}
}
using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
namespace Pipelines.Sockets.Unofficial
{
/// <summary>
/// An awaitable reusable token that is compatible with SocketAsyncEventArgs usage
/// </summary>
public sealed class SocketAwaitable : ICriticalNotifyCompletion
{
private static void NoOp() { }
private static readonly Action _callbackCompleted = NoOp; // get better debug info by avoiding inline delegate here
private Action _callback;
private int _bytesTransfered;
private SocketError _error;
private readonly PipeScheduler _scheduler;
/// <summary>
/// Create a new SocketAwaitable instance, optionally providing a callback scheduler
/// </summary>
/// <param name="scheduler"></param>
public SocketAwaitable(PipeScheduler scheduler = null) => _scheduler =
ReferenceEquals(scheduler, PipeScheduler.Inline) ? null : scheduler;
/// <summary>
/// Gets an awaiter that represents the pending operation
/// </summary>
public SocketAwaitable GetAwaiter() => this;
/// <summary>
/// Indicates whether the pending operation is complete
/// </summary>
public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted);
/// <summary>
/// Gets the result of the pending operation
/// </summary>
public int GetResult()
{
Debug.Assert(ReferenceEquals(_callback, _callbackCompleted));
_callback = null;
if (_error != SocketError.Success)
{
throw new SocketException((int)_error);
}
return _bytesTransfered;
}
/// <summary>
/// Schedule a callback to be invoked when the operation completes
/// </summary>
public void OnCompleted(Action continuation)
{
if (ReferenceEquals(_callback, _callbackCompleted)
|| ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted))
{
continuation(); // sync completion; don't use scheduler
}
}
/// <summary>
/// Schedule a callback to be invoked when the operation completes
/// </summary>
public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation);
/// <summary>
/// Provides a callback suitable for use with SocketAsyncEventArgs where the UserToken is a SocketAwaitable
/// </summary>
public static EventHandler<SocketAsyncEventArgs> Callback = (sender, args) => ((SocketAwaitable)args.UserToken).TryComplete(args.BytesTransferred, args.SocketError);
/// <summary>
/// Mark the pending operation as complete by reading the state of a SocketAsyncEventArgs instance where the UserToken is a SocketAwaitable
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void OnCompleted(SocketAsyncEventArgs args)
=> ((SocketAwaitable)args.UserToken).TryComplete(args.BytesTransferred, args.SocketError);
/// <summary>
/// Mark the pending operation as complete by providing the state explicitly
/// </summary>
public bool TryComplete(int bytesTransferred, SocketError socketError)
{
var action = Interlocked.Exchange(ref _callback, _callbackCompleted);
if ((object)action == (object)_callbackCompleted)
{
return false;
}
_error = socketError;
_bytesTransfered = bytesTransferred;
if (action == null)
{
//Helpers.Incr(Counter.SocketAwaitableCallbackNone);
}
else
{
if (_scheduler == null)
{
//Helpers.Incr(Counter.SocketAwaitableCallbackDirect);
action();
}
else
{
//Helpers.Incr(Counter.SocketAwaitableCallbackSchedule);
_scheduler.Schedule(InvokeStateAsAction, action);
}
}
return true;
}
private static void InvokeStateAsActionImpl(object state) => ((Action)state).Invoke();
internal static readonly Action<object> InvokeStateAsAction = InvokeStateAsActionImpl;
}
}
// #define GOOD
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal;
using Pipelines.Sockets.Unofficial;
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
namespace sockettest
{
static class Program
{
static void Main()
{
var server = Task.Run(Server);
var client = Task.Run(Client);
while(true)
{
Thread.Sleep(1000);
Console.WriteLine($"sent: {Interlocked.Read(ref sent):###,###,###,##0}, received: {Interlocked.Read(ref received):###,###,###,##0}");
}
//server.Wait();
//client.Wait();
}
static async Task Client()
{
var client = new Socket(AddressFamily.InterNetwork,
SocketType.Stream, ProtocolType.Tcp);
Console.WriteLine("connecting...");
client.Connect(new IPEndPoint(IPAddress.Loopback, 9005));
Console.WriteLine("connected");
await SendAsync(client, 312);
}
static async Task Server()
{
var listener = new Socket(AddressFamily.InterNetwork,
SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(IPAddress.Loopback, 9005));
listener.Listen(5);
Console.WriteLine("listening");
await Task.Yield();
Console.WriteLine("accepting...");
var server = listener.Accept();
Console.WriteLine("accepted");
await ReceiveAsync(server, 112);
}
static long sent, received;
private static async Task ReceiveAsync(Socket socket, int seed)
{
var rand = new Random(seed);
#if GOOD
var args = new SocketAwaitableEventArgs(PipeScheduler.ThreadPool);
#else
var awaitable = new SocketAwaitable(PipeScheduler.ThreadPool);
var args = new SocketAsyncEventArgs();
args.UserToken = awaitable;
args.Completed += SocketAwaitable.Callback;
#endif
while (true)
{
var size = rand.Next(100, 300);
var buffer = ArrayPool<byte>.Shared.Rent(size);
args.SetBuffer(buffer, 0, size);
#if GOOD
if (!socket.ReceiveAsync(args)) args.Complete();
var bytes = await args;
#else
if (!socket.ReceiveAsync(args)) SocketAwaitable.OnCompleted(args);
var bytes = await awaitable;
#endif
if (bytes != args.BytesTransferred)
{
Console.WriteLine($"\trecv: {size} vs {bytes} vs {args.BytesTransferred}");
}
Interlocked.Add(ref received, bytes);
ArrayPool<byte>.Shared.Return(buffer);
await Task.Yield();
}
}
static async Task SendAsync(Socket socket, int seed)
{
var rand = new Random(seed);
#if GOOD
var args = new SocketAwaitableEventArgs(PipeScheduler.Inline);
#else
var awaitable = new SocketAwaitable(PipeScheduler.ThreadPool);
var args = new SocketAsyncEventArgs();
args.UserToken = awaitable;
args.Completed += SocketAwaitable.Callback;
#endif
while (true)
{
var size = rand.Next(100, 300);
var buffer = ArrayPool<byte>.Shared.Rent(size);
rand.NextBytes(buffer);
args.SetBuffer(buffer, 0, size);
#if GOOD
if (!socket.SendAsync(args)) args.Complete();
int bytes = await args;
#else
if (!socket.SendAsync(args)) SocketAwaitable.OnCompleted(args);
var bytes = await awaitable;
#endif
if (bytes != args.BytesTransferred || bytes != size)
{
Console.WriteLine($"\tsend: {size} vs {bytes} vs {args.BytesTransferred}");
}
Interlocked.Add(ref sent, bytes);
ArrayPool<byte>.Shared.Return(buffer);
await Task.Yield();
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment