Created
June 17, 2018 20:18
-
-
Save catid/751b2f6e6dd8a0be9771cface14cc918 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include "xxxxxxxxxProtocol.h" | |
#include "xxxxxxxxxTools.h" | |
#define ZSTD_STATIC_LINKING_ONLY /* Enable advanced API */ | |
#include "thirdparty/zstd/zstd.h" // Zstd | |
#include "thirdparty/zstd/zstd_errors.h" | |
namespace xxx { | |
//------------------------------------------------------------------------------ | |
// Compression Constants | |
/// Zstd compression level | |
static const int kCompressionLevel = 1; | |
/// Compression history buffer size | |
static const unsigned kCompressionDictBytes = 24 * 1000; | |
/// Bytes allocated per packet | |
static const unsigned kCompressionAllocateBytes = \ | |
protocol::kMaxPossibleDatagramByteLimit - protocol::kMaxOverheadBytes; | |
//------------------------------------------------------------------------------ | |
// Ring Buffer | |
template<size_t kBufferBytes> | |
class RingBuffer | |
{ | |
public: | |
/// Get a contiguous region that is at least `bytes` in size | |
XXXX_FORCE_INLINE void* Allocate(unsigned bytes) | |
{ | |
if (NextWriteOffset + bytes > kBufferBytes) { | |
NextWriteOffset = 0; | |
} | |
return Buffer + NextWriteOffset; | |
} | |
/// Commit some number of bytes up to allocated bytes | |
XXXX_FORCE_INLINE void Commit(unsigned bytes) | |
{ | |
XXXX_DEBUG_ASSERT(NextWriteOffset + bytes <= kBufferBytes); | |
NextWriteOffset += bytes; | |
} | |
protected: | |
/// Ring buffer that eats its own tail | |
uint8_t Buffer[kBufferBytes]; | |
/// Next offset to write to | |
unsigned NextWriteOffset = 0; | |
}; | |
//------------------------------------------------------------------------------ | |
// MessageCompressor | |
class MessageCompressor | |
{ | |
public: | |
Result Initialize(); | |
~MessageCompressor(); | |
/// Compress data to the destination buffer `destBuffer`. | |
/// Returns the number of bytes written in `writtenBytes`. | |
/// Returns writtenBytes = 0 if data should not be compressed | |
Result Compress( | |
const uint8_t* data, | |
unsigned bytes, | |
uint8_t* dest, | |
unsigned& writtenBytes); | |
protected: | |
/// Dictionary history used by decompressor | |
RingBuffer<kCompressionDictBytes> History; | |
/// Zstd context object used to compress packets | |
ZSTD_CCtx* CCtx = nullptr; | |
}; | |
//------------------------------------------------------------------------------ | |
// MessageDecompressor | |
struct Decompressed | |
{ | |
const uint8_t* Data; | |
unsigned Bytes; | |
}; | |
class MessageDecompressor | |
{ | |
public: | |
Result Initialize(); | |
~MessageDecompressor(); | |
/// Decompress and handle a block of messages | |
Result Decompress( | |
const void* data, | |
unsigned bytes, | |
Decompressed& decompressed); | |
/// Insert uncompressed reliable datagram | |
void InsertUncompressed( | |
const uint8_t* data, | |
unsigned bytes); | |
protected: | |
/// Dictionary history used by decompressor | |
RingBuffer<kCompressionDictBytes> History; | |
/// Zstd context object used to decompress packets | |
ZSTD_DCtx* DCtx = nullptr; | |
}; | |
//------------------------------------------------------------------------------ | |
// MessageCompressor | |
Result MessageCompressor::Initialize() | |
{ | |
CCtx = ZSTD_createCCtx(); | |
if (!CCtx) { | |
return Result("SessionOutgoing::Initialize", "ZSTD_createCCtx failed", ErrorType::Zstd); | |
} | |
const size_t estimatedPacketSize = kCompressionAllocateBytes; | |
ZSTD_parameters zParams; | |
zParams.cParams = ZSTD_getCParams( | |
kCompressionLevel, | |
estimatedPacketSize, | |
kCompressionDictBytes); | |
zParams.fParams.checksumFlag = 0; | |
zParams.fParams.contentSizeFlag = 0; | |
zParams.fParams.noDictIDFlag = 1; | |
const size_t icsResult = ZSTD_compressBegin_advanced( | |
CCtx, | |
nullptr, | |
0, | |
zParams, | |
ZSTD_CONTENTSIZE_UNKNOWN); | |
if (0 != ZSTD_isError(icsResult)) { | |
XXXX_DEBUG_BREAK(); | |
return Result("SessionOutgoing::Initialize", "ZSTD_compressBegin_advanced failed", ErrorType::Zstd, icsResult); | |
} | |
const size_t blockSizeBytes = ZSTD_getBlockSize(CCtx); | |
if (blockSizeBytes < kCompressionAllocateBytes) { | |
return Result("SessionOutgoing::Initialize", "Zstd block size is too small", ErrorType::Zstd); | |
} | |
return Result::Success(); | |
} | |
MessageCompressor::~MessageCompressor() | |
{ | |
if (CCtx) { | |
ZSTD_freeCCtx(CCtx); | |
} | |
} | |
Result MessageCompressor::Compress( | |
const uint8_t* data, | |
unsigned bytes, | |
uint8_t* destBuffer, | |
unsigned& writtenBytes) | |
{ | |
XXXX_DEBUG_ASSERT(bytes >= protocol::kMessageFrameBytes); | |
writtenBytes = 0; | |
// Insert data into history ring buffer | |
XXXX_DEBUG_ASSERT(kCompressionAllocateBytes >= bytes); | |
void* history = History.Allocate(kCompressionAllocateBytes); | |
memcpy(history, data, bytes); | |
History.Commit(bytes); | |
// Compress into scratch buffer, leaving room for a frame header | |
const size_t result = ZSTD_compressBlock( | |
CCtx, | |
destBuffer + protocol::kMessageFrameBytes, | |
kCompressionAllocateBytes, | |
history, | |
bytes); | |
// If no data to compress, or would require too much space, | |
// or did not produce a small enough result: | |
if (0 == result || | |
(size_t)-ZSTD_error_dstSize_tooSmall == result || | |
protocol::kMessageFrameBytes + result >= bytes) | |
{ | |
// Note: Input data was accumulated into history ring buffer | |
return Result::Success(); | |
} | |
// If compression failed: | |
if (0 != ZSTD_isError(result)) | |
{ | |
std::string reason = "ZSTD_compressBlock failed: "; | |
reason += ZSTD_getErrorName(result); | |
XXXX_DEBUG_BREAK(); | |
return Result("SessionOutgoing::compress", reason, ErrorType::Zstd, result); | |
} | |
const unsigned compressedBytes = static_cast<unsigned>(result); | |
// Write Compressed frame header | |
protocol::WriteMessageFrameHeader( | |
destBuffer, | |
protocol::MessageType_Compressed, | |
compressedBytes); | |
// Compressed bytes includes the frame header | |
writtenBytes = protocol::kMessageFrameBytes + compressedBytes; | |
XXXX_DEBUG_ASSERT(writtenBytes <= kCompressionAllocateBytes); | |
return Result::Success(); | |
} | |
//------------------------------------------------------------------------------ | |
// MessageDecompressor | |
Result MessageDecompressor::Initialize() | |
{ | |
DCtx = ZSTD_createDCtx(); | |
if (!DCtx) { | |
return Result("SessionIncoming::Initialize", "ZSTD_createDCtx failed", ErrorType::Zstd); | |
} | |
const size_t beginResult = ZSTD_decompressBegin(DCtx); | |
if (0 != ZSTD_isError(beginResult)) { | |
return Result("SessionIncoming::Initialize", "ZSTD_decompressBegin failed", ErrorType::Zstd, beginResult); | |
} | |
return Result::Success(); | |
} | |
MessageDecompressor::~MessageDecompressor() | |
{ | |
if (DCtx) { | |
ZSTD_freeDCtx(DCtx); | |
} | |
} | |
void MessageDecompressor::InsertUncompressed( | |
const uint8_t* data, | |
unsigned bytes) | |
{ | |
if (bytes > kCompressionAllocateBytes) { | |
XXXX_DEBUG_BREAK(); // Invalid input | |
return; | |
} | |
void* history = History.Allocate(kCompressionAllocateBytes); | |
memcpy(history, data, bytes); | |
ZSTD_insertBlock(DCtx, history, bytes); | |
History.Commit(bytes); | |
} | |
Result MessageDecompressor::Decompress( | |
const void* data, | |
unsigned bytes, | |
Decompressed& decompressed) | |
{ | |
// Decompress data into history ring buffer | |
void* history = History.Allocate(kCompressionAllocateBytes); | |
const size_t result = ZSTD_decompressBlock( | |
DCtx, | |
history, | |
kCompressionAllocateBytes, | |
data, | |
bytes); | |
// If decompression failed: | |
if (0 == result || 0 != ZSTD_isError(result)) | |
{ | |
std::string reason = "ZSTD_decompressBlock failed: "; | |
reason += ZSTD_getErrorName(result); | |
XXXX_DEBUG_BREAK(); | |
return Result("SessionOutgoing::decompress", reason, ErrorType::Zstd, result); | |
} | |
const uint8_t* datagramData = reinterpret_cast<uint8_t*>(history); | |
const unsigned datagramBytes = static_cast<unsigned>(result); | |
History.Commit(datagramBytes); | |
decompressed.Data = datagramData; | |
decompressed.Bytes = datagramBytes; | |
return Result::Success(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment