Skip to content

Instantly share code, notes, and snippets.

@neremin
Last active August 29, 2015 14:15
Show Gist options
  • Save neremin/becc8ddff220be98fa7d to your computer and use it in GitHub Desktop.
Save neremin/becc8ddff220be98fa7d to your computer and use it in GitHub Desktop.
public static class BytesHelper
{
public static bool Equals(byte[] arrayA, byte[] arrayB)
{
Contract.Requires<ArgumentNullException>(arrayA != null);
Contract.Requires<ArgumentNullException>(arrayB != null);
var length = arrayA.Length;
return length == arrayB.Length && UnsafeEquals(arrayA, 0, arrayB, 0, length);
}
public static bool Equals(byte[] arrayA, int startIndexA, byte[] arrayB, int startIndexB, int count)
{
Contract.Requires<ArgumentNullException>(arrayA != null);
Contract.Requires<ArgumentNullException>(arrayB != null);
Contract.Requires<ArgumentOutOfRangeException>(startIndexA >= 0);
Contract.Requires<ArgumentOutOfRangeException>(arrayA.Length > startIndexA + count);
Contract.Requires<ArgumentOutOfRangeException>(startIndexB >= 0);
Contract.Requires<ArgumentOutOfRangeException>(arrayB.Length > startIndexB + count);
return UnsafeEquals(arrayA, startIndexA, arrayB, startIndexB, count);
}
public static bool StreamEqualsBuffer(Stream stream, byte[] buffer, int chunkSize = 4096)
{
Contract.Requires<ArgumentNullException>(buffer != null);
return stream.Length == buffer.Length &&
StreamSegmentEqualsBufferSegment(stream, buffer, 0, buffer.Length, chunkSize);
}
public static bool StreamEqualsBufferSegment(Stream stream, byte[] buffer, int offset, int count, int chunkSize = 4096)
{
Contract.Requires<ArgumentNullException>(stream != null);
return count == stream.Length && StreamSegmentEqualsBufferSegment(stream, buffer, offset, count, chunkSize);
}
public static bool StreamSegmentEqualsBufferSegment(Stream stream, byte[] buffer, int offset, int count, int chunkSize = 4096)
{
Contract.Requires<ArgumentNullException>(stream != null);
Contract.Requires<ArgumentNullException>(buffer != null);
Contract.Requires<ArgumentException>(offset + count <= buffer.Length);
if (stream.Position + count > stream.Length)
{
return false;
}
var totalBytesRead = 0;
int bytesRead;
var maxChunkSize = Math.Min(count, chunkSize);
var streamChunk = new byte[maxChunkSize];
var nextChunkSize = maxChunkSize;
while ((bytesRead = stream.Read(streamChunk, 0, nextChunkSize)) > 0)
{
if (!UnsafeEquals(streamChunk, 0, buffer, totalBytesRead, bytesRead))
{
return false;
}
totalBytesRead += bytesRead;
count -= bytesRead;
nextChunkSize = Math.Min(count, maxChunkSize);
}
return count == 0;
}
public static bool StreamSegmentEquals(Stream streamA, Stream streamB, int count, int chunkSize = 4096)
{
Contract.Requires<ArgumentNullException>(streamA != null);
Contract.Requires<ArgumentNullException>(streamB != null);
Contract.Requires<ArgumentException>(count >= 0);
Contract.Requires<ArgumentException>(chunkSize > 0);
if (count == 0)
{
return true;
}
var totalBytesRead = 0;
int bytesReadA = 0, bytesReadB = 0;
var maxChunkSize = Math.Min(count, chunkSize);
var aChunk = new byte[maxChunkSize];
var bChunk = new byte[maxChunkSize];
var nextChunkLength = maxChunkSize;
while (count > 0 &&
(bytesReadA = streamA.Read(aChunk, 0, nextChunkLength)) > 0 &&
(bytesReadB = streamB.Read(bChunk, 0, nextChunkLength)) > 0 &&
bytesReadA == bytesReadB)
{
if (!UnsafeEquals(aChunk, 0, bChunk, 0, bytesReadA))
{
return false;
}
totalBytesRead += bytesReadA;
count -= bytesReadA;
nextChunkLength = Math.Min(maxChunkSize, count);
}
return count == 0;
}
/// <devdoc>
/// Adapted from .NET's EqualsHelper in String.cs.
/// </devdoc>
[SecuritySafeCritical]
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
static unsafe bool UnsafeEquals(byte[] arrayA, int startIndexA, byte[] arrayB, int startIndexB, int bytesCount)
{
if (bytesCount == 0)
{
return true;
}
fixed (byte* ap = &arrayA[0])
fixed (byte* bp = &arrayB[0])
{
if (ap == bp)
{
return true;
}
byte* a = ap + startIndexA;
byte* b = bp + startIndexB;
// unroll the loop
if (IntPtr.Size == sizeof(long))
{
// for AMD64 bit platform we unroll by 12 and
// check 3 qword at a time. This is less code
// than the 32 bit case and is shorter
// path length
const int step = 3;
const int stepInBytes = step * sizeof(long);
var al = (long*) a;
var bl = (long*) b;
while (bytesCount >= stepInBytes)
{
if (al[0] != bl[0]) return false;
if (al[1] != bl[1]) return false;
if (al[2] != bl[2]) return false;
al += step; bl += step; bytesCount -= stepInBytes;
}
a = (byte*) al;
b = (byte*) bl;
}
else
{
const int step = 5;
const int stepInBytes = step * sizeof(int);
var ai = (int*) a;
var bi = (int*) b;
while (bytesCount >= stepInBytes)
{
if (ai[0] != bi[0]) return false;
if (ai[1] != bi[1]) return false;
if (ai[2] != bi[2]) return false;
if (ai[3] != bi[3]) return false;
if (ai[4] != bi[4]) return false;
ai += step; bi += step; bytesCount -= stepInBytes;
}
a = (byte*) ai;
b = (byte*) bi;
}
while (bytesCount >= sizeof(int))
{
if (*(int*)a != *(int*)b) return false;
a += sizeof(int); b += sizeof(int); bytesCount -= sizeof(int);
}
while (bytesCount > 0)
{
if (*a != *b) return false;
++a; ++b; --bytesCount;
}
return true;
}
}
}
[TestFixture]
public class BytesHelperTests
{
[TestCaseSource(typeof (ByteSequences))]
public bool Equals(byte[] s1, byte[] s2)
{
return BytesHelper.Equals(s1, s2);
}
[TestCaseSource(typeof(ByteSequences))]
public bool StreamEqualsBuffer(byte[] s1, byte[] s2)
{
return BytesHelper.StreamEqualsBuffer(new MemoryStream(s1), s2);
}
[TestCaseSource(typeof(StreamSegmentsSequences))]
public bool StreamSegmentEquals(byte[] s1, byte[] s2, int count, int chunkSize)
{
return BytesHelper.StreamSegmentEquals(new MemoryStream(s1), new MemoryStream(s2), count, chunkSize);
}
sealed class StreamSegmentsSequences : IEnumerable
{
public IEnumerator GetEnumerator()
{
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 1).Returns(true);
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 16).Returns(true);
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 1).Returns(true);
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 16).Returns(true);
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3 }, 3, 1).Returns(true);
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3 }, 3, 16).Returns(true);
foreach (TestCaseData sequence in (new ByteSequences()))
{
var s1 = (byte[])sequence.Arguments[0];
var s2 = (byte[])sequence.Arguments[1];
yield return new TestCaseData(s1, s2, Math.Max(s1.Length, s2.Length), 1).Returns(sequence.Result);
yield return new TestCaseData(s1, s2, Math.Max(s1.Length, s2.Length), 4096).Returns(sequence.Result);
}
}
}
sealed class ByteSequences : IEnumerable
{
static byte[] GenerateSequence(int length)
{
var sequence = new byte[length];
for (int i = 0; i < length; ++i)
{
sequence[i] = (byte)(i % byte.MaxValue);
}
return sequence;
}
public IEnumerator GetEnumerator()
{
yield return new TestCaseData(new byte[0], new byte[0]).Returns(true);
yield return new TestCaseData(new byte[] { 2 }, new byte[] { 1 }).Returns(false);
yield return new TestCaseData(new byte[] { 1, 2 }, new byte[] { 1, 1 }).Returns(false);
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 2 }).Returns(false);
yield return new TestCaseData(new byte[] { 1, 2, 3, 4 }, new byte[] { 1, 2, 3, 3 }).Returns(false);
yield return new TestCaseData(GenerateSequence(1), GenerateSequence(2)).Returns(false);
yield return new TestCaseData(GenerateSequence(2), GenerateSequence(1)).Returns(false);
yield return new TestCaseData(GenerateSequence(5), GenerateSequence(5)).Returns(true);
yield return new TestCaseData(GenerateSequence(27), GenerateSequence(27)).Returns(true);
yield return new TestCaseData(GenerateSequence(13), GenerateSequence(13)).Returns(true);
yield return new TestCaseData(GenerateSequence(15), GenerateSequence(15)).Returns(true);
yield return new TestCaseData(GenerateSequence(19), GenerateSequence(19)).Returns(true);
yield return new TestCaseData(GenerateSequence(16), GenerateSequence(16)).Returns(true);
yield return new TestCaseData(GenerateSequence(20), GenerateSequence(20)).Returns(true);
yield return new TestCaseData(GenerateSequence(29), GenerateSequence(29)).Returns(true);
yield return new TestCaseData(GenerateSequence(29), GenerateSequence(29)).Returns(true);
yield return new TestCaseData(GenerateSequence(32), GenerateSequence(32)).Returns(true);
yield return new TestCaseData(GenerateSequence(48), GenerateSequence(48)).Returns(true);
yield return new TestCaseData(GenerateSequence(39), GenerateSequence(39).Reverse().ToArray()).Returns(false);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment