Skip to content

Instantly share code, notes, and snippets.

@ayende
Created Jun 25, 2022
Embed
What would you like to do?
using System.Buffers;
using System.Buffers.Text;
using System.Collections.Concurrent;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Text;
var listener = new TcpListener(System.Net.IPAddress.Any, 6379);
listener.Start();
var state = new ConcurrentDictionary<object, Item>();
while (true)
{
var tcp = listener.AcceptTcpClient();
new Client(state, tcp).HandleConnection()
.ContinueWith(t =>
{
using (tcp)
{
if (t.Exception != null)
Console.WriteLine(t.Exception);
}
});
}
public class Client
{
ConcurrentDictionary<object, Item> _state;
PipeReader _netReader;
PipeWriter _netWriter;
List<ReadOnlySequence<byte>> _cmds = new();
byte[] _temp = new byte[4];
BoxedSeq _reusable = new();
public Client(ConcurrentDictionary<object, Item> state, TcpClient tcp)
{
_state = state;
var stream = tcp.GetStream();
_netReader = PipeReader.Create(stream);
_netWriter = PipeWriter.Create(stream);
}
public async Task HandleConnection()
{
while (true)
{
var result = await _netReader.ReadAsync();
var (consumed, examined) = ParseNetworkData(result);
_netReader.AdvanceTo(consumed, examined);
await _netWriter.FlushAsync();
}
}
(SequencePosition Consumed, SequencePosition Examined) ParseNetworkData(ReadResult result)
{
var reader = new SequenceReader<byte>(result.Buffer);
SequencePosition consumed;
while (true)
{
_cmds.Clear();
consumed = reader.Position;
if (reader.TryReadTo(out ReadOnlySpan<byte> line, (byte)'\n') == false)
return (consumed, result.Buffer.End);
if (line.Length == 0 || line[0] != '*' || line[line.Length - 1] != '\r')
ThrowBadBuffer(result.Buffer);
if (Utf8Parser.TryParse(line.Slice(1), out int argc, out int bytesConsumed) == false ||
bytesConsumed + 2 != line.Length) // account for the * and \r
ThrowBadBuffer(result.Buffer);
for (int i = 0; i < argc; i++)
{
if (reader.TryReadTo(out line, (byte)'\n') == false)
{
return (consumed, result.Buffer.End);
}
if (line.Length == 0 || line[0] != '$' || line[line.Length - 1] != '\r')
ThrowBadBuffer(result.Buffer);
if (Utf8Parser.TryParse(line.Slice(1), out int size, out bytesConsumed) == false ||
bytesConsumed + 2 != line.Length) // accounts for $ and \r
ThrowBadBuffer(result.Buffer);
if (size + 2 /*\r\n*/ > reader.UnreadSequence.Length)
{
return (consumed, result.Buffer.End);
}
var arg = reader.UnreadSequence.Slice(0, size);
_cmds.Add(arg);
reader.Advance(size);
if (reader.TryReadTo(out line, (byte)'\n') == false)
{
return (consumed, result.Buffer.End);
}
if (line.Length == 0 || line[0] != '\r')
ThrowBadBuffer(result.Buffer);
}
ExecCommand(_cmds);
}
}
private void ExecCommand(List<ReadOnlySequence<byte>> cmds)
{
if (cmds[0].Length != 3)
ThrowBadBuffer(cmds[0]);
ReadOnlySpan<byte> cmd;
if (cmds[0].IsSingleSegment == false)
{
cmds[0].CopyTo(_temp);
cmd = new ReadOnlySpan<byte>(_temp, 0, 3);
}
else
{
cmd = cmds[0].FirstSpan;
}
if (cmd[1] != (byte)'E' || cmd[2] != (byte)'T')
ThrowBadBuffer(cmds[0]);
_reusable.Seq = cmds[1];
if (cmd[0] == (byte)'G')
{
if (_state.TryGetValue(_reusable, out var item))
{
var result = item.Value;
var mem = _netWriter.GetMemory(result.Length + 16); // \r\n times 2 + result length
mem.Span[0] = (byte)'$';
if (Utf8Formatter.TryFormat(result.Length, mem.Span.Slice(1), out var written) == false)
ThrowImpossibleFailedWrite();
written += 1;
written += WriteEndOfLine(mem.Span, written);
result.Span.CopyTo(mem.Span.Slice(written));
written += result.Length;
written += WriteEndOfLine(mem.Span, written);
_netWriter.Advance(written);
}
else
{
WriteMissing();
}
}
else if (cmd[0] == (byte)'S')
{
var buffer = ArrayPool<byte>.Shared.Rent((int)cmds[2].Length);
cmds[2].CopyTo(buffer);
var val = new ReusableBuffer(buffer, (int)cmds[2].Length);
Item newItem;
ReusableBuffer key;
if (_state.TryGetValue(_reusable, out var item))
{
// can reuse key buffer
newItem = new Item(item.Key, val);
key = item.Key;
}
else
{
var keyBuffer = ArrayPool<byte>.Shared.Rent((int)cmds[1].Length);
cmds[1].CopyTo(keyBuffer);
key = new ReusableBuffer(keyBuffer, (int)cmds[1].Length);
newItem = new Item(key, val);
}
_state[key] = newItem;
WriteMissing();
}
else
{
ThrowBadBuffer(cmds[0]);
}
}
private int WriteEndOfLine(Span<byte> span, int offset)
{
span[offset] = (byte)'\r';
span[offset + 1] = (byte)'\n';
return 2;
}
private void WriteMissing()
{
var span = _netWriter.GetMemory(5).Span;
span[0] = (byte)'$';
span[1] = (byte)'-';
span[2] = (byte)'1';
span[3] = (byte)'\r';
span[4] = (byte)'\n';
_netWriter.Advance(5);
}
private static void ThrowImpossibleFailedWrite()
{
throw new InvalidOperationException("Unable to write to memory, impopssible");
}
void ThrowBadBuffer(ReadOnlySequence<byte> buf)
{
throw new InvalidDataException("The buffer didn't match the expected value: " + Encoding.UTF8.GetString(buf));
}
}
public class BoxedSeq
{
public ReadOnlySequence<byte> Seq;
public override int GetHashCode()
{
var hc = new HashCode();
if (Seq.IsSingleSegment)
{
hc.AddBytes(Seq.FirstSpan);
return hc.ToHashCode();
}
foreach (var mem in Seq)
{
hc.AddBytes(mem.Span);
}
return hc.ToHashCode();
}
public override bool Equals(object? obj)
{
if (obj is not BoxedSeq bs)
return false;
return SeqEquals(Seq, bs.Seq);
}
public static bool SeqEquals(ReadOnlySequence<byte> x, ReadOnlySequence<byte> y)
{
if (x.Length != y.Length)
return false;
if (x.IsSingleSegment)
{
var span = x.FirstSpan;
if (y.IsSingleSegment)
return span.SequenceEqual(y.FirstSpan);
foreach (var mem in y)
{
if (mem.Span.SequenceEqual(span.Slice(0, span.Length)) == false)
return false;
span = span.Slice(mem.Span.Length);
}
return true;
}
else
{
if (y.IsSingleSegment)
return SeqEquals(y, x);
// cheating, faster, and not going to hit this...,
// adding this for completion's sake
return SeqEquals(new ReadOnlySequence<byte>(x.ToArray()), y);
}
}
}
public class Item
{
public ReusableBuffer Key;
public ReusableBuffer Value;
public Item(ReusableBuffer key, ReusableBuffer value)
{
Key = key;
Value = value;
}
}
public class ReusableBuffer
{
public byte[] Buffer;
public int Length;
public Span<byte> Span => new Span<byte>(Buffer, 0, Length);
public ReadOnlySequence<byte> Seq => new ReadOnlySequence<byte>(Buffer, 0, Length);
public ReusableBuffer(byte[] buffer, int length)
{
Buffer = buffer;
Length = length;
}
public override bool Equals(object? obj)
{
if (obj is BoxedSeq bs)
{
return BoxedSeq.SeqEquals(bs.Seq, Seq);
}
if (obj is not ReusableBuffer o)
return false;
return o.Span.SequenceEqual(Span);
}
public override int GetHashCode()
{
var hc = new HashCode();
hc.AddBytes(Span);
return hc.ToHashCode();
}
~ReusableBuffer()
{
ArrayPool<byte>.Shared.Return(Buffer);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment