Last active
August 29, 2015 14:15
-
-
Save neremin/becc8ddff220be98fa7d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static class BytesHelper | |
{ | |
public static bool Equals(byte[] arrayA, byte[] arrayB) | |
{ | |
Contract.Requires<ArgumentNullException>(arrayA != null); | |
Contract.Requires<ArgumentNullException>(arrayB != null); | |
var length = arrayA.Length; | |
return length == arrayB.Length && UnsafeEquals(arrayA, 0, arrayB, 0, length); | |
} | |
public static bool Equals(byte[] arrayA, int startIndexA, byte[] arrayB, int startIndexB, int count) | |
{ | |
Contract.Requires<ArgumentNullException>(arrayA != null); | |
Contract.Requires<ArgumentNullException>(arrayB != null); | |
Contract.Requires<ArgumentOutOfRangeException>(startIndexA >= 0); | |
Contract.Requires<ArgumentOutOfRangeException>(arrayA.Length > startIndexA + count); | |
Contract.Requires<ArgumentOutOfRangeException>(startIndexB >= 0); | |
Contract.Requires<ArgumentOutOfRangeException>(arrayB.Length > startIndexB + count); | |
return UnsafeEquals(arrayA, startIndexA, arrayB, startIndexB, count); | |
} | |
public static bool StreamEqualsBuffer(Stream stream, byte[] buffer, int chunkSize = 4096) | |
{ | |
Contract.Requires<ArgumentNullException>(buffer != null); | |
return stream.Length == buffer.Length && | |
StreamSegmentEqualsBufferSegment(stream, buffer, 0, buffer.Length, chunkSize); | |
} | |
public static bool StreamEqualsBufferSegment(Stream stream, byte[] buffer, int offset, int count, int chunkSize = 4096) | |
{ | |
Contract.Requires<ArgumentNullException>(stream != null); | |
return count == stream.Length && StreamSegmentEqualsBufferSegment(stream, buffer, offset, count, chunkSize); | |
} | |
public static bool StreamSegmentEqualsBufferSegment(Stream stream, byte[] buffer, int offset, int count, int chunkSize = 4096) | |
{ | |
Contract.Requires<ArgumentNullException>(stream != null); | |
Contract.Requires<ArgumentNullException>(buffer != null); | |
Contract.Requires<ArgumentException>(offset + count <= buffer.Length); | |
if (stream.Position + count > stream.Length) | |
{ | |
return false; | |
} | |
var totalBytesRead = 0; | |
int bytesRead; | |
var maxChunkSize = Math.Min(count, chunkSize); | |
var streamChunk = new byte[maxChunkSize]; | |
var nextChunkSize = maxChunkSize; | |
while ((bytesRead = stream.Read(streamChunk, 0, nextChunkSize)) > 0) | |
{ | |
if (!UnsafeEquals(streamChunk, 0, buffer, totalBytesRead, bytesRead)) | |
{ | |
return false; | |
} | |
totalBytesRead += bytesRead; | |
count -= bytesRead; | |
nextChunkSize = Math.Min(count, maxChunkSize); | |
} | |
return count == 0; | |
} | |
public static bool StreamSegmentEquals(Stream streamA, Stream streamB, int count, int chunkSize = 4096) | |
{ | |
Contract.Requires<ArgumentNullException>(streamA != null); | |
Contract.Requires<ArgumentNullException>(streamB != null); | |
Contract.Requires<ArgumentException>(count >= 0); | |
Contract.Requires<ArgumentException>(chunkSize > 0); | |
if (count == 0) | |
{ | |
return true; | |
} | |
var totalBytesRead = 0; | |
int bytesReadA = 0, bytesReadB = 0; | |
var maxChunkSize = Math.Min(count, chunkSize); | |
var aChunk = new byte[maxChunkSize]; | |
var bChunk = new byte[maxChunkSize]; | |
var nextChunkLength = maxChunkSize; | |
while (count > 0 && | |
(bytesReadA = streamA.Read(aChunk, 0, nextChunkLength)) > 0 && | |
(bytesReadB = streamB.Read(bChunk, 0, nextChunkLength)) > 0 && | |
bytesReadA == bytesReadB) | |
{ | |
if (!UnsafeEquals(aChunk, 0, bChunk, 0, bytesReadA)) | |
{ | |
return false; | |
} | |
totalBytesRead += bytesReadA; | |
count -= bytesReadA; | |
nextChunkLength = Math.Min(maxChunkSize, count); | |
} | |
return count == 0; | |
} | |
/// <devdoc> | |
/// Adapted from .NET's EqualsHelper in String.cs. | |
/// </devdoc> | |
[SecuritySafeCritical] | |
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] | |
static unsafe bool UnsafeEquals(byte[] arrayA, int startIndexA, byte[] arrayB, int startIndexB, int bytesCount) | |
{ | |
if (bytesCount == 0) | |
{ | |
return true; | |
} | |
fixed (byte* ap = &arrayA[0]) | |
fixed (byte* bp = &arrayB[0]) | |
{ | |
if (ap == bp) | |
{ | |
return true; | |
} | |
byte* a = ap + startIndexA; | |
byte* b = bp + startIndexB; | |
// unroll the loop | |
if (IntPtr.Size == sizeof(long)) | |
{ | |
// for AMD64 bit platform we unroll by 12 and | |
// check 3 qword at a time. This is less code | |
// than the 32 bit case and is shorter | |
// path length | |
const int step = 3; | |
const int stepInBytes = step * sizeof(long); | |
var al = (long*) a; | |
var bl = (long*) b; | |
while (bytesCount >= stepInBytes) | |
{ | |
if (al[0] != bl[0]) return false; | |
if (al[1] != bl[1]) return false; | |
if (al[2] != bl[2]) return false; | |
al += step; bl += step; bytesCount -= stepInBytes; | |
} | |
a = (byte*) al; | |
b = (byte*) bl; | |
} | |
else | |
{ | |
const int step = 5; | |
const int stepInBytes = step * sizeof(int); | |
var ai = (int*) a; | |
var bi = (int*) b; | |
while (bytesCount >= stepInBytes) | |
{ | |
if (ai[0] != bi[0]) return false; | |
if (ai[1] != bi[1]) return false; | |
if (ai[2] != bi[2]) return false; | |
if (ai[3] != bi[3]) return false; | |
if (ai[4] != bi[4]) return false; | |
ai += step; bi += step; bytesCount -= stepInBytes; | |
} | |
a = (byte*) ai; | |
b = (byte*) bi; | |
} | |
while (bytesCount >= sizeof(int)) | |
{ | |
if (*(int*)a != *(int*)b) return false; | |
a += sizeof(int); b += sizeof(int); bytesCount -= sizeof(int); | |
} | |
while (bytesCount > 0) | |
{ | |
if (*a != *b) return false; | |
++a; ++b; --bytesCount; | |
} | |
return true; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[TestFixture] | |
public class BytesHelperTests | |
{ | |
[TestCaseSource(typeof (ByteSequences))] | |
public bool Equals(byte[] s1, byte[] s2) | |
{ | |
return BytesHelper.Equals(s1, s2); | |
} | |
[TestCaseSource(typeof(ByteSequences))] | |
public bool StreamEqualsBuffer(byte[] s1, byte[] s2) | |
{ | |
return BytesHelper.StreamEqualsBuffer(new MemoryStream(s1), s2); | |
} | |
[TestCaseSource(typeof(StreamSegmentsSequences))] | |
public bool StreamSegmentEquals(byte[] s1, byte[] s2, int count, int chunkSize) | |
{ | |
return BytesHelper.StreamSegmentEquals(new MemoryStream(s1), new MemoryStream(s2), count, chunkSize); | |
} | |
sealed class StreamSegmentsSequences : IEnumerable | |
{ | |
public IEnumerator GetEnumerator() | |
{ | |
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 1).Returns(true); | |
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 16).Returns(true); | |
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 1).Returns(true); | |
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 3, 3, 6 }, 3, 16).Returns(true); | |
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3 }, 3, 1).Returns(true); | |
yield return new TestCaseData(new byte[] { 1, 2, 3, 4, 5 }, new byte[] { 1, 2, 3 }, 3, 16).Returns(true); | |
foreach (TestCaseData sequence in (new ByteSequences())) | |
{ | |
var s1 = (byte[])sequence.Arguments[0]; | |
var s2 = (byte[])sequence.Arguments[1]; | |
yield return new TestCaseData(s1, s2, Math.Max(s1.Length, s2.Length), 1).Returns(sequence.Result); | |
yield return new TestCaseData(s1, s2, Math.Max(s1.Length, s2.Length), 4096).Returns(sequence.Result); | |
} | |
} | |
} | |
sealed class ByteSequences : IEnumerable | |
{ | |
static byte[] GenerateSequence(int length) | |
{ | |
var sequence = new byte[length]; | |
for (int i = 0; i < length; ++i) | |
{ | |
sequence[i] = (byte)(i % byte.MaxValue); | |
} | |
return sequence; | |
} | |
public IEnumerator GetEnumerator() | |
{ | |
yield return new TestCaseData(new byte[0], new byte[0]).Returns(true); | |
yield return new TestCaseData(new byte[] { 2 }, new byte[] { 1 }).Returns(false); | |
yield return new TestCaseData(new byte[] { 1, 2 }, new byte[] { 1, 1 }).Returns(false); | |
yield return new TestCaseData(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 2 }).Returns(false); | |
yield return new TestCaseData(new byte[] { 1, 2, 3, 4 }, new byte[] { 1, 2, 3, 3 }).Returns(false); | |
yield return new TestCaseData(GenerateSequence(1), GenerateSequence(2)).Returns(false); | |
yield return new TestCaseData(GenerateSequence(2), GenerateSequence(1)).Returns(false); | |
yield return new TestCaseData(GenerateSequence(5), GenerateSequence(5)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(27), GenerateSequence(27)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(13), GenerateSequence(13)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(15), GenerateSequence(15)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(19), GenerateSequence(19)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(16), GenerateSequence(16)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(20), GenerateSequence(20)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(29), GenerateSequence(29)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(29), GenerateSequence(29)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(32), GenerateSequence(32)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(48), GenerateSequence(48)).Returns(true); | |
yield return new TestCaseData(GenerateSequence(39), GenerateSequence(39).Reverse().ToArray()).Returns(false); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment