Last active
November 21, 2017 03:11
-
-
Save halter73/8d653817fea22fea9ba1292df4b7e0c5 to your computer and use it in GitHub Desktop.
Bad Request Diagnostic Connection Adapter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Buffers; | |
using System.IO; | |
using System.Text; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; | |
using Microsoft.Extensions.Logging; | |
namespace BadRequestDiagnosticAdapter | |
{ | |
public class BadRequestDiagnosticAdapter : IConnectionAdapter | |
{ | |
private readonly ILogger _logger; | |
private readonly int _bufferSize; | |
public BadRequestDiagnosticAdapter(ILogger logger, int bufferSize) | |
{ | |
_logger = logger; | |
_bufferSize = bufferSize; | |
} | |
public bool IsHttps => false; | |
public Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context) | |
{ | |
return Task.FromResult<IAdaptedConnection>( | |
new BadRequestDiagnosticAdapted(context.ConnectionStream, _logger, _bufferSize)); | |
} | |
private class BadRequestDiagnosticAdapted : IAdaptedConnection | |
{ | |
private readonly BadRequestDiagnosticStream _diagnosticStream; | |
public BadRequestDiagnosticAdapted(Stream inner, ILogger logger, int bufferSize) | |
{ | |
_diagnosticStream = new BadRequestDiagnosticStream(inner, logger, bufferSize); | |
} | |
public Stream ConnectionStream => _diagnosticStream; | |
public void Dispose() | |
{ | |
_diagnosticStream.Dispose(); | |
} | |
} | |
private class BadRequestDiagnosticStream : Stream | |
{ | |
private readonly Stream _inner; | |
private readonly ILogger _logger; | |
private readonly int _bufferSize; | |
private readonly byte[] _buffer; | |
private readonly object _bufferLock = new object(); | |
private bool _empty = true; | |
private int _head; | |
private int _tail; | |
private int _searchOffset; | |
private bool _disposed; | |
public BadRequestDiagnosticStream(Stream inner, ILogger logger, int bufferSize) | |
{ | |
_inner = inner; | |
_logger = logger; | |
_bufferSize = bufferSize; | |
_buffer = ArrayPool<byte>.Shared.Rent(bufferSize); | |
} | |
public override bool CanRead => _inner.CanRead; | |
public override bool CanSeek => _inner.CanSeek; | |
public override bool CanWrite => _inner.CanWrite; | |
public override long Length => _inner.Length; | |
public override long Position | |
{ | |
get => _inner.Position; | |
set => _inner.Position = value; | |
} | |
public override void Flush() | |
{ | |
_inner.Flush(); | |
} | |
public override Task FlushAsync(CancellationToken cancellationToken) | |
{ | |
return _inner.FlushAsync(cancellationToken); | |
} | |
public override int Read(byte[] buffer, int offset, int count) | |
{ | |
var read = _inner.Read(buffer, offset, count); | |
Copy(buffer, offset, read); | |
return read; | |
} | |
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |
{ | |
var read = await _inner.ReadAsync(buffer, offset, count, cancellationToken); | |
Copy(buffer, offset, read); | |
return read; | |
} | |
public override long Seek(long offset, SeekOrigin origin) | |
{ | |
return _inner.Seek(offset, origin); | |
} | |
public override void SetLength(long value) | |
{ | |
_inner.SetLength(value); | |
} | |
public override void Write(byte[] buffer, int offset, int count) | |
{ | |
Test400(buffer, offset, count); | |
_inner.Write(buffer, offset, count); | |
} | |
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |
{ | |
Test400(buffer, offset, count); | |
return _inner.WriteAsync(buffer, offset, count, cancellationToken); | |
} | |
public override string ToString() | |
{ | |
lock (_bufferLock) | |
{ | |
if (_empty) | |
{ | |
return string.Empty; | |
} | |
var builder = new StringBuilder(_bufferSize * 4 + 14); | |
var head = _head; | |
builder.Append("[HEX] "); | |
do | |
{ | |
builder.Append(_buffer[head].ToString("X2")); | |
builder.Append(" "); | |
head = (head + 1) % _bufferSize; | |
} while (head != _tail); | |
builder.AppendLine(); | |
head = _head; | |
builder.Append("[RAW] "); | |
do | |
{ | |
builder.Append((char)_buffer[head]); | |
head = (head + 1) % _bufferSize; | |
} while (head != _tail); | |
return builder.ToString(); | |
} | |
} | |
protected override void Dispose(bool disposing) | |
{ | |
if (!_disposed) | |
{ | |
_disposed = true; | |
ArrayPool<byte>.Shared.Return(_buffer); | |
} | |
base.Dispose(disposing); | |
} | |
// The below APM methods call the underlying Read methods which will still be buffered. | |
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) | |
{ | |
var task = ReadAsync(buffer, offset, count, default(CancellationToken), state); | |
if (callback != null) | |
{ | |
task.ContinueWith((t, state2) => ((AsyncCallback) state2).Invoke(t), callback); | |
} | |
return task; | |
} | |
public override int EndRead(IAsyncResult asyncResult) | |
{ | |
return ((Task<int>) asyncResult).GetAwaiter().GetResult(); | |
} | |
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) | |
{ | |
var tcs = new TaskCompletionSource<int>(state); | |
var task = ReadAsync(buffer, offset, count, cancellationToken); | |
task.ContinueWith((task2, state2) => | |
{ | |
var tcs2 = (TaskCompletionSource<int>) state2; | |
if (task2.IsCanceled) | |
{ | |
tcs2.SetCanceled(); | |
} | |
else if (task2.IsFaulted) | |
{ | |
tcs2.SetException(task2.Exception); | |
} | |
else | |
{ | |
tcs2.SetResult(task2.Result); | |
} | |
}, tcs, cancellationToken); | |
return tcs.Task; | |
} | |
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) | |
{ | |
var task = WriteAsync(buffer, offset, count, default(CancellationToken), state); | |
if (callback != null) | |
{ | |
task.ContinueWith((t, state2) => ((AsyncCallback) state2).Invoke(t), callback); | |
} | |
return task; | |
} | |
public override void EndWrite(IAsyncResult asyncResult) | |
{ | |
((Task<object>) asyncResult).GetAwaiter().GetResult(); | |
} | |
private Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) | |
{ | |
var tcs = new TaskCompletionSource<object>(state); | |
var task = WriteAsync(buffer, offset, count, cancellationToken); | |
task.ContinueWith((task2, state2) => | |
{ | |
var tcs2 = (TaskCompletionSource<object>) state2; | |
if (task2.IsCanceled) | |
{ | |
tcs2.SetCanceled(); | |
} | |
else if (task2.IsFaulted) | |
{ | |
tcs2.SetException(task2.Exception); | |
} | |
else | |
{ | |
tcs2.SetResult(null); | |
} | |
}, tcs, cancellationToken); | |
return tcs.Task; | |
} | |
private void Copy(byte[] buffer, int offset, int count) | |
{ | |
if (count == 0) | |
{ | |
return; | |
} | |
lock (_bufferLock) | |
{ | |
int totalCopyCount, sourceOffset; | |
if (count < _bufferSize) | |
{ | |
totalCopyCount = count; | |
sourceOffset = offset; | |
} | |
else | |
{ | |
totalCopyCount = _bufferSize; | |
sourceOffset = offset + count - _bufferSize; | |
} | |
if (totalCopyCount <= _bufferSize - _tail) | |
{ | |
Buffer.BlockCopy(buffer, sourceOffset, _buffer, _tail, totalCopyCount); | |
if (!_empty && _head == _tail) | |
{ | |
_head = _tail + totalCopyCount; | |
} | |
_tail += totalCopyCount; | |
} | |
else | |
{ | |
var firstCopyCount = _bufferSize - _tail; | |
var secondCopyCount = totalCopyCount - firstCopyCount; | |
Buffer.BlockCopy(buffer, sourceOffset, _buffer, _tail, firstCopyCount); | |
Buffer.BlockCopy(buffer, sourceOffset + firstCopyCount, _buffer, 0, secondCopyCount); | |
_head = _tail = secondCopyCount; | |
} | |
_empty = false; | |
} | |
} | |
private void Test400(byte[] buffer, int offset, int count) | |
{ | |
// O(n). Only works because the first character in searchString is not repeated. | |
const string searchString = "HTTP/1.1 400 Bad Request\r\n"; | |
var head = offset; | |
var tail = offset + count; | |
// Start in the middle of the sarch string if that's where we left off in the last buffer. | |
var searchIndex = _searchOffset; | |
while (head < tail) | |
{ | |
while (searchIndex < searchString.Length && head < tail) | |
{ | |
if (buffer[head] == searchString[searchIndex]) | |
{ | |
head++; | |
searchIndex++; | |
} | |
else if (searchIndex == 0) | |
{ | |
head++; | |
} | |
else | |
{ | |
searchIndex = 0; | |
} | |
} | |
if (searchIndex == searchString.Length) | |
{ | |
lock (_bufferLock) | |
{ | |
var bytesBuffered = _head == _tail ? _bufferSize : _tail - _head; | |
_logger.LogError( | |
"Observed 400 response. The last {bytesBuffered} bytes of request data were: {newLine}{buffer}", | |
bytesBuffered, Environment.NewLine, this); | |
} | |
searchIndex = 0; | |
} | |
} | |
_searchOffset = searchIndex; | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Net; | |
using Microsoft.AspNetCore; | |
using Microsoft.AspNetCore.Hosting; | |
using Microsoft.Extensions.DependencyInjection; | |
using Microsoft.Extensions.Logging; | |
namespace BadRequestDiagnosticAdapter | |
{ | |
public class Program | |
{ | |
public static void Main(string[] args) | |
{ | |
BuildWebHost(args).Run(); | |
} | |
public static IWebHost BuildWebHost(string[] args) => | |
WebHost.CreateDefaultBuilder(args) | |
.UseKestrel(kestrelOptions => | |
{ | |
kestrelOptions.Listen(IPAddress.IPv6Any, 5000, listenOptions => | |
{ | |
var loggerFactory = listenOptions.KestrelServerOptions.ApplicationServices.GetRequiredService<ILoggerFactory>(); | |
var logger = loggerFactory.CreateLogger<BadRequestDiagnosticAdapter>(); | |
listenOptions.ConnectionAdapters.Add(new BadRequestDiagnosticAdapter(logger, bufferSize: 16384)); | |
}); | |
}) | |
.UseStartup<Startup>() | |
.Build(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment