Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Custom TlsHandler for DotNetty. This is needed when using Matrix vNext with Xamarin. From https://github.com/Azure/DotNetty/pull/374
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using DotNetty.Handlers.Tls;
public sealed class ServerTlsSettings : TlsSettings
{
public ServerTlsSettings(X509Certificate certificate)
: this(certificate, false)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate)
: this(certificate, negotiateClientCertificate, false)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation)
: this(certificate, negotiateClientCertificate, checkCertificateRevocation, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12)
{
}
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation, SslProtocols enabledProtocols)
: base(enabledProtocols, checkCertificateRevocation)
{
this.Certificate = certificate;
this.NegotiateClientCertificate = negotiateClientCertificate;
}
public X509Certificate Certificate { get; }
public bool NegotiateClientCertificate { get; }
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
// from: https://github.com/Azure/DotNetty/pull/374
namespace CustomMonoTlsHandler
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.IO;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using DotNetty.Buffers;
using DotNetty.Codecs;
using DotNetty.Common.Concurrency;
using DotNetty.Common.Utilities;
using DotNetty.Handlers.Tls;
using DotNetty.Transport.Channels;
public sealed class TlsHandler : ByteToMessageDecoder
{
readonly TlsSettings settings;
const int FallbackReadBufferSize = 256;
const int UnencryptedWriteBatchSize = 14 * 1024;
static readonly Exception ChannelClosedException = new IOException("Channel is closed");
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
static readonly Action<Task<int>, object> UnwrapCompletedCallback = new Action<Task<int>, object>(UnwrapCompleted);
readonly SslStream sslStream;
readonly MediationStream mediationStream;
readonly TaskCompletionSource closeFuture;
TlsHandlerState state;
int packetLength;
volatile IChannelHandlerContext capturedContext;
BatchingPendingWriteQueue pendingUnencryptedWrites;
Task lastContextWriteTask;
bool firedChannelRead;
volatile FlushMode flushMode = FlushMode.ForceFlush;
IByteBuffer pendingSslStreamReadBuffer;
int pendingSslStreamReadLength;
Task<int> pendingSslStreamReadFuture;
public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
}
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
{
Contract.Requires(sslStreamFactory != null);
Contract.Requires(settings != null);
this.settings = settings;
this.closeFuture = new TaskCompletionSource();
this.mediationStream = new MediationStream(this);
this.sslStream = sslStreamFactory(this.mediationStream);
}
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));
// using workaround mentioned here: https://github.com/dotnet/corefx/issues/4510
public X509Certificate2 LocalCertificate => this.sslStream.LocalCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.LocalCertificate?.Export(X509ContentType.Cert));
public X509Certificate2 RemoteCertificate => this.sslStream.RemoteCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.RemoteCertificate?.Export(X509ContentType.Cert));
bool IsServer => this.settings is ServerTlsSettings;
public void Dispose() => this.sslStream?.Dispose();
public override void ChannelActive(IChannelHandlerContext context)
{
base.ChannelActive(context);
if (!this.IsServer)
{
this.EnsureAuthenticated();
}
}
public override void ChannelInactive(IChannelHandlerContext context)
{
// Make sure to release SslStream,
// and notify the handshake future if the connection has been closed during handshake.
this.HandleFailure(ChannelClosedException);
base.ChannelInactive(context);
}
public override void ExceptionCaught(IChannelHandlerContext context, Exception exception)
{
if (this.IgnoreException(exception))
{
// Close the connection explicitly just in case the transport
// did not close the connection automatically.
if (context.Channel.Active)
{
context.CloseAsync();
}
}
else
{
base.ExceptionCaught(context, exception);
}
}
bool IgnoreException(Exception t)
{
if (t is ObjectDisposedException && this.closeFuture.Task.IsCompleted)
{
return true;
}
return false;
}
static void HandleHandshakeCompleted(Task task, object state)
{
var self = (TlsHandler)state;
switch (task.Status)
{
case TaskStatus.RanToCompletion:
{
TlsHandlerState oldState = self.state;
Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted));
self.state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake);
self.capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success);
if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self.capturedContext.Channel.Configuration.AutoRead)
{
self.capturedContext.Read();
}
if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake))
{
self.WrapAndFlush(self.capturedContext);
}
break;
}
case TaskStatus.Canceled:
case TaskStatus.Faulted:
{
// ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted
TlsHandlerState oldState = self.state;
Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated));
self.HandleFailure(task.Exception);
break;
}
default:
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
}
}
public override void HandlerAdded(IChannelHandlerContext context)
{
base.HandlerAdded(context);
this.capturedContext = context;
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
if (context.Channel.Active && !this.IsServer)
{
// todo: support delayed initialization on an existing/active channel if in client mode
this.EnsureAuthenticated();
}
}
protected override void HandlerRemovedInternal(IChannelHandlerContext context)
{
if (!this.pendingUnencryptedWrites.IsEmpty)
{
// Check if queue is not empty first because create a new ChannelException is expensive
this.pendingUnencryptedWrites.RemoveAndFailAll(new ChannelException("Write has failed due to TlsHandler being removed from channel pipeline."));
}
}
protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output)
{
int startOffset = input.ReaderIndex;
int endOffset = input.WriterIndex;
int offset = startOffset;
int totalLength = 0;
List<int> packetLengths;
// if we calculated the length of the current SSL record before, use that information.
if (this.packetLength > 0)
{
if (endOffset - startOffset < this.packetLength)
{
// input does not contain a single complete SSL record
return;
}
else
{
packetLengths = new List<int>(4);
packetLengths.Add(this.packetLength);
offset += this.packetLength;
totalLength = this.packetLength;
this.packetLength = 0;
}
}
else
{
packetLengths = new List<int>(4);
}
bool nonSslRecord = false;
while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
int readableBytes = endOffset - offset;
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
{
break;
}
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (encryptedPacketLength == -1)
{
nonSslRecord = true;
break;
}
Contract.Assert(encryptedPacketLength > 0);
if (encryptedPacketLength > readableBytes)
{
// wait until the whole packet can be read
this.packetLength = encryptedPacketLength;
break;
}
int newTotalLength = totalLength + encryptedPacketLength;
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
// Don't read too much.
break;
}
// 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once.
// 2. once we're through all the whole packets, switch to reading out using fallback sized buffer
// We have a whole packet.
// Increment the offset to handle the next packet.
packetLengths.Add(encryptedPacketLength);
offset += encryptedPacketLength;
totalLength = newTotalLength;
}
if (totalLength > 0)
{
// The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger
// decode(...) again via:
// 1) unwrap(..) is called
// 2) wrap(...) is called from within unwrap(...)
// 3) wrap(...) calls unwrapLater(...)
// 4) unwrapLater(...) calls decode(...)
//
// See https://github.com/netty/netty/issues/1534
input.SkipBytes(totalLength);
this.Unwrap(context, input, startOffset, totalLength, packetLengths, output);
if (!this.firedChannelRead)
{
// Check first if firedChannelRead is not set yet as it may have been set in a
// previous decode(...) call.
this.firedChannelRead = output.Count > 0;
}
}
if (nonSslRecord)
{
// Not an SSL/TLS packet
var ex = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
input.SkipBytes(input.ReadableBytes);
context.FireExceptionCaught(ex);
this.HandleFailure(ex);
}
}
public override void ChannelReadComplete(IChannelHandlerContext ctx)
{
// Discard bytes of the cumulation buffer if needed.
this.DiscardSomeReadBytes();
this.ReadIfNeeded(ctx);
this.firedChannelRead = false;
ctx.FireChannelReadComplete();
}
void ReadIfNeeded(IChannelHandlerContext ctx)
{
// if handshake is not finished yet, we need more data
if (!ctx.Channel.Configuration.AutoRead && (!this.firedChannelRead || !this.state.HasAny(TlsHandlerState.AuthenticationCompleted)))
{
// No auto-read used and no message was passed through the ChannelPipeline or the handshake was not completed
// yet, which means we need to trigger the read to ensure we will not stall
ctx.Read();
}
}
/// <summary>Unwraps inbound SSL records.</summary>
void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int length, List<int> packetLengths, List<object> output)
{
Contract.Requires(packetLengths.Count > 0);
//bool notifyClosure = false; // todo: netty/issues/137
bool pending = false;
IByteBuffer outputBuffer = null;
try
{
ArraySegment<byte> inputIoBuffer = packet.GetIoBuffer(offset, length);
this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator);
int packetIndex = 0;
while (!this.EnsureAuthenticated())
{
this.mediationStream.ExpandSource(packetLengths[packetIndex]);
if (++packetIndex == packetLengths.Count)
{
return;
}
}
Task<int> currentReadFuture = this.pendingSslStreamReadFuture;
int outputBufferLength;
if (currentReadFuture != null)
{
// restoring context from previous read
Contract.Assert(this.pendingSslStreamReadBuffer != null);
outputBuffer = this.pendingSslStreamReadBuffer;
outputBufferLength = this.pendingSslStreamReadLength;
this.pendingSslStreamReadFuture = null;
this.pendingSslStreamReadBuffer = null;
this.pendingSslStreamReadLength = 0;
}
else
{
outputBufferLength = 0;
}
// go through packets one by one (because SslStream does not consume more than 1 packet at a time)
for (; packetIndex < packetLengths.Count; packetIndex++)
{
int currentPacketLength = packetLengths[packetIndex];
this.mediationStream.ExpandSource(currentPacketLength);
while (true)
{
int totalRead = 0;
if (currentReadFuture != null)
{
// there was a read pending already, so we make sure we completed that first
if (!currentReadFuture.IsCompleted)
{
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
break;
}
int read = currentReadFuture.Result;
totalRead += read;
if (read == 0)
{
//Stream closed
return;
}
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
AddBufferToOutput(outputBuffer, read, output);
currentReadFuture = null;
outputBuffer = null;
if (this.mediationStream.TotalReadableBytes == 0)
{
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
if (read < outputBufferLength)
{
// SslStream returned non-full buffer and there's no more input to go through ->
// typically it means SslStream is done reading current frame so we skip
break;
}
// we've read out `read` bytes out of current packet to fulfil previously outstanding read
outputBufferLength = currentPacketLength - totalRead;
if (outputBufferLength <= 0)
{
// after feeding to SslStream current frame it read out more bytes than current packet size
outputBufferLength = FallbackReadBufferSize;
}
}
else
{
// SslStream did not get to reading current frame so it completed previous read sync
// and the next read will likely read out the new frame
outputBufferLength = currentPacketLength;
}
}
else
{
// there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
outputBufferLength = currentPacketLength;
}
outputBuffer = ctx.Allocator.Buffer(outputBufferLength);
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength);
}
}
if (currentReadFuture != null)
{
pending = true;
this.pendingSslStreamReadBuffer = outputBuffer;
this.pendingSslStreamReadFuture = currentReadFuture;
this.pendingSslStreamReadLength = outputBufferLength;
//Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
currentReadFuture.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None);
}
}
catch (Exception ex)
{
this.HandleFailure(ex);
throw;
}
finally
{
this.mediationStream.ResetSource(ctx.Allocator);
if (!pending && outputBuffer != null)
{
if (outputBuffer.IsReadable())
{
output.Add(outputBuffer);
}
else
{
outputBuffer.SafeRelease();
}
}
}
}
static void UnwrapCompleted(Task<int> task, object state)
{
// Mono(with legacy provider) finish ReadAsync in async,
// so extra check is needed to receive data in async
var self = (TlsHandler)state;
Debug.Assert(self.capturedContext.Executor.InEventLoop);
//Ignore task completed in Unwrap
if (task == self.pendingSslStreamReadFuture)
{
IByteBuffer buf = self.pendingSslStreamReadBuffer;
int outputBufferLength = self.pendingSslStreamReadLength;
self.pendingSslStreamReadFuture = null;
self.pendingSslStreamReadBuffer = null;
self.pendingSslStreamReadLength = 0;
while (true)
{
switch (task.Status)
{
case TaskStatus.RanToCompletion:
{
var read = task.Result;
//Stream Closed
if (read == 0)
return;
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read));
if (self.mediationStream.TotalReadableBytes == 0)
{
self.capturedContext.FireChannelReadComplete();
self.mediationStream.ResetSource(self.capturedContext.Allocator);
if (read < outputBufferLength)
{
// SslStream returned non-full buffer and there's no more input to go through ->
// typically it means SslStream is done reading current frame so we skip
return;
}
}
outputBufferLength = self.mediationStream.TotalReadableBytes;
if (outputBufferLength <= 0)
outputBufferLength = FallbackReadBufferSize;
buf = self.capturedContext.Allocator.Buffer(outputBufferLength);
task = self.ReadFromSslStreamAsync(buf, outputBufferLength);
if (task.IsCompleted)
{
continue;
}
self.pendingSslStreamReadFuture = task;
self.pendingSslStreamReadBuffer = buf;
self.pendingSslStreamReadLength = outputBufferLength;
task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously);
return;
}
case TaskStatus.Canceled:
case TaskStatus.Faulted:
{
buf.SafeRelease();
self.HandleFailure(task.Exception);
return;
}
default:
{
buf.SafeRelease();
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
}
}
}
}
}
static void AddBufferToOutput(IByteBuffer outputBuffer, int length, List<object> output)
{
Contract.Assert(length > 0);
output.Add(outputBuffer.SetWriterIndex(outputBuffer.WriterIndex + length));
}
Task<int> ReadFromSslStreamAsync(IByteBuffer outputBuffer, int outputBufferLength)
{
ArraySegment<byte> outlet = outputBuffer.GetIoBuffer(outputBuffer.WriterIndex, outputBufferLength);
return this.sslStream.ReadAsync(outlet.Array, outlet.Offset, outlet.Count);
}
public override void Read(IChannelHandlerContext context)
{
TlsHandlerState oldState = this.state;
if (!oldState.HasAny(TlsHandlerState.AuthenticationCompleted))
{
this.state = oldState | TlsHandlerState.ReadRequestedBeforeAuthenticated;
}
context.Read();
}
bool EnsureAuthenticated()
{
TlsHandlerState oldState = this.state;
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
{
this.state = oldState | TlsHandlerState.Authenticating;
if (this.IsServer)
{
var serverSettings = (ServerTlsSettings)this.settings;
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, serverSettings.NegotiateClientCertificate, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
else
{
var clientSettings = (ClientTlsSettings)this.settings;
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, null, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
return false;
}
return oldState.Has(TlsHandlerState.Authenticated);
}
public override Task WriteAsync(IChannelHandlerContext context, object message)
{
if (!(message is IByteBuffer))
{
return TaskEx.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer)));
}
return this.pendingUnencryptedWrites.Add(message);
}
public override void Flush(IChannelHandlerContext context)
{
if (this.pendingUnencryptedWrites.IsEmpty)
{
this.pendingUnencryptedWrites.Add(Unpooled.Empty);
}
if (!this.EnsureAuthenticated())
{
this.state |= TlsHandlerState.FlushedBeforeHandshake;
return;
}
this.WrapAndFlush(context);
}
void WrapAndFlush(IChannelHandlerContext context)
{
this.flushMode = FlushMode.NoFlush;
try
{
this.Wrap(context);
}
finally
{
// We may have written some parts of data before an exception was thrown so ensure we always flush.
if (this.flushMode == FlushMode.NoFlush)
{
this.flushMode = FlushMode.ForceFlush;
context.Flush();
}
else
{
context.Executor.Execute((state) => {
var self = (TlsHandler)state;
self.flushMode = FlushMode.ForceFlush;
self.capturedContext.Flush();
}, this);
}
}
}
void Wrap(IChannelHandlerContext context)
{
Contract.Assert(context == this.capturedContext);
IByteBuffer buf = null;
try
{
while (true)
{
List<object> messages = this.pendingUnencryptedWrites.Current;
if (messages == null || messages.Count == 0)
{
break;
}
if (messages.Count == 1)
{
buf = (IByteBuffer)messages[0];
}
else
{
buf = context.Allocator.Buffer((int)this.pendingUnencryptedWrites.CurrentSize);
foreach (IByteBuffer buffer in messages)
{
buffer.ReadBytes(buf, buffer.ReadableBytes);
buffer.Release();
}
}
buf.ReadBytes(this.sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times
buf.Release();
TaskCompletionSource promise = this.pendingUnencryptedWrites.Remove();
Task task = this.lastContextWriteTask;
if (task != null)
{
task.LinkOutcome(promise);
this.lastContextWriteTask = null;
}
else
{
promise.TryComplete();
}
}
}
catch (Exception ex)
{
buf.SafeRelease();
this.HandleFailure(ex);
throw;
}
}
void FinishWrap(byte[] buffer, int offset, int count)
{
// In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread,
// so it will run after the call to Flush.
if (this.flushMode == FlushMode.NoFlush && !this.capturedContext.Executor.InEventLoop)
{
this.flushMode = FlushMode.PendingFlush;
}
IByteBuffer output;
if (count == 0)
{
output = Unpooled.Empty;
}
else
{
output = this.capturedContext.Allocator.Buffer(count);
output.WriteBytes(buffer, offset, count);
}
this.lastContextWriteTask = (this.flushMode == FlushMode.ForceFlush) ? this.capturedContext.WriteAndFlushAsync(output) : this.capturedContext.WriteAsync(output);
}
Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count)
{
var future = this.capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count));
this.ReadIfNeeded(this.capturedContext);
return future;
}
public override Task CloseAsync(IChannelHandlerContext context)
{
this.closeFuture.TryComplete();
this.sslStream.Dispose();
return base.CloseAsync(context);
}
void HandleFailure(Exception cause)
{
// Release all resources such as internal buffers that SSLEngine
// is managing.
this.mediationStream.Dispose();
try
{
this.sslStream.Dispose();
}
catch (Exception)
{
// todo: evaluate following:
// only log in Debug mode as it most likely harmless and latest chrome still trigger
// this all the time.
//
// See https://github.com/netty/netty/issues/1340
//string msg = ex.Message;
//if (msg == null || !msg.contains("possible truncation attack"))
//{
// //Logger.Debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e);
//}
}
this.pendingSslStreamReadBuffer?.SafeRelease();
this.pendingSslStreamReadBuffer = null;
this.pendingSslStreamReadFuture = null;
this.NotifyHandshakeFailure(cause);
this.pendingUnencryptedWrites.RemoveAndFailAll(cause);
}
void NotifyHandshakeFailure(Exception cause)
{
if (!this.state.HasAny(TlsHandlerState.AuthenticationCompleted))
{
// handshake was not completed yet => TlsHandler react to failure by closing the channel
this.state = (this.state | TlsHandlerState.FailedAuthentication) & ~TlsHandlerState.Authenticating;
this.capturedContext.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause));
this.CloseAsync(this.capturedContext);
}
}
enum FlushMode : byte
{
/// <summary>
/// Do nothing with Flush.
/// </summary>
NoFlush = 0,
/// <summary>
/// An Flush is or will be posted to IEventExecutor.
/// </summary>
PendingFlush = 1,
/// <summary>
/// Force FinishWrap to call Flush.
/// </summary>
ForceFlush = 2,
}
sealed class MediationStream : Stream
{
readonly TlsHandler owner;
object sourceLock = new object();
IByteBuffer ownBuffer;
byte[] input;
int inputStartOffset;
int inputOffset;
int inputLength;
TaskCompletionSource<int> readCompletionSource;
ArraySegment<byte> sslOwnedBuffer;
#if NETSTANDARD1_3
int readByteCount;
#else
SynchronousAsyncResult<int> syncReadResult;
AsyncCallback readCallback;
TaskCompletionSource writeCompletion;
AsyncCallback writeCallback;
#endif
public MediationStream(TlsHandler owner)
{
this.owner = owner;
}
public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes;
public int SourceReadableBytes => this.inputLength - this.inputOffset;
public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc)
{
lock (sourceLock)
{
ResetSource(alloc);
this.input = source;
this.inputStartOffset = offset;
this.inputOffset = 0;
this.inputLength = 0;
}
}
public void ResetSource(IByteBufferAllocator alloc)
{
//Mono will run BeginRead in async and it's running with ResetSource at the same time
lock (sourceLock)
{
int leftLen = this.SourceReadableBytes;
IByteBuffer buf = this.ownBuffer;
if (leftLen > 0)
{
if (buf != null)
{
buf.DiscardSomeReadBytes();
}
else
{
buf = alloc.Buffer(leftLen);
this.ownBuffer = buf;
}
buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen);
}
else if (buf != null)
{
if (!buf.IsReadable())
{
buf.SafeRelease();
this.ownBuffer = null;
}
else
{
buf.DiscardSomeReadBytes();
}
}
this.input = null;
this.inputStartOffset = 0;
this.inputOffset = 0;
this.inputLength = 0;
}
}
public void ExpandSource(int count)
{
Contract.Assert(this.input != null);
lock (sourceLock)
{
this.inputLength += count;
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
if (sslBuffer.Array == null)
{
// there is no pending read operation - keep for future
return;
}
this.sslOwnedBuffer = default(ArraySegment<byte>);
#if NETSTANDARD1_3
this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
// hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available.
new Task(
ms =>
{
var self = (MediationStream)ms;
TaskCompletionSource<int> p = self.readCompletionSource;
self.readCompletionSource = null;
p.TrySetResult(self.readByteCount);
},
this)
.RunSynchronously(TaskScheduler.Default);
#else
int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
TaskCompletionSource<int> promise = this.readCompletionSource;
this.readCompletionSource = null;
promise.TrySetResult(read);
AsyncCallback callback = this.readCallback;
this.readCallback = null;
callback?.Invoke(promise.Task);
#endif
}
}
public override int Read(byte[] buffer, int offset, int count) => this.ReadAsync(buffer, offset, count).Result;
#if NETSTANDARD1_3
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
return Task.FromResult(read);
}
Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>();
return this.readCompletionSource.Task;
}
#else
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
var res = this.PrepareSyncReadResult(read, state);
callback?.Invoke(res);
return res;
}
Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>(state);
this.readCallback = callback;
return this.readCompletionSource.Task;
}
public override int EndRead(IAsyncResult asyncResult)
{
SynchronousAsyncResult<int> syncResult = this.syncReadResult;
if (ReferenceEquals(asyncResult, syncResult))
{
return syncResult.Result;
}
Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult);
Contract.Assert(!((Task<int>)asyncResult).IsCanceled);
try
{
return ((Task<int>)asyncResult).Result;
}
catch (AggregateException ex)
{
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw; // unreachable
}
}
IAsyncResult PrepareSyncReadResult(int readBytes, object state)
{
// it is safe to reuse sync result object as it can't lead to leak (no way to attach to it via handle)
SynchronousAsyncResult<int> result = this.syncReadResult ?? (this.syncReadResult = new SynchronousAsyncResult<int>());
result.Result = readBytes;
result.AsyncState = state;
return result;
}
#endif
public override void Write(byte[] buffer, int offset, int count) => this.owner.FinishWrap(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
=> this.owner.FinishWrapNonAppDataAsync(buffer, offset, count);
#if !NETSTANDARD1_3
static readonly Action<Task, object> WriteCompleteCallback = HandleChannelWriteComplete;
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
Task task = this.WriteAsync(buffer, offset, count);
switch (task.Status)
{
case TaskStatus.RanToCompletion:
// write+flush completed synchronously (and successfully)
var result = new SynchronousAsyncResult<int>();
result.AsyncState = state;
callback?.Invoke(result);
return result;
default:
this.writeCallback = callback;
Contract.Assert(this.writeCompletion == null);
var tcs = new TaskCompletionSource(state);
this.writeCompletion = tcs;
task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously);
return tcs.Task;
}
}
static void HandleChannelWriteComplete(Task writeTask, object state)
{
var self = (MediationStream)state;
AsyncCallback callback = self.writeCallback;
self.writeCallback = null;
var promise = self.writeCompletion;
self.writeCompletion = null;
switch (writeTask.Status)
{
case TaskStatus.RanToCompletion:
promise.TryComplete();
break;
case TaskStatus.Canceled:
promise.TrySetCanceled();
break;
case TaskStatus.Faulted:
promise.TrySetException(writeTask.Exception);
break;
default:
throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status);
}
callback?.Invoke(promise.Task);
}
public override void EndWrite(IAsyncResult asyncResult)
{
if (asyncResult is SynchronousAsyncResult<int>)
{
return;
}
Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult);
try
{
((Task<int>)asyncResult).Wait();
}
catch (AggregateException ex)
{
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw;
}
}
#endif
int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity)
{
Contract.Assert(destination != null);
lock (sourceLock)
{
int length = 0;
do
{
int readableBytes;
IByteBuffer buf = this.ownBuffer;
if (buf != null)
{
readableBytes = buf.ReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(buf.ReadableBytes, destinationCapacity);
buf.ReadBytes(destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;
if (destinationCapacity == 0)
break;
}
}
byte[] source = this.input;
if (source != null)
{
readableBytes = this.SourceReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(readableBytes, destinationCapacity);
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;
this.inputOffset += readableBytes;
}
}
} while (false);
return length;
}
}
public override void Flush()
{
// NOOP: called on SslStream.Close
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
if (disposing)
{
TaskCompletionSource<int> p = this.readCompletionSource;
this.readCompletionSource = null;
p?.TrySetResult(0);
}
}
#region plumbing
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
#endregion
#region sync result
sealed class SynchronousAsyncResult<T> : IAsyncResult
{
public T Result { get; set; }
public bool IsCompleted => true;
public WaitHandle AsyncWaitHandle
{
get { throw new InvalidOperationException("Cannot wait on a synchronous result."); }
}
public object AsyncState { get; set; }
public bool CompletedSynchronously => true;
}
#endregion
}
}
[Flags]
enum TlsHandlerState
{
Authenticating = 1,
Authenticated = 1 << 1,
FailedAuthentication = 1 << 2,
ReadRequestedBeforeAuthenticated = 1 << 3,
FlushedBeforeHandshake = 1 << 4,
AuthenticationStarted = Authenticating | Authenticated | FailedAuthentication,
AuthenticationCompleted = Authenticated | FailedAuthentication
}
static class TlsHandlerStateExtensions
{
public static bool Has(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) == testValue;
public static bool HasAny(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) != 0;
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System;
using System.Diagnostics.Contracts;
public sealed class TlsHandshakeCompletionEvent
{
public static readonly TlsHandshakeCompletionEvent Success = new TlsHandshakeCompletionEvent();
readonly Exception exception;
/// <summary>
/// Creates a new event that indicates a successful handshake.
/// </summary>
TlsHandshakeCompletionEvent()
{
this.exception = null;
}
/// <summary>
/// Creates a new event that indicates an unsuccessful handshake.
/// Use {@link #SUCCESS} to indicate a successful handshake.
/// </summary>
public TlsHandshakeCompletionEvent(Exception exception)
{
Contract.Requires(exception != null);
this.exception = exception;
}
/// <summary>
/// Return {@code true} if the handshake was successful
/// </summary>
public bool IsSuccessful => this.exception == null;
/// <summary>
/// Return the {@link Throwable} if {@link #isSuccess()} returns {@code false}
/// and so the handshake failed.
/// </summary>
public Exception Exception => this.exception;
public override string ToString()
{
Exception ex = this.Exception;
return ex == null ? "TlsHandshakeCompletionEvent(SUCCESS)" : $"TlsHandshakeCompletionEvent({ex})";
}
}
}
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace CustomMonoTlsHandler
{
using System;
using DotNetty.Buffers;
using DotNetty.Transport.Channels;
/// Utilities for TLS packets.
static class TlsUtils
{
const int MAX_PLAINTEXT_LENGTH = 16 * 1024; // 2^14
const int MAX_COMPRESSED_LENGTH = MAX_PLAINTEXT_LENGTH + 1024;
const int MAX_CIPHERTEXT_LENGTH = MAX_COMPRESSED_LENGTH + 1024;
// Header (5) + Data (2^14) + Compression (1024) + Encryption (1024) + MAC (20) + Padding (256)
public const int MAX_ENCRYPTED_PACKET_LENGTH = MAX_CIPHERTEXT_LENGTH + 5 + 20 + 256;
/// change cipher spec
public const int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
/// alert
public const int SSL_CONTENT_TYPE_ALERT = 21;
/// handshake
public const int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/// application data
public const int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/// the length of the ssl record header (in bytes)
public const int SSL_RECORD_HEADER_LENGTH = 5;
// Not enough data in buffer to parse the record length
public const int NOT_ENOUGH_DATA = -1;
// data is not encrypted
public const int NOT_ENCRYPTED = -2;
/// <summary>
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
/// the readerIndex of the given <see cref="IByteBuffer"/>.
/// </summary>
/// <param name="buffer">
/// The <see cref="IByteBuffer"/> to read from. Be aware that it must have at least
/// <see cref="SSL_RECORD_HEADER_LENGTH"/> bytes to read,
/// otherwise it will throw an <see cref="ArgumentException"/>.
/// </param>
/// <param name="offset">Offset to record start.</param>
/// <returns>
/// The length of the encrypted packet that is included in the buffer. This will
/// return <c>-1</c> if the given <see cref="IByteBuffer"/> is not encrypted at all.
/// </returns>
public static int GetEncryptedPacketLength(IByteBuffer buffer, int offset)
{
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
switch (buffer.GetByte(offset))
{
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SSL_CONTENT_TYPE_ALERT:
case SSL_CONTENT_TYPE_HANDSHAKE:
case SSL_CONTENT_TYPE_APPLICATION_DATA:
break;
default:
// SSLv2 or bad data
return -1;
}
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.GetByte(offset + 1);
if (majorVersion == 3)
{
// SSLv3 or TLS
packetLength = buffer.GetUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SSL_RECORD_HEADER_LENGTH)
{
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
return -1;
}
}
else
{
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
return -1;
}
return packetLength;
}
public static void NotifyHandshakeFailure(IChannelHandlerContext ctx, Exception cause)
{
// We have may haven written some parts of data before an exception was thrown so ensure we always flush.
// See https://github.com/netty/netty/issues/3900#issuecomment-172481830
ctx.Flush();
ctx.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause));
ctx.CloseAsync();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.