Skip to content

Instantly share code, notes, and snippets.

@bigbes
Created April 22, 2016 16:43
Show Gist options
  • Save bigbes/936329f6db379606838ea64d990af73d to your computer and use it in GitHub Desktop.
Save bigbes/936329f6db379606838ea64d990af73d to your computer and use it in GitHub Desktop.
-- msgpackffi.lua (internal file)
local ffi = require('ffi')
local buffer = require('buffer')
local builtin = ffi.C
local msgpack = require('msgpack') -- .NULL, .array_mt, .map_mt, .cfg
local MAXNESTING = 16
local int8_ptr_t = ffi.typeof('int8_t *')
local uint8_ptr_t = ffi.typeof('uint8_t *')
local uint16_ptr_t = ffi.typeof('uint16_t *')
local uint32_ptr_t = ffi.typeof('uint32_t *')
local uint64_ptr_t = ffi.typeof('uint64_t *')
local const_char_ptr_t = ffi.typeof('const char *')
ffi.cdef([[
char *
mp_encode_float(char *data, float num);
char *
mp_encode_double(char *data, double num);
float
mp_decode_float(const char **data);
double
mp_decode_double(const char **data);
union tmpint {
uint16_t u16;
uint32_t u32;
uint64_t u64;
};
]])
local strict_alignment = (jit.arch == 'arm')
local tmpint
if strict_alignment then
tmpint = ffi.new('union tmpint[1]')
end
local function bswap_u16(num)
return bit.rshift(bit.bswap(tonumber(num)), 16)
end
--------------------------------------------------------------------------------
-- Encoder
--------------------------------------------------------------------------------
local encode_ext_cdata = {}
-- Set trigger that called when encoding cdata
local function on_encode(ctype_or_udataname, callback)
if type(ctype_or_udataname) ~= "cdata" or type(callback) ~= "function" then
error("Usage: on_encode(ffi.typeof('mytype'), function(buf, obj)")
end
local ctypeid = tonumber(ffi.typeof(ctype_or_udataname))
local prev = encode_ext_cdata[ctypeid]
encode_ext_cdata[ctypeid] = callback
return prev
end
local function encode_fix(buf, code, num)
local p = buf:alloc(1)
p[0] = bit.bor(code, tonumber(num))
end
local function encode_u8(buf, code, num)
local p = buf:alloc(2)
p[0] = code
ffi.cast(uint8_ptr_t, p + 1)[0] = num
end
local encode_u16
if strict_alignment then
encode_u16 = function(buf, code, num)
tmpint[0].u16 = bswap_u16(num)
local p = buf:alloc(3)
p[0] = code
ffi.copy(p + 1, tmpint, 2)
end
else
encode_u16 = function(buf, code, num)
local p = buf:alloc(3)
p[0] = code
ffi.cast(uint16_ptr_t, p + 1)[0] = bswap_u16(num)
end
end
local encode_u32
if strict_alignment then
encode_u32 = function(buf, code, num)
tmpint[0].u32 =
ffi.cast('uint32_t', bit.bswap(tonumber(num)))
local p = buf:alloc(5)
p[0] = code
ffi.copy(p + 1, tmpint, 4)
end
else
encode_u32 = function(buf, code, num)
local p = buf:alloc(5)
p[0] = code
ffi.cast(uint32_ptr_t, p + 1)[0] =
ffi.cast('uint32_t', bit.bswap(tonumber(num)))
end
end
local encode_u64
if strict_alignment then
encode_u64 = function(buf, code, num)
tmpint[0].u64 = bit.bswap(ffi.cast('uint64_t', num))
local p = buf:alloc(9)
p[0] = code
ffi.copy(p + 1, tmpint, 8)
end
else
encode_u64 = function(buf, code, num)
local p = buf:alloc(9)
p[0] = code
ffi.cast(uint64_ptr_t, p + 1)[0] = bit.bswap(ffi.cast('uint64_t', num))
end
end
local function encode_float(buf, num)
local p = buf:alloc(5)
builtin.mp_encode_float(p, num)
end
local function encode_double(buf, num)
local p = buf:alloc(9)
builtin.mp_encode_double(p, num)
end
local function encode_int(buf, num)
if num >= 0 then
if num <= 0x7f then
encode_fix(buf, 0, num)
elseif num <= 0xff then
encode_u8(buf, 0xcc, num)
elseif num <= 0xffff then
encode_u16(buf, 0xcd, num)
elseif num <= 0xffffffff then
encode_u32(buf, 0xce, num)
else
encode_u64(buf, 0xcf, 0ULL + num)
end
else
if num >= -0x20 then
encode_fix(buf, 0xe0, num)
elseif num >= -0x80 then
encode_u8(buf, 0xd0, num)
elseif num >= -0x8000 then
encode_u16(buf, 0xd1, num)
elseif num >= -0x80000000 then
encode_u32(buf, 0xd2, num)
else
encode_u64(buf, 0xd3, 0LL + num)
end
end
end
local function encode_str(buf, str)
local len = #str
buf:reserve(5 + len)
if len <= 31 then
encode_fix(buf, 0xa0, len)
elseif len <= 0xff then
encode_u8(buf, 0xd9, len)
elseif len <= 0xffff then
encode_u16(buf, 0xda, len)
else
encode_u32(buf, 0xdb, len)
end
local p = buf:alloc(len)
ffi.copy(p, str, len)
end
local function encode_array(buf, size)
if size <= 0xf then
encode_fix(buf, 0x90, size)
elseif size <= 0xffff then
encode_u16(buf, 0xdc, size)
else
encode_u32(buf, 0xdd, size)
end
end
local function encode_map(buf, size)
if size <= 0xf then
encode_fix(buf, 0x80, size)
elseif size <= 0xffff then
encode_u16(buf, 0xde, size)
else
encode_u32(buf, 0xdf, size)
end
end
local function encode_bool(buf, val)
encode_fix(buf, 0xc2, val and 1 or 0)
end
local function encode_bool_cdata(buf, val)
encode_fix(buf, 0xc2, val ~= 0 and 1 or 0)
end
local function encode_nil(buf)
local p = buf:alloc(1)
p[0] = 0xc0
end
local function encode_r(buf, obj, level)
if type(obj) == "number" then
-- Lua-way to check that number is an integer
if obj % 1 == 0 and obj > -1e63 and obj < 1e64 then
encode_int(buf, obj)
else
encode_double(buf, obj)
end
elseif type(obj) == "string" then
encode_str(buf, obj)
elseif type(obj) == "table" then
if level >= MAXNESTING then -- Limit nested tables
encode_nil(buf)
return
end
if #obj > 0 then
encode_array(buf, #obj)
local i
for i=1,#obj,1 do
encode_r(buf, obj[i], level + 1)
end
else
local size = 0
local key, val
for key, val in pairs(obj) do -- goodbye, JIT
size = size + 1
end
if size == 0 then
encode_array(buf, 0) -- encode empty table as an array
return
end
encode_map(buf, size)
for key, val in pairs(obj) do
encode_r(buf, key, level + 1)
encode_r(buf, val, level + 1)
end
end
elseif obj == nil then
encode_nil(buf)
elseif type(obj) == "boolean" then
encode_bool(buf, obj)
elseif type(obj) == "cdata" then
if obj == nil then -- a workaround for nil
encode_nil(buf, obj)
return
end
local ctypeid = tonumber(ffi.typeof(obj))
local fun = encode_ext_cdata[ctypeid]
if fun ~= nil then
fun(buf, obj)
else
error("can not encode FFI type: '"..ffi.typeof(obj).."'")
end
else
error("can not encode Lua type: '"..type(obj).."'")
end
end
local function encode(obj)
local tmpbuf = buffer.IBUF_SHARED
tmpbuf:reset()
encode_r(tmpbuf, obj, 0)
local r = ffi.string(tmpbuf.rpos, tmpbuf:size())
tmpbuf:recycle()
return r
end
local function encode_ibuf(obj, ibuf)
encode_r(ibuf, obj, 0)
end
on_encode(ffi.typeof('uint8_t'), encode_int)
on_encode(ffi.typeof('uint16_t'), encode_int)
on_encode(ffi.typeof('uint32_t'), encode_int)
on_encode(ffi.typeof('uint64_t'), encode_int)
on_encode(ffi.typeof('int8_t'), encode_int)
on_encode(ffi.typeof('int16_t'), encode_int)
on_encode(ffi.typeof('int32_t'), encode_int)
on_encode(ffi.typeof('int64_t'), encode_int)
on_encode(ffi.typeof('char'), encode_int)
on_encode(ffi.typeof('const char'), encode_int)
on_encode(ffi.typeof('unsigned char'), encode_int)
on_encode(ffi.typeof('const unsigned char'), encode_int)
on_encode(ffi.typeof('bool'), encode_bool_cdata)
on_encode(ffi.typeof('float'), encode_float)
on_encode(ffi.typeof('double'), encode_double)
--------------------------------------------------------------------------------
-- Decoder
--------------------------------------------------------------------------------
local decode_r
-- See similar constants in utils.cc
local DBL_INT_MAX = 1e14 - 1
local DBL_INT_MIN = -1e14 + 1
local function decode_u8(data)
local num = ffi.cast(uint8_ptr_t, data[0])[0]
data[0] = data[0] + 1
return tonumber(num)
end
local decode_u16
if strict_alignment then
decode_u16 = function(data)
ffi.copy(tmpint, data[0], 2)
data[0] = data[0] + 2
return tonumber(bswap_u16(tmpint[0].u16))
end
else
decode_u16 = function(data)
local num = bswap_u16(ffi.cast(uint16_ptr_t, data[0])[0])
data[0] = data[0] + 2
return tonumber(num)
end
end
local decode_u32
if strict_alignment then
decode_u32 = function(data)
ffi.copy(tmpint, data[0], 4)
data[0] = data[0] + 4
return tonumber(
ffi.cast('uint32_t', bit.bswap(tonumber(tmpint[0].u32))))
end
else
decode_u32 = function(data)
local num = ffi.cast('uint32_t',
bit.bswap(tonumber(ffi.cast(uint32_ptr_t, data[0])[0])))
data[0] = data[0] + 4
return tonumber(num)
end
end
local decode_u64
if strict_alignment then
decode_u64 = function(data)
ffi.copy(tmpint, data[0], 8);
data[0] = data[0] + 8
local num = bit.bswap(tmpint[0].u64)
if num <= DBL_INT_MAX then
return tonumber(num) -- return as 'number'
end
return num -- return as 'cdata'
end
else
decode_u64 = function(data)
local num =
bit.bswap(ffi.cast(uint64_ptr_t, data[0])[0])
data[0] = data[0] + 8
if num <= DBL_INT_MAX then
return tonumber(num) -- return as 'number'
end
return num -- return as 'cdata'
end
end
local function decode_i8(data)
local num = ffi.cast(int8_ptr_t, data[0])[0]
data[0] = data[0] + 1
return tonumber(num)
end
local decode_i16
if strict_alignment then
decode_i16 = function(data)
ffi.copy(tmpint, data[0], 2)
local num = bswap_u16(tmpint[0].u16)
data[0] = data[0] + 2
-- note: this double cast is actually necessary
return tonumber(ffi.cast('int16_t', ffi.cast('uint16_t', num)))
end
else
decode_i16 = function(data)
local num = bswap_u16(ffi.cast(uint16_ptr_t, data[0])[0])
data[0] = data[0] + 2
-- note: this double cast is actually necessary
return tonumber(ffi.cast('int16_t', ffi.cast('uint16_t', num)))
end
end
local decode_i32
if strict_alignment then
decode_i32 = function(data)
ffi.copy(tmpint, data[0], 4)
local num = bit.bswap(tonumber(tmpint[0].u32))
data[0] = data[0] + 4
return num
end
else
decode_i32 = function(data)
local num = bit.bswap(tonumber(ffi.cast(uint32_ptr_t, data[0])[0]))
data[0] = data[0] + 4
return num
end
end
local decode_i64
if strict_alignment then
decode_i64 = function(data)
ffi.copy(tmpint, data[0], 8)
data[0] = data[0] + 8
local num = bit.bswap(ffi.cast('int64_t', tmpint[0].u64))
if num >= -DBL_INT_MAX and num <= DBL_INT_MAX then
return tonumber(num) -- return as 'number'
end
return num -- return as 'cdata'
end
else
decode_i64 = function(data)
local num = bit.bswap(ffi.cast('int64_t',
ffi.cast(uint64_ptr_t, data[0])[0]))
data[0] = data[0] + 8
if num >= -DBL_INT_MAX and num <= DBL_INT_MAX then
return tonumber(num) -- return as 'number'
end
return num -- return as 'cdata'
end
end
local function decode_float(data)
data[0] = data[0] - 1 -- mp_decode_float need type code
return tonumber(builtin.mp_decode_float(data))
end
local function decode_double(data)
data[0] = data[0] - 1 -- mp_decode_double need type code
return tonumber(builtin.mp_decode_double(data))
end
local function decode_str(data, size)
local ret = ffi.string(data[0], size)
data[0] = data[0] + size
return ret
end
local function decode_array(data, size)
assert (type(size) == "number")
local arr = {}
local i
for i=1,size,1 do
table.insert(arr, decode_r(data))
end
if not msgpack.cfg.decode_save_metatables then
return arr
end
return setmetatable(arr, msgpack.array_mt)
end
local function decode_map(data, size)
assert (type(size) == "number")
local map = {}
local i
for i=1,size,1 do
local key = decode_r(data);
local val = decode_r(data);
map[key] = val
end
if not msgpack.cfg.decode_save_metatables then
return map
end
return setmetatable(map, msgpack.map_mt)
end
local decoder_hint = {
--[[{{{ MP_BIN]]
[0xc4] = function(data) return decode_str(data, decode_u8(data)) end;
[0xc5] = function(data) return decode_str(data, decode_u16(data)) end;
[0xc6] = function(data) return decode_str(data, decode_u32(data)) end;
--[[MP_FLOAT, MP_DOUBLE]]
[0xca] = decode_float;
[0xcb] = decode_double;
--[[MP_UINT]]
[0xcc] = decode_u8;
[0xcd] = decode_u16;
[0xce] = decode_u32;
[0xcf] = decode_u64;
--[[MP_INT]]
[0xd0] = decode_i8;
[0xd1] = decode_i16;
[0xd2] = decode_i32;
[0xd3] = decode_i64;
--[[MP_STR]]
[0xd9] = function(data) return decode_str(data, decode_u8(data)) end;
[0xda] = function(data) return decode_str(data, decode_u16(data)) end;
[0xdb] = function(data) return decode_str(data, decode_u32(data)) end;
--[[MP_ARRAY]]
[0xdc] = function(data) return decode_array(data, decode_u16(data)) end;
[0xdd] = function(data) return decode_array(data, decode_u32(data)) end;
--[[MP_MAP]]
[0xde] = function(data) return decode_map(data, decode_u16(data)) end;
[0xdf] = function(data) return decode_map(data, decode_u32(data)) end;
}
decode_r = function(data)
local c = data[0][0]
data[0] = data[0] + 1
if c <= 0x7f then
return tonumber(c) -- fixint
elseif c >= 0xa0 and c <= 0xbf then
return decode_str(data, bit.band(c, 0x1f)) -- fixstr
elseif c >= 0x90 and c <= 0x9f then
return decode_array(data, bit.band(c, 0xf)) -- fixarray
elseif c >= 0x80 and c <= 0x8f then
return decode_map(data, bit.band(c, 0xf)) -- fixmap
elseif c >= 0xe0 then
return tonumber(ffi.cast('signed char',c)) -- negfixint
elseif c == 0xc0 then
return msgpack.NULL
elseif c == 0xc2 then
return false
elseif c == 0xc3 then
return true
else
local fun = decoder_hint[c];
assert (type(fun) == "function")
return fun(data)
end
end
---
-- A temporary const char ** buffer.
-- All decode_XXX functions accept const char **data as its first argument,
-- like libmsgpuck does. After decoding data[0] position is changed to the next
-- element. It is significally faster on LuaJIT to use double pointer than
-- return result, newpos.
--
local bufp = ffi.new('const unsigned char *[1]');
local function check_offset(offset, len)
if offset == nil then
return 1
end
local offset = ffi.cast('ptrdiff_t', offset)
if offset < 1 or offset > len then
error(string.format("offset = %d is out of bounds [1..%d]",
tonumber(offset), len))
end
return offset
end
-- decode_unchecked(str, offset) -> res, new_offset
-- decode_unchecked(buf) -> res, new_buf
local function decode_unchecked(str, offset)
if type(str) == "string" then
offset = check_offset(offset, #str)
local buf = ffi.cast(const_char_ptr_t, str)
bufp[0] = buf + offset - 1
local r = decode_r(bufp)
return r, bufp[0] - buf + 1
elseif ffi.istype(const_char_ptr_t, str) then
bufp[0] = str
local r = decode_r(bufp)
return r, bufp[0]
else
error("msgpackffi.decode_unchecked(str, offset) -> res, new_offset | "..
"msgpackffi.decode_unchecked(const char *buf) -> res, new_buf")
end
end
--------------------------------------------------------------------------------
-- exports
--------------------------------------------------------------------------------
return {
NULL = msgpack.NULL;
array_mt = msgpack.array_mt;
map_mt = msgpack.map_mt;
encode = encode;
on_encode = on_encode;
decode_unchecked = decode_unchecked;
decode = decode_unchecked; -- just for tests
internal = {
encode_fix = encode_fix;
encode_array = encode_array;
encode_r = encode_r;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment