Skip to content

Instantly share code, notes, and snippets.

@starwing starwing/vararg.c
Last active Mar 25, 2016

Embed
What would you like to do?
A vararg module compatible with Lua 5.3
#define LUA_LIB
#include <lua.h>
#include <lauxlib.h>
static lua_Integer posrelat(lua_Integer pos, size_t len) {
if (pos >= 0) return pos;
else if (0u - (size_t)pos > len) return 0;
else return (lua_Integer)len + pos + 1;
}
static int tuple(lua_State *L) {
int top, n = (int)lua_tointeger(L, lua_upvalueindex(1));
lua_Integer i, j;
switch (lua_type(L, 1)) {
case LUA_TNIL: /* as iterator */
i = lua_tointeger(L, 2) + 1;
if (i <= 0 || i > n) return 0;
lua_pushinteger(L, i);
lua_pushvalue(L, lua_upvalueindex(i + 1));
return 2;
case LUA_TSTRING: /* as length operator */
if (*lua_tostring(L, 1) == '#') {
lua_pushinteger(L, n);
return 1;
}
break;
case LUA_TNONE: /* get all varargs */
luaL_checkstack(L, n, "too many values");
for (i = 1; i <= n; ++i)
lua_pushvalue(L, lua_upvalueindex(i+1));
return n;
case LUA_TNUMBER: /* get/set a range */
i = posrelat(luaL_checkinteger(L, 1), n);
j = posrelat(luaL_optinteger(L, 2, i), n);
if (i > j) return 0;
n = (int)(j-i+1);
luaL_checkstack(L, n, "too many values");
if ((top = lua_gettop(L)) <= 2) { /* get */
for (; i <= j; ++i)
lua_pushvalue(L, lua_upvalueindex(i+1));
}
else {
int idx;
lua_settop(L, top = n + 2);
for (idx = 3; idx <= top; ++idx) {
lua_pushvalue(L, idx);
lua_replace(L, lua_upvalueindex(i+idx-2));
}
}
return n;
}
return luaL_argerror(L, 1, "invalid argument");
}
static int Lpack(lua_State *L) {
int n = lua_gettop(L);
if (n >= 255) luaL_error(L, "too many values to pack");
lua_pushinteger(L, n);
lua_insert(L, 1);
lua_pushcclosure(L, tuple, n+1);
return 1;
}
static int Lrange(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i, j;
if (n < 0) return 0;
i = posrelat(luaL_checkinteger(L, 1), n);
j = posrelat(luaL_checkinteger(L, 2), n);
if (i > j || j == 0) return 0;
if (j > n) luaL_checkstack(L, j-n, "range is too big");
lua_settop(L, j + 2);
return j-i+1;
}
static int Linsert(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i = posrelat(luaL_checkinteger(L, 2), n);
if (i > n) {
luaL_checkstack(L, i-n, "index is too big");
lua_settop(L, i + 1);
lua_pushvalue(L, 1);
return i;
}
lua_pushvalue(L, 1);
lua_insert(L, i + 2);
return n + 1;
}
static int Lremove(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i <= n) {
lua_remove(L, i + 1);
--n;
}
return n;
}
static int Lreplace(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i = posrelat(luaL_checkinteger(L, 2), n);
if (i > n) {
luaL_checkstack(L, i-n, "index is too big");
lua_settop(L, i + 1);
lua_pushvalue(L, 1);
return i;
}
lua_pushvalue(L, 1);
lua_replace(L, i + 2);
return n;
}
static int Lpush(lua_State *L) {
lua_pushvalue(L, 1);
return lua_gettop(L) - 1;
}
static int Lpop(lua_State *L) {
lua_pop(L, 1);
return lua_gettop(L);
}
static int Ltake(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i > n) return 0;
lua_pop(L, n-i);
return i;
}
static int Ltail(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i > n) return 0;
return n-i+1;
}
static int Lshift(lua_State *L) {
return lua_gettop(L) - 1;
}
static int Lmap(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
for (i = 2; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_pushvalue(L, i);
lua_call(L, 1, 1);
lua_replace(L, i);
}
return n-1;
}
static int Lfilter(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
for (i = 2; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_pushvalue(L, i);
lua_call(L, 1, 1);
if (!lua_toboolean(L, -1)) {
lua_remove(L, i);
--i, --n;
}
lua_pop(L, 1);
}
return n-1;
}
static int Lreduce(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
if (n <= 3) {
lua_call(L, n-1, 1);
return 1;
}
lua_pushvalue(L, 1);
lua_pushvalue(L, 2);
lua_pushvalue(L, 3);
lua_call(L, 2, 1);
for (i = 4; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_insert(L, -2);
lua_pushvalue(L, i);
lua_call(L, 2, 1);
}
return 1;
}
static int Lunpack(lua_State *L) {
int i, n = lua_gettop(L);
for (i = 1; i <= n; ++i) {
lua_pushvalue(L, i);
lua_call(L, 0, LUA_MULTRET);
}
return lua_gettop(L)-n;
}
static int Lrotate(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i, c = luaL_checkinteger(L, 1) % n;
#if LUA_VERSION_NUM >= 503
(void)i; /* unused */
lua_rotate(L, 2, (int)c);
#else
c = -c + ((c > 0) ? n+1 : 1);
if (c > 1) luaL_checkstack(L, c-1, "too many values");
for (i = 2; i <= c; ++i)
lua_pushvalue(L, i);
#endif
return n;
}
static int Lreverse(lua_State *L) {
int i, j, n = lua_gettop(L);
for (i = 1, j = n; i < j; ++i, --j) {
lua_pushvalue(L, i);
lua_pushvalue(L, j);
lua_replace(L, i);
lua_replace(L, j);
}
return n;
}
static int Lrep(lua_State *L) {
int top = lua_gettop(L), n = top - 1;
lua_Integer i, j, count = luaL_checkinteger(L, 1);
if (count <= 0) return 0;
if (n == 0) {
luaL_checkstack(L, count, "too many values");
lua_settop(L, count+1);
return count;
}
luaL_checkstack(L, n*(count-1), "too many values");
for (i = 1; i < count; ++i)
for (j = 2; j <= top; ++j)
lua_pushvalue(L, j);
return n*count;
}
LUALIB_API int luaopen_vararg(lua_State *L) {
luaL_Reg libs[] = {
{ "concat", Lunpack },
{ "unpack", Lunpack },
{ "filter", Lfilter },
{ "insert", Linsert },
{ "map", Lmap },
{ "pack", Lpack },
{ "pop", Lpop },
{ "push", Lpush },
{ "range", Lrange },
{ "reduce", Lreduce },
{ "remove", Lremove },
{ "rep", Lrep },
{ "replace", Lreplace },
{ "reverse", Lreverse },
{ "rotate", Lrotate },
{ "shift", Lshift },
{ "tail", Ltail },
{ "take", Ltake },
{ NULL, NULL }
};
#if LUA_VERSION_NUM >= 502
luaL_newlib(L, libs);
#else
luaL_register(L, "vararg", libs);
#endif
return 1;
}
/* cc: flags+='-s -O3 -mdll -DLUA_BUILD_AS_DLL' libs+='-llua53'
* cc: output='vararg.dll' */
local _G = require "_G"
local assert = _G.assert
local pcall = _G.pcall
local print = _G.print
local select = _G.select
local type = _G.type
local math = require "math"
local ceil = math.ceil
local huge = math.huge
local min = math.min
local table = require "table"
local unpack = table.unpack or _G.unpack
local vararg = require "vararg"
local pack = vararg.pack
local range = vararg.range
local insert = vararg.insert
local remove = vararg.remove
local replace = vararg.replace
local push = vararg.push
local concat = vararg.concat
local map = vararg.map
local rotate = vararg.rotate
-- auxiliary functions----------------------------------------------------------
local values = {}
local maxstack
for i = 1, huge do
if not pcall(unpack, values, 1, 2^i) then
local min, max = 2^(i-1), 2^i
while min < max do
local mid = ceil((min+max)/2)
if pcall(unpack, values, 1, mid) then
min = mid
else
max = mid-1
end
end
maxstack = max
break
end
end
for i = 1, maxstack, 2 do
values[i] = i
end
local function tpack(...)
return {..., n=select("#", ...)}
end
local function assertsame(v, i, j, ...)
local count = select("#", ...)
assert(count == j-i+1, count..","..i..","..j)
for pos = 1, count do
assert(v[i+pos-1] == select(pos, ...))
end
end
local function asserterror(expected, f, ...)
local ok, actual = pcall(f, ...)
assert(ok == false, "error was expected")
assert(actual:find(expected, 1), "wrong error, got "..actual)
end
-- test 'pack' function --------------------------------------------------------
local function testpack(...)
local v = {...}
local n = select("#", ...)
local p = pack(...)
assertsame(v, 1, n, p())
assert(n == p("#"))
for i,pv in p do assert(v[i] == pv) end
for i = 1, n do
assert(v[i] == p(i))
if n > 0 then
assert(v[i] == p(i-n-1))
end
end
for i = 1, n, 10 do
local j = i+9
assertsame(v, i, j, p(i, j))
if n > j then
assertsame(v, i, j, p(i-n-1, j-n-1))
end
end
end
testpack()
testpack({},{},{})
testpack(nil)
testpack(nil, nil)
testpack(nil, 1, nil)
testpack(unpack(values, 1, 254))
local ok, err = pcall(pack, unpack(values, 1, 255))
if ok then -- Lua version
assert(type(err) == "function")
else -- C version
assert(ok == false and err == "too many values to pack")
end
-- test 'range' function -------------------------------------------------------
local function testrange(n, ...)
local v = {...}
for c = 1, 3 do
for i = 1, n, c do
local j = min(i+c-1, n)
assertsame(v, i, j, range(i, j, ...))
local n = select("#", ...)
if n > 0 then
assertsame(v, i, j, range(i-n-1, j-n-1, ...))
end
end
end
end
local ok, err = pcall(range, 0, 0, ...)
if ok then -- Lua version
assert(err == nil)
else -- C version
assert(ok == false and err == "bad argument #1 to '?' (index out of bounds (0))")
end
testrange(10)
testrange(10, 1,2,3,4,5,6,7,8,9,0)
maxstack = 10000 -- use a smaller value
testrange(maxstack, unpack(values, 1, maxstack))
-- test other functions --------------------------------------------------------
assertsame({1,2,3,4,5}, 1, 5, insert(3, 3, 1,2,4,5))
assertsame({1,2,3,4,5}, 1, 5, insert(4,-1, 1,2,3,5))
assertsame({1,2,nil,4}, 1, 4, insert(4, 4, 1,2))
assertsame({nil,nil,3}, 1, 3, insert(3, 3))
assertsame({1,2,3,4,5}, 1, 5, replace(3, 3, 1,2,0,4,5))
assertsame({1,2,3,4,5}, 1, 5, replace(5,-1, 1,2,3,4,0))
assertsame({1,2,nil,4}, 1, 4, replace(4, 4, 1,2))
assertsame({nil,nil,3}, 1, 3, replace(3, 3))
assertsame({1,2,3,4,5}, 1, 5, remove( 3, 1,2,0,3,4,5))
assertsame({1,2,3,4,5}, 1, 5, remove(-1, 1,2,3,4,5,0))
assertsame({1,2,nil,4}, 1, 4, remove( 4, 1,2,nil,0,4))
assertsame({nil,nil,3}, 1, 3, remove( 3, nil,nil,0,3))
assertsame({1,2,3,4,5}, 1, 5, remove(10, 1,2,3,4,5))
assertsame({1,2,3,4,5}, 1, 5, push(5, 1,2,3,4))
assertsame({1,2,nil,4}, 1, 4, push(4, 1,2,nil))
assertsame({nil,nil,3}, 1, 3, push(3, nil,nil))
assertsame({5,1,2,3,4}, 1, 5, rotate(1, 1,2,3,4,5))
assertsame({2,3,4,5,1}, 1, 5, rotate(-1, 1,2,3,4,5))
assertsame({5,1,2,3,4}, 1, 5, rotate(11, 1,2,3,4,5))
assertsame({2,3,4,5,1}, 1, 5, rotate(-11, 1,2,3,4,5))
assertsame({1,2,3,4,5,6,7,8,9}, 1, 9, concat(pack(1,2,3),
pack(4,5,6),
pack(7,8,9)))
-- test function errors and expectional conditions ---------------------------
assertsame({"1","2","3","4","5"}, 1, 5, map(tostring, 1,2,3,4,5))
assertsame({"1","2","nil","4" }, 1, 4, map(tostring, 1,2,nil,4))
assertsame({"nil","nil","3" }, 1, 3, map(tostring, nil,nil,3))
assertsame({"1","nil","nil" }, 1, 3, map(tostring, 1,nil,nil))
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", insert)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", insert, nil)
asserterror("bad argument #2 to '[^']*' %(number expected, got nil%)", insert, nil, nil)
--asserterror("bad argument #2 to '[^']*' %(index out of bounds %(0%)%)", insert, nil, 0)
assert(select('#', insert(nil, 0)) == 1)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", replace)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", replace, nil)
asserterror("bad argument #2 to '[^']*' %(number expected, got nil%)", replace, nil, nil)
--asserterror("bad argument #2 to '[^']*' %(index out of bounds %(0%)%)", replace, nil, 0)
assert(select('#', replace(nil, 0)) == 0)
asserterror("bad argument #1 to '[^']*' %(number expected, got no value%)", remove)
asserterror("bad argument #1 to '[^']*' %(number expected, got nil%)", remove, nil)
--asserterror("bad argument #1 to '[^']*' %(index out of bounds %(0%)%)", remove, 0)
assert(select('#', remove(0)) == 0)
assertsame({}, 1, 0, push())
assertsame({nil}, 1, 1, push(nil))
assertsame({}, 1, 0, concat())
asserterror("attempt to call a nil value", concat, nil)
asserterror("bad argument #1 to '[^']*' %(value expected%)", map)
assertsame({}, 1, 0, map(nil))
asserterror("attempt to call a nil value", map, nil, nil)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.