Skip to content

Instantly share code, notes, and snippets.

@catid
Created June 17, 2018 20:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catid/751b2f6e6dd8a0be9771cface14cc918 to your computer and use it in GitHub Desktop.
Save catid/751b2f6e6dd8a0be9771cface14cc918 to your computer and use it in GitHub Desktop.
#pragma once
#include "xxxxxxxxxProtocol.h"
#include "xxxxxxxxxTools.h"
#define ZSTD_STATIC_LINKING_ONLY /* Enable advanced API */
#include "thirdparty/zstd/zstd.h" // Zstd
#include "thirdparty/zstd/zstd_errors.h"
namespace xxx {
//------------------------------------------------------------------------------
// Compression Constants
/// Zstd compression level
static const int kCompressionLevel = 1;
/// Compression history buffer size
static const unsigned kCompressionDictBytes = 24 * 1000;
/// Bytes allocated per packet
static const unsigned kCompressionAllocateBytes = \
protocol::kMaxPossibleDatagramByteLimit - protocol::kMaxOverheadBytes;
//------------------------------------------------------------------------------
// Ring Buffer
template<size_t kBufferBytes>
class RingBuffer
{
public:
/// Get a contiguous region that is at least `bytes` in size
XXXX_FORCE_INLINE void* Allocate(unsigned bytes)
{
if (NextWriteOffset + bytes > kBufferBytes) {
NextWriteOffset = 0;
}
return Buffer + NextWriteOffset;
}
/// Commit some number of bytes up to allocated bytes
XXXX_FORCE_INLINE void Commit(unsigned bytes)
{
XXXX_DEBUG_ASSERT(NextWriteOffset + bytes <= kBufferBytes);
NextWriteOffset += bytes;
}
protected:
/// Ring buffer that eats its own tail
uint8_t Buffer[kBufferBytes];
/// Next offset to write to
unsigned NextWriteOffset = 0;
};
//------------------------------------------------------------------------------
// MessageCompressor
class MessageCompressor
{
public:
Result Initialize();
~MessageCompressor();
/// Compress data to the destination buffer `destBuffer`.
/// Returns the number of bytes written in `writtenBytes`.
/// Returns writtenBytes = 0 if data should not be compressed
Result Compress(
const uint8_t* data,
unsigned bytes,
uint8_t* dest,
unsigned& writtenBytes);
protected:
/// Dictionary history used by decompressor
RingBuffer<kCompressionDictBytes> History;
/// Zstd context object used to compress packets
ZSTD_CCtx* CCtx = nullptr;
};
//------------------------------------------------------------------------------
// MessageDecompressor
struct Decompressed
{
const uint8_t* Data;
unsigned Bytes;
};
class MessageDecompressor
{
public:
Result Initialize();
~MessageDecompressor();
/// Decompress and handle a block of messages
Result Decompress(
const void* data,
unsigned bytes,
Decompressed& decompressed);
/// Insert uncompressed reliable datagram
void InsertUncompressed(
const uint8_t* data,
unsigned bytes);
protected:
/// Dictionary history used by decompressor
RingBuffer<kCompressionDictBytes> History;
/// Zstd context object used to decompress packets
ZSTD_DCtx* DCtx = nullptr;
};
//------------------------------------------------------------------------------
// MessageCompressor
Result MessageCompressor::Initialize()
{
CCtx = ZSTD_createCCtx();
if (!CCtx) {
return Result("SessionOutgoing::Initialize", "ZSTD_createCCtx failed", ErrorType::Zstd);
}
const size_t estimatedPacketSize = kCompressionAllocateBytes;
ZSTD_parameters zParams;
zParams.cParams = ZSTD_getCParams(
kCompressionLevel,
estimatedPacketSize,
kCompressionDictBytes);
zParams.fParams.checksumFlag = 0;
zParams.fParams.contentSizeFlag = 0;
zParams.fParams.noDictIDFlag = 1;
const size_t icsResult = ZSTD_compressBegin_advanced(
CCtx,
nullptr,
0,
zParams,
ZSTD_CONTENTSIZE_UNKNOWN);
if (0 != ZSTD_isError(icsResult)) {
XXXX_DEBUG_BREAK();
return Result("SessionOutgoing::Initialize", "ZSTD_compressBegin_advanced failed", ErrorType::Zstd, icsResult);
}
const size_t blockSizeBytes = ZSTD_getBlockSize(CCtx);
if (blockSizeBytes < kCompressionAllocateBytes) {
return Result("SessionOutgoing::Initialize", "Zstd block size is too small", ErrorType::Zstd);
}
return Result::Success();
}
MessageCompressor::~MessageCompressor()
{
if (CCtx) {
ZSTD_freeCCtx(CCtx);
}
}
Result MessageCompressor::Compress(
const uint8_t* data,
unsigned bytes,
uint8_t* destBuffer,
unsigned& writtenBytes)
{
XXXX_DEBUG_ASSERT(bytes >= protocol::kMessageFrameBytes);
writtenBytes = 0;
// Insert data into history ring buffer
XXXX_DEBUG_ASSERT(kCompressionAllocateBytes >= bytes);
void* history = History.Allocate(kCompressionAllocateBytes);
memcpy(history, data, bytes);
History.Commit(bytes);
// Compress into scratch buffer, leaving room for a frame header
const size_t result = ZSTD_compressBlock(
CCtx,
destBuffer + protocol::kMessageFrameBytes,
kCompressionAllocateBytes,
history,
bytes);
// If no data to compress, or would require too much space,
// or did not produce a small enough result:
if (0 == result ||
(size_t)-ZSTD_error_dstSize_tooSmall == result ||
protocol::kMessageFrameBytes + result >= bytes)
{
// Note: Input data was accumulated into history ring buffer
return Result::Success();
}
// If compression failed:
if (0 != ZSTD_isError(result))
{
std::string reason = "ZSTD_compressBlock failed: ";
reason += ZSTD_getErrorName(result);
XXXX_DEBUG_BREAK();
return Result("SessionOutgoing::compress", reason, ErrorType::Zstd, result);
}
const unsigned compressedBytes = static_cast<unsigned>(result);
// Write Compressed frame header
protocol::WriteMessageFrameHeader(
destBuffer,
protocol::MessageType_Compressed,
compressedBytes);
// Compressed bytes includes the frame header
writtenBytes = protocol::kMessageFrameBytes + compressedBytes;
XXXX_DEBUG_ASSERT(writtenBytes <= kCompressionAllocateBytes);
return Result::Success();
}
//------------------------------------------------------------------------------
// MessageDecompressor
Result MessageDecompressor::Initialize()
{
DCtx = ZSTD_createDCtx();
if (!DCtx) {
return Result("SessionIncoming::Initialize", "ZSTD_createDCtx failed", ErrorType::Zstd);
}
const size_t beginResult = ZSTD_decompressBegin(DCtx);
if (0 != ZSTD_isError(beginResult)) {
return Result("SessionIncoming::Initialize", "ZSTD_decompressBegin failed", ErrorType::Zstd, beginResult);
}
return Result::Success();
}
MessageDecompressor::~MessageDecompressor()
{
if (DCtx) {
ZSTD_freeDCtx(DCtx);
}
}
void MessageDecompressor::InsertUncompressed(
const uint8_t* data,
unsigned bytes)
{
if (bytes > kCompressionAllocateBytes) {
XXXX_DEBUG_BREAK(); // Invalid input
return;
}
void* history = History.Allocate(kCompressionAllocateBytes);
memcpy(history, data, bytes);
ZSTD_insertBlock(DCtx, history, bytes);
History.Commit(bytes);
}
Result MessageDecompressor::Decompress(
const void* data,
unsigned bytes,
Decompressed& decompressed)
{
// Decompress data into history ring buffer
void* history = History.Allocate(kCompressionAllocateBytes);
const size_t result = ZSTD_decompressBlock(
DCtx,
history,
kCompressionAllocateBytes,
data,
bytes);
// If decompression failed:
if (0 == result || 0 != ZSTD_isError(result))
{
std::string reason = "ZSTD_decompressBlock failed: ";
reason += ZSTD_getErrorName(result);
XXXX_DEBUG_BREAK();
return Result("SessionOutgoing::decompress", reason, ErrorType::Zstd, result);
}
const uint8_t* datagramData = reinterpret_cast<uint8_t*>(history);
const unsigned datagramBytes = static_cast<unsigned>(result);
History.Commit(datagramBytes);
decompressed.Data = datagramData;
decompressed.Bytes = datagramBytes;
return Result::Success();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment