Skip to content

Instantly share code, notes, and snippets.

@TrueGeek
Created December 23, 2019 02:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TrueGeek/960641d12e42d477bf8f35e7365a46de to your computer and use it in GitHub Desktop.
Save TrueGeek/960641d12e42d477bf8f35e7365a46de to your computer and use it in GitHub Desktop.
Custom TlsHandler for DotNetty. This is needed when using Matrix vNext with Xamarin. From https://github.com/Azure/DotNetty/pull/374
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using DotNetty.Handlers.Tls;
public sealed class ServerTlsSettings : TlsSettings
{
public ServerTlsSettings(X509Certificate certificate)
: this(certificate, false)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate)
: this(certificate, negotiateClientCertificate, false)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation)
: this(certificate, negotiateClientCertificate, checkCertificateRevocation, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation, SslProtocols enabledProtocols)
: base(enabledProtocols, checkCertificateRevocation)
{
this.Certificate = certificate;
this.NegotiateClientCertificate = negotiateClientCertificate;
}
public X509Certificate Certificate { get; }
public bool NegotiateClientCertificate { get; }
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
// from: https://github.com/Azure/DotNetty/pull/374
namespace CustomMonoTlsHandler
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.IO;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using DotNetty.Buffers;
using DotNetty.Codecs;
using DotNetty.Common.Concurrency;
using DotNetty.Common.Utilities;
using DotNetty.Handlers.Tls;
using DotNetty.Transport.Channels;
public sealed class TlsHandler : ByteToMessageDecoder
{
readonly TlsSettings settings;
const int FallbackReadBufferSize = 256;
const int UnencryptedWriteBatchSize = 14 * 1024;
static readonly Exception ChannelClosedException = new IOException("Channel is closed");
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
static readonly Action<Task<int>, object> UnwrapCompletedCallback = new Action<Task<int>, object>(UnwrapCompleted);
readonly SslStream sslStream;
readonly MediationStream mediationStream;
readonly TaskCompletionSource closeFuture;
TlsHandlerState state;
int packetLength;
volatile IChannelHandlerContext capturedContext;
BatchingPendingWriteQueue pendingUnencryptedWrites;
Task lastContextWriteTask;
bool firedChannelRead;
volatile FlushMode flushMode = FlushMode.ForceFlush;
IByteBuffer pendingSslStreamReadBuffer;
int pendingSslStreamReadLength;
Task<int> pendingSslStreamReadFuture;
public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
}
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
{
Contract.Requires(sslStreamFactory != null);
Contract.Requires(settings != null);
this.settings = settings;
this.closeFuture = new TaskCompletionSource();
this.mediationStream = new MediationStream(this);
this.sslStream = sslStreamFactory(this.mediationStream);
}
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));
// using workaround mentioned here: https://github.com/dotnet/corefx/issues/4510
public X509Certificate2 LocalCertificate => this.sslStream.LocalCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.LocalCertificate?.Export(X509ContentType.Cert));
public X509Certificate2 RemoteCertificate => this.sslStream.RemoteCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.RemoteCertificate?.Export(X509ContentType.Cert));
bool IsServer => this.settings is ServerTlsSettings;
public void Dispose() => this.sslStream?.Dispose();
public override void ChannelActive(IChannelHandlerContext context)
{
base.ChannelActive(context);
if (!this.IsServer)
{
this.EnsureAuthenticated();
}
}
public override void ChannelInactive(IChannelHandlerContext context)
{
// Make sure to release SslStream,
// and notify the handshake future if the connection has been closed during handshake.
this.HandleFailure(ChannelClosedException);
base.ChannelInactive(context);
}
public override void ExceptionCaught(IChannelHandlerContext context, Exception exception)
{
if (this.IgnoreException(exception))
{
// Close the connection explicitly just in case the transport
// did not close the connection automatically.
if (context.Channel.Active)
{
context.CloseAsync();
}
}
else
{
base.ExceptionCaught(context, exception);
}
}
bool IgnoreException(Exception t)
{
if (t is ObjectDisposedException && this.closeFuture.Task.IsCompleted)
{
return true;
}
return false;
}
static void HandleHandshakeCompleted(Task task, object state)
{
var self = (TlsHandler)state;
switch (task.Status)
{
case TaskStatus.RanToCompletion:
{
TlsHandlerState oldState = self.state;
Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted));
self.state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake);
self.capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success);
if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self.capturedContext.Channel.Configuration.AutoRead)
{
self.capturedContext.Read();
}
if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake))
{
self.WrapAndFlush(self.capturedContext);
}
break;
}
case TaskStatus.Canceled:
case TaskStatus.Faulted:
{
// ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted
TlsHandlerState oldState = self.state;
Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated));
self.HandleFailure(task.Exception);
break;
}
default:
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
}
}
public override void HandlerAdded(IChannelHandlerContext context)
{
base.HandlerAdded(context);
this.capturedContext = context;
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
if (context.Channel.Active && !this.IsServer)
{
// todo: support delayed initialization on an existing/active channel if in client mode
this.EnsureAuthenticated();
}
}
protected override void HandlerRemovedInternal(IChannelHandlerContext context)
{
if (!this.pendingUnencryptedWrites.IsEmpty)
{
// Check if queue is not empty first because create a new ChannelException is expensive
this.pendingUnencryptedWrites.RemoveAndFailAll(new ChannelException("Write has failed due to TlsHandler being removed from channel pipeline."));
}
}
protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output)
{
int startOffset = input.ReaderIndex;
int endOffset = input.WriterIndex;
int offset = startOffset;
int totalLength = 0;
List<int> packetLengths;
// if we calculated the length of the current SSL record before, use that information.
if (this.packetLength > 0)
{
if (endOffset - startOffset < this.packetLength)
{
// input does not contain a single complete SSL record
return;
}
else
{
packetLengths = new List<int>(4);
packetLengths.Add(this.packetLength);
offset += this.packetLength;
totalLength = this.packetLength;
this.packetLength = 0;
}
}
else
{
packetLengths = new List<int>(4);
}
bool nonSslRecord = false;
while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
int readableBytes = endOffset - offset;
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
{
break;
}
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (encryptedPacketLength == -1)
{
nonSslRecord = true;
break;
}
Contract.Assert(encryptedPacketLength > 0);
if (encryptedPacketLength > readableBytes)
{
// wait until the whole packet can be read
this.packetLength = encryptedPacketLength;
break;
}
int newTotalLength = totalLength + encryptedPacketLength;
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
// Don't read too much.
break;
}
// 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once.
// 2. once we're through all the whole packets, switch to reading out using fallback sized buffer
// We have a whole packet.
// Increment the offset to handle the next packet.
packetLengths.Add(encryptedPacketLength);
offset += encryptedPacketLength;
totalLength = newTotalLength;
}
if (totalLength > 0)
{
// The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
// decode(...) again via:
// 1) unwrap(..) is called
// 2) wrap(...) is called from within unwrap(...)
// 3) wrap(...) calls unwrapLater(...)
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
input.SkipBytes(totalLength);
this.Unwrap(context, input, startOffset, totalLength, packetLengths, output);
if (!this.firedChannelRead)
{
// Check first if firedChannelRead is not set yet as it may have been set in a
// previous decode(...) call.
this.firedChannelRead = output.Count > 0;
}
}
if (nonSslRecord)
{
// Not an SSL/TLS packet
var ex = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
input.SkipBytes(input.ReadableBytes);
context.FireExceptionCaught(ex);
this.HandleFailure(ex);
}
}
public override void ChannelReadComplete(IChannelHandlerContext ctx)
{
// Discard bytes of the cumulation buffer if needed.
this.DiscardSomeReadBytes();
this.ReadIfNeeded(ctx);
this.firedChannelRead = false;
ctx.FireChannelReadComplete();
}
void ReadIfNeeded(IChannelHandlerContext ctx)
{
// if handshake is not finished yet, we need more data
if (!ctx.Channel.Configuration.AutoRead && (!this.firedChannelRead || !this.state.HasAny(TlsHandlerState.AuthenticationCompleted)))
{
// No auto-read used and no message was passed through the ChannelPipeline or the handshake was not completed
// yet, which means we need to trigger the read to ensure we will not stall
ctx.Read();
}
}
/// <summary>Unwraps inbound SSL records.</summary>
void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int length, List<int> packetLengths, List<object> output)
{
Contract.Requires(packetLengths.Count > 0);
//bool notifyClosure = false; // todo: netty/issues/137
bool pending = false;
IByteBuffer outputBuffer = null;
try
{
ArraySegment<byte> inputIoBuffer = packet.GetIoBuffer(offset, length);
this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator);
int packetIndex = 0;
while (!this.EnsureAuthenticated())
{
this.mediationStream.ExpandSource(packetLengths[packetIndex]);
if (++packetIndex == packetLengths.Count)
{
return;
}
}
Task<int> currentReadFuture = this.pendingSslStreamReadFuture;
int outputBufferLength;
if (currentReadFuture != null)
{
// restoring context from previous read
Contract.Assert(this.pendingSslStreamReadBuffer != null);
outputBuffer = this.pendingSslStreamReadBuffer;
outputBufferLength = this.pendingSslStreamReadLength;
this.pendingSslStreamReadFuture = null;
this.pendingSslStreamReadBuffer = null;
this.pendingSslStreamReadLength = 0;
}
else
{
outputBufferLength = 0;
}
// go through packets one by one (because SslStream does not consume more than 1 packet at a time)
for (; packetIndex < packetLengths.Count; packetIndex++)
{
int currentPacketLength = packetLengths[packetIndex];
this.mediationStream.ExpandSource(currentPacketLength);
while (true)
{
int totalRead = 0;
if (currentReadFuture != null)
{
// there was a read pending already, so we make sure we completed that first
if (!currentReadFuture.IsCompleted)
{
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
break;
}
int read = currentReadFuture.Result;
totalRead += read;
if (read == 0)
{
//Stream closed
return;
}
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
AddBufferToOutput(outputBuffer, read, output);
currentReadFuture = null;
outputBuffer = null;
if (this.mediationStream.TotalReadableBytes == 0)
{
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
if (read < outputBufferLength)
{
// SslStream returned non-full buffer and there's no more input to go through ->
// typically it means SslStream is done reading current frame so we skip
break;
}
// we've read out `read` bytes out of current packet to fulfil previously outstanding read
outputBufferLength = currentPacketLength - totalRead;
if (outputBufferLength <= 0)
{
// after feeding to SslStream current frame it read out more bytes than current packet size
outputBufferLength = FallbackReadBufferSize;
}
}
else
{
// SslStream did not get to reading current frame so it completed previous read sync
// and the next read will likely read out the new frame
outputBufferLength = currentPacketLength;
}
}
else
{
// there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
outputBufferLength = currentPacketLength;
}
outputBuffer = ctx.Allocator.Buffer(outputBufferLength);
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength);
}
}
if (currentReadFuture != null)
{
pending = true;
this.pendingSslStreamReadBuffer = outputBuffer;
this.pendingSslStreamReadFuture = currentReadFuture;
this.pendingSslStreamReadLength = outputBufferLength;
//Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
currentReadFuture.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None);
}
}
catch (Exception ex)
{
this.HandleFailure(ex);
throw;
}
finally
{
this.mediationStream.ResetSource(ctx.Allocator);
if (!pending && outputBuffer != null)
{
if (outputBuffer.IsReadable())
{
output.Add(outputBuffer);
}
else
{
outputBuffer.SafeRelease();
}
}
}
}
static void UnwrapCompleted(Task<int> task, object state)
{
// Mono(with legacy provider) finish ReadAsync in async,
// so extra check is needed to receive data in async
var self = (TlsHandler)state;
Debug.Assert(self.capturedContext.Executor.InEventLoop);
//Ignore task completed in Unwrap
if (task == self.pendingSslStreamReadFuture)
{
IByteBuffer buf = self.pendingSslStreamReadBuffer;
int outputBufferLength = self.pendingSslStreamReadLength;
self.pendingSslStreamReadFuture = null;
self.pendingSslStreamReadBuffer = null;
self.pendingSslStreamReadLength = 0;
while (true)
{
switch (task.Status)
{
case TaskStatus.RanToCompletion:
{
var read = task.Result;
//Stream Closed
if (read == 0)
return;
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read));
if (self.mediationStream.TotalReadableBytes == 0)
{
self.capturedContext.FireChannelReadComplete();
self.mediationStream.ResetSource(self.capturedContext.Allocator);
if (read < outputBufferLength)
{
// SslStream returned non-full buffer and there's no more input to go through ->
// typically it means SslStream is done reading current frame so we skip
return;
}
}
outputBufferLength = self.mediationStream.TotalReadableBytes;
if (outputBufferLength <= 0)
outputBufferLength = FallbackReadBufferSize;
buf = self.capturedContext.Allocator.Buffer(outputBufferLength);
task = self.ReadFromSslStreamAsync(buf, outputBufferLength);
if (task.IsCompleted)
{
continue;
}
self.pendingSslStreamReadFuture = task;
self.pendingSslStreamReadBuffer = buf;
self.pendingSslStreamReadLength = outputBufferLength;
task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously);
return;
}
case TaskStatus.Canceled:
case TaskStatus.Faulted:
{
buf.SafeRelease();
self.HandleFailure(task.Exception);
return;
}
default:
{
buf.SafeRelease();
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
}
}
}
}
}
static void AddBufferToOutput(IByteBuffer outputBuffer, int length, List<object> output)
{
Contract.Assert(length > 0);
output.Add(outputBuffer.SetWriterIndex(outputBuffer.WriterIndex + length));
}
Task<int> ReadFromSslStreamAsync(IByteBuffer outputBuffer, int outputBufferLength)
{
ArraySegment<byte> outlet = outputBuffer.GetIoBuffer(outputBuffer.WriterIndex, outputBufferLength);
return this.sslStream.ReadAsync(outlet.Array, outlet.Offset, outlet.Count);
}
public override void Read(IChannelHandlerContext context)
{
TlsHandlerState oldState = this.state;
if (!oldState.HasAny(TlsHandlerState.AuthenticationCompleted))
{
this.state = oldState | TlsHandlerState.ReadRequestedBeforeAuthenticated;
}
context.Read();
}
bool EnsureAuthenticated()
{
TlsHandlerState oldState = this.state;
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
{
this.state = oldState | TlsHandlerState.Authenticating;
if (this.IsServer)
{
var serverSettings = (ServerTlsSettings)this.settings;
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, serverSettings.NegotiateClientCertificate, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
else
{
var clientSettings = (ClientTlsSettings)this.settings;
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, null, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
return false;
}
return oldState.Has(TlsHandlerState.Authenticated);
}
public override Task WriteAsync(IChannelHandlerContext context, object message)
{
if (!(message is IByteBuffer))
{
return TaskEx.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer)));
}
return this.pendingUnencryptedWrites.Add(message);
}
public override void Flush(IChannelHandlerContext context)
{
if (this.pendingUnencryptedWrites.IsEmpty)
{
this.pendingUnencryptedWrites.Add(Unpooled.Empty);
}
if (!this.EnsureAuthenticated())
{
this.state |= TlsHandlerState.FlushedBeforeHandshake;
return;
}
this.WrapAndFlush(context);
}
void WrapAndFlush(IChannelHandlerContext context)
{
this.flushMode = FlushMode.NoFlush;
try
{
this.Wrap(context);
}
finally
{
// We may have written some parts of data before an exception was thrown so ensure we always flush.
if (this.flushMode == FlushMode.NoFlush)
{
this.flushMode = FlushMode.ForceFlush;
context.Flush();
}
else
{
context.Executor.Execute((state) => {
var self = (TlsHandler)state;
self.flushMode = FlushMode.ForceFlush;
self.capturedContext.Flush();
}, this);
}
}
}
void Wrap(IChannelHandlerContext context)
{
Contract.Assert(context == this.capturedContext);
IByteBuffer buf = null;
try
{
while (true)
{
List<object> messages = this.pendingUnencryptedWrites.Current;
if (messages == null || messages.Count == 0)
{
break;
}
if (messages.Count == 1)
{
buf = (IByteBuffer)messages[0];
}
else
{
buf = context.Allocator.Buffer((int)this.pendingUnencryptedWrites.CurrentSize);
foreach (IByteBuffer buffer in messages)
{
buffer.ReadBytes(buf, buffer.ReadableBytes);
buffer.Release();
}
}
buf.ReadBytes(this.sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times
buf.Release();
TaskCompletionSource promise = this.pendingUnencryptedWrites.Remove();
Task task = this.lastContextWriteTask;
if (task != null)
{
task.LinkOutcome(promise);
this.lastContextWriteTask = null;
}
else
{
promise.TryComplete();
}
}
}
catch (Exception ex)
{
buf.SafeRelease();
this.HandleFailure(ex);
throw;
}
}
void FinishWrap(byte[] buffer, int offset, int count)
{
// In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread,
// so it will run after the call to Flush.
if (this.flushMode == FlushMode.NoFlush && !this.capturedContext.Executor.InEventLoop)
{
this.flushMode = FlushMode.PendingFlush;
}
IByteBuffer output;
if (count == 0)
{
output = Unpooled.Empty;
}
else
{
output = this.capturedContext.Allocator.Buffer(count);
output.WriteBytes(buffer, offset, count);
}
this.lastContextWriteTask = (this.flushMode == FlushMode.ForceFlush) ? this.capturedContext.WriteAndFlushAsync(output) : this.capturedContext.WriteAsync(output);
}
Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count)
{
var future = this.capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count));
this.ReadIfNeeded(this.capturedContext);
return future;
}
public override Task CloseAsync(IChannelHandlerContext context)
{
this.closeFuture.TryComplete();
this.sslStream.Dispose();
return base.CloseAsync(context);
}
void HandleFailure(Exception cause)
{
// Release all resources such as internal buffers that SSLEngine
// is managing.
this.mediationStream.Dispose();
try
{
this.sslStream.Dispose();
}
catch (Exception)
{
// todo: evaluate following:
// only log in Debug mode as it most likely harmless and latest chrome still trigger
// this all the time.
//
// See https://github.com/netty/netty/issues/1340
//string msg = ex.Message;
//if (msg == null || !msg.contains("possible truncation attack"))
//{
// //Logger.Debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e);
//}
}
this.pendingSslStreamReadBuffer?.SafeRelease();
this.pendingSslStreamReadBuffer = null;
this.pendingSslStreamReadFuture = null;
this.NotifyHandshakeFailure(cause);
this.pendingUnencryptedWrites.RemoveAndFailAll(cause);
}
void NotifyHandshakeFailure(Exception cause)
{
if (!this.state.HasAny(TlsHandlerState.AuthenticationCompleted))
{
// handshake was not completed yet => TlsHandler react to failure by closing the channel
this.state = (this.state | TlsHandlerState.FailedAuthentication) & ~TlsHandlerState.Authenticating;
this.capturedContext.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause));
this.CloseAsync(this.capturedContext);
}
}
enum FlushMode : byte
{
/// <summary>
/// Do nothing with Flush.
/// </summary>
NoFlush = 0,
/// <summary>
/// An Flush is or will be posted to IEventExecutor.
/// </summary>
PendingFlush = 1,
/// <summary>
/// Force FinishWrap to call Flush.
/// </summary>
ForceFlush = 2,
}
sealed class MediationStream : Stream
{
readonly TlsHandler owner;
object sourceLock = new object();
IByteBuffer ownBuffer;
byte[] input;
int inputStartOffset;
int inputOffset;
int inputLength;
TaskCompletionSource<int> readCompletionSource;
ArraySegment<byte> sslOwnedBuffer;
#if NETSTANDARD1_3
int readByteCount;
#else
SynchronousAsyncResult<int> syncReadResult;
AsyncCallback readCallback;
TaskCompletionSource writeCompletion;
AsyncCallback writeCallback;
#endif
public MediationStream(TlsHandler owner)
{
this.owner = owner;
}
public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes;
public int SourceReadableBytes => this.inputLength - this.inputOffset;
public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc)
{
lock (sourceLock)
{
ResetSource(alloc);
this.input = source;
this.inputStartOffset = offset;
this.inputOffset = 0;
this.inputLength = 0;
}
}
public void ResetSource(IByteBufferAllocator alloc)
{
//Mono will run BeginRead in async and it's running with ResetSource at the same time
lock (sourceLock)
{
int leftLen = this.SourceReadableBytes;
IByteBuffer buf = this.ownBuffer;
if (leftLen > 0)
{
if (buf != null)
{
buf.DiscardSomeReadBytes();
}
else
{
buf = alloc.Buffer(leftLen);
this.ownBuffer = buf;
}
buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen);
}
else if (buf != null)
{
if (!buf.IsReadable())
{
buf.SafeRelease();
this.ownBuffer = null;
}
else
{
buf.DiscardSomeReadBytes();
}
}
this.input = null;
this.inputStartOffset = 0;
this.inputOffset = 0;
this.inputLength = 0;
}
}
public void ExpandSource(int count)
{
Contract.Assert(this.input != null);
lock (sourceLock)
{
this.inputLength += count;
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
if (sslBuffer.Array == null)
{
// there is no pending read operation - keep for future
return;
}
this.sslOwnedBuffer = default(ArraySegment<byte>);
#if NETSTANDARD1_3
this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
// hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available.
new Task(
ms =>
{
var self = (MediationStream)ms;
TaskCompletionSource<int> p = self.readCompletionSource;
self.readCompletionSource = null;
p.TrySetResult(self.readByteCount);
},
this)
.RunSynchronously(TaskScheduler.Default);
#else
int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
TaskCompletionSource<int> promise = this.readCompletionSource;
this.readCompletionSource = null;
promise.TrySetResult(read);
AsyncCallback callback = this.readCallback;
this.readCallback = null;
callback?.Invoke(promise.Task);
#endif
}
}
public override int Read(byte[] buffer, int offset, int count) => this.ReadAsync(buffer, offset, count).Result;
#if NETSTANDARD1_3
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
return Task.FromResult(read);
}
Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>();
return this.readCompletionSource.Task;
}
#else
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
var res = this.PrepareSyncReadResult(read, state);
callback?.Invoke(res);
return res;
}
Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>(state);
this.readCallback = callback;
return this.readCompletionSource.Task;
}
public override int EndRead(IAsyncResult asyncResult)
{
SynchronousAsyncResult<int> syncResult = this.syncReadResult;
if (ReferenceEquals(asyncResult, syncResult))
{
return syncResult.Result;
}
Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult);
Contract.Assert(!((Task<int>)asyncResult).IsCanceled);
try
{
return ((Task<int>)asyncResult).Result;
}
catch (AggregateException ex)
{
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw; // unreachable
}
}
IAsyncResult PrepareSyncReadResult(int readBytes, object state)
{
// it is safe to reuse sync result object as it can't lead to leak (no way to attach to it via handle)
SynchronousAsyncResult<int> result = this.syncReadResult ?? (this.syncReadResult = new SynchronousAsyncResult<int>());
result.Result = readBytes;
result.AsyncState = state;
return result;
}
#endif
public override void Write(byte[] buffer, int offset, int count) => this.owner.FinishWrap(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
=> this.owner.FinishWrapNonAppDataAsync(buffer, offset, count);
#if !NETSTANDARD1_3
static readonly Action<Task, object> WriteCompleteCallback = HandleChannelWriteComplete;
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
Task task = this.WriteAsync(buffer, offset, count);
switch (task.Status)
{
case TaskStatus.RanToCompletion:
// write+flush completed synchronously (and successfully)
var result = new SynchronousAsyncResult<int>();
result.AsyncState = state;
callback?.Invoke(result);
return result;
default:
this.writeCallback = callback;
Contract.Assert(this.writeCompletion == null);
var tcs = new TaskCompletionSource(state);
this.writeCompletion = tcs;
task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously);
return tcs.Task;
}
}
static void HandleChannelWriteComplete(Task writeTask, object state)
{
var self = (MediationStream)state;
AsyncCallback callback = self.writeCallback;
self.writeCallback = null;
var promise = self.writeCompletion;
self.writeCompletion = null;
switch (writeTask.Status)
{
case TaskStatus.RanToCompletion:
promise.TryComplete();
break;
case TaskStatus.Canceled:
promise.TrySetCanceled();
break;
case TaskStatus.Faulted:
promise.TrySetException(writeTask.Exception);
break;
default:
throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status);
}
callback?.Invoke(promise.Task);
}
public override void EndWrite(IAsyncResult asyncResult)
{
if (asyncResult is SynchronousAsyncResult<int>)
{
return;
}
Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult);
try
{
((Task<int>)asyncResult).Wait();
}
catch (AggregateException ex)
{
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw;
}
}
#endif
int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity)
{
Contract.Assert(destination != null);
lock (sourceLock)
{
int length = 0;
do
{
int readableBytes;
IByteBuffer buf = this.ownBuffer;
if (buf != null)
{
readableBytes = buf.ReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(buf.ReadableBytes, destinationCapacity);
buf.ReadBytes(destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;
if (destinationCapacity == 0)
break;
}
}
byte[] source = this.input;
if (source != null)
{
readableBytes = this.SourceReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(readableBytes, destinationCapacity);
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;
this.inputOffset += readableBytes;
}
}
} while (false);
return length;
}
}
public override void Flush()
{
// NOOP: called on SslStream.Close
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
if (disposing)
{
TaskCompletionSource<int> p = this.readCompletionSource;
this.readCompletionSource = null;
p?.TrySetResult(0);
}
}
#region plumbing
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
#endregion
#region sync result
sealed class SynchronousAsyncResult<T> : IAsyncResult
{
public T Result { get; set; }
public bool IsCompleted => true;
public WaitHandle AsyncWaitHandle
{
get { throw new InvalidOperationException("Cannot wait on a synchronous result."); }
}
public object AsyncState { get; set; }
public bool CompletedSynchronously => true;
}
#endregion
}
}
[Flags]
enum TlsHandlerState
{
Authenticating = 1,
Authenticated = 1 << 1,
FailedAuthentication = 1 << 2,
ReadRequestedBeforeAuthenticated = 1 << 3,
FlushedBeforeHandshake = 1 << 4,
AuthenticationStarted = Authenticating | Authenticated | FailedAuthentication,
AuthenticationCompleted = Authenticated | FailedAuthentication
}
static class TlsHandlerStateExtensions
{
public static bool Has(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) == testValue;
public static bool HasAny(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) != 0;
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System;
using System.Diagnostics.Contracts;
public sealed class TlsHandshakeCompletionEvent
{
public static readonly TlsHandshakeCompletionEvent Success = new TlsHandshakeCompletionEvent();
readonly Exception exception;
/// <summary>
/// Creates a new event that indicates a successful handshake.
/// </summary>
TlsHandshakeCompletionEvent()
{
this.exception = null;
}
/// <summary>
/// Creates a new event that indicates an unsuccessful handshake.
/// Use {@link #SUCCESS} to indicate a successful handshake.
/// </summary>
public TlsHandshakeCompletionEvent(Exception exception)
{
Contract.Requires(exception != null);
this.exception = exception;
}
/// <summary>
/// Return {@code true} if the handshake was successful
/// </summary>
public bool IsSuccessful => this.exception == null;
/// <summary>
/// Return the {@link Throwable} if {@link #isSuccess()} returns {@code false}
/// and so the handshake failed.
/// </summary>
public Exception Exception => this.exception;
public override string ToString()
{
Exception ex = this.Exception;
return ex == null ? "TlsHandshakeCompletionEvent(SUCCESS)" : $"TlsHandshakeCompletionEvent({ex})";
}
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System;
using DotNetty.Buffers;
using DotNetty.Transport.Channels;
/// Utilities for TLS packets.
static class TlsUtils
{
const int MAX_PLAINTEXT_LENGTH = 16 * 1024; // 2^14
const int MAX_COMPRESSED_LENGTH = MAX_PLAINTEXT_LENGTH + 1024;
const int MAX_CIPHERTEXT_LENGTH = MAX_COMPRESSED_LENGTH + 1024;
// Header (5) + Data (2^14) + Compression (1024) + Encryption (1024) + MAC (20) + Padding (256)
public const int MAX_ENCRYPTED_PACKET_LENGTH = MAX_CIPHERTEXT_LENGTH + 5 + 20 + 256;
/// change cipher spec
public const int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/// alert
public const int SSL_CONTENT_TYPE_ALERT = 21;
/// handshake
public const int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/// application data
public const int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/// the length of the ssl record header (in bytes)
public const int SSL_RECORD_HEADER_LENGTH = 5;
// Not enough data in buffer to parse the record length
public const int NOT_ENOUGH_DATA = -1;
// data is not encrypted
public const int NOT_ENCRYPTED = -2;
/// <summary>
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
/// the readerIndex of the given <see cref="IByteBuffer"/>.
/// </summary>
/// <param name="buffer">
/// The <see cref="IByteBuffer"/> to read from. Be aware that it must have at least
/// <see cref="SSL_RECORD_HEADER_LENGTH"/> bytes to read,
/// otherwise it will throw an <see cref="ArgumentException"/>.
/// </param>
/// <param name="offset">Offset to record start.</param>
/// <returns>
/// The length of the encrypted packet that is included in the buffer. This will
/// return <c>-1</c> if the given <see cref="IByteBuffer"/> is not encrypted at all.
/// </returns>
public static int GetEncryptedPacketLength(IByteBuffer buffer, int offset)
{
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
switch (buffer.GetByte(offset))
{
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SSL_CONTENT_TYPE_ALERT:
case SSL_CONTENT_TYPE_HANDSHAKE:
case SSL_CONTENT_TYPE_APPLICATION_DATA:
break;
default:
// SSLv2 or bad data
return -1;
}
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.GetByte(offset + 1);
if (majorVersion == 3)
{
// SSLv3 or TLS
packetLength = buffer.GetUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SSL_RECORD_HEADER_LENGTH)
{
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
return -1;
}
}
else
{
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
return -1;
}
return packetLength;
}
public static void NotifyHandshakeFailure(IChannelHandlerContext ctx, Exception cause)
{
// We have may haven written some parts of data before an exception was thrown so ensure we always flush.
// See https://github.com/netty/netty/issues/3900#issuecomment-172481830
ctx.Flush();
ctx.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause));
ctx.CloseAsync();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment