Skip to content

Instantly share code, notes, and snippets.

@starwing
Last active May 4, 2021 07:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save starwing/5893607 to your computer and use it in GitHub Desktop.
Save starwing/5893607 to your computer and use it in GitHub Desktop.
A vararg module compatible with Lua 5.3
#define LUA_LIB
#include <lua.h>
#include <lauxlib.h>
static lua_Integer posrelat(lua_Integer pos, size_t len) {
if (pos >= 0) return pos;
else if (0u - (size_t)pos > len) return 0;
else return (lua_Integer)len + pos + 1;
}
static int tuple(lua_State *L) {
int top, n = (int)lua_tointeger(L, lua_upvalueindex(1));
lua_Integer i, j;
switch (lua_type(L, 1)) {
case LUA_TNIL: /* as iterator */
i = lua_tointeger(L, 2) + 1;
if (i <= 0 || i > n) return 0;
lua_pushinteger(L, i);
lua_pushvalue(L, lua_upvalueindex(i + 1));
return 2;
case LUA_TSTRING: /* as length operator */
if (*lua_tostring(L, 1) == '#') {
lua_pushinteger(L, n);
return 1;
}
break;
case LUA_TNONE: /* get all varargs */
luaL_checkstack(L, n, "too many values");
for (i = 1; i <= n; ++i)
lua_pushvalue(L, lua_upvalueindex(i+1));
return n;
case LUA_TNUMBER: /* get/set a range */
i = posrelat(luaL_checkinteger(L, 1), n);
j = posrelat(luaL_optinteger(L, 2, i), n);
if (i > j) return 0;
n = (int)(j-i+1);
luaL_checkstack(L, n, "too many values");
if ((top = lua_gettop(L)) <= 2) { /* get */
for (; i <= j; ++i)
lua_pushvalue(L, lua_upvalueindex(i+1));
}
else {
int idx;
lua_settop(L, top = n + 2);
for (idx = 3; idx <= top; ++idx) {
lua_pushvalue(L, idx);
lua_replace(L, lua_upvalueindex(i+idx-2));
}
}
return n;
}
return luaL_argerror(L, 1, "invalid argument");
}
static int Lpack(lua_State *L) {
int n = lua_gettop(L);
if (n >= 255) luaL_error(L, "too many values to pack");
lua_pushinteger(L, n);
lua_insert(L, 1);
lua_pushcclosure(L, tuple, n+1);
return 1;
}
static int Lrange(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i, j;
if (n < 0) return 0;
i = posrelat(luaL_checkinteger(L, 1), n);
j = posrelat(luaL_checkinteger(L, 2), n);
if (i > j || j == 0) return 0;
if (j > n) luaL_checkstack(L, j-n, "range is too big");
lua_settop(L, j + 2);
return j-i+1;
}
static int Linsert(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i = posrelat(luaL_checkinteger(L, 2), n);
if (i > n) {
luaL_checkstack(L, i-n, "index is too big");
lua_settop(L, i + 1);
lua_pushvalue(L, 1);
return i;
}
lua_pushvalue(L, 1);
lua_insert(L, i + 2);
return n + 1;
}
static int Lremove(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i <= n) {
lua_remove(L, i + 1);
--n;
}
return n;
}
static int Lreplace(lua_State *L) {
int n = lua_gettop(L) - 2;
lua_Integer i = posrelat(luaL_checkinteger(L, 2), n);
if (i > n) {
luaL_checkstack(L, i-n, "index is too big");
lua_settop(L, i + 1);
lua_pushvalue(L, 1);
return i;
}
lua_pushvalue(L, 1);
lua_replace(L, i + 2);
return n;
}
static int Lpush(lua_State *L) {
lua_pushvalue(L, 1);
return lua_gettop(L) - 1;
}
static int Lpop(lua_State *L) {
lua_pop(L, 1);
return lua_gettop(L);
}
static int Ltake(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i > n) return 0;
lua_pop(L, n-i);
return i;
}
static int Ltail(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i = posrelat(luaL_checkinteger(L, 1), n);
if (i > n) return 0;
return n-i+1;
}
static int Lshift(lua_State *L) {
return lua_gettop(L) - 1;
}
static int Lmap(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
for (i = 2; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_pushvalue(L, i);
lua_call(L, 1, 1);
lua_replace(L, i);
}
return n-1;
}
static int Lfilter(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
for (i = 2; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_pushvalue(L, i);
lua_call(L, 1, 1);
if (!lua_toboolean(L, -1)) {
lua_remove(L, i);
--i, --n;
}
lua_pop(L, 1);
}
return n-1;
}
static int Lreduce(lua_State *L) {
int i, n = lua_gettop(L);
luaL_checkany(L, 1);
if (n <= 3) {
lua_call(L, n-1, 1);
return 1;
}
lua_pushvalue(L, 1);
lua_pushvalue(L, 2);
lua_pushvalue(L, 3);
lua_call(L, 2, 1);
for (i = 4; i <= n; ++i) {
lua_pushvalue(L, 1);
lua_insert(L, -2);
lua_pushvalue(L, i);
lua_call(L, 2, 1);
}
return 1;
}
static int Lunpack(lua_State *L) {
int i, n = lua_gettop(L);
for (i = 1; i <= n; ++i) {
lua_pushvalue(L, i);
lua_call(L, 0, LUA_MULTRET);
}
return lua_gettop(L)-n;
}
static int Lrotate(lua_State *L) {
int n = lua_gettop(L) - 1;
lua_Integer i, c = luaL_checkinteger(L, 1) % n;
#if LUA_VERSION_NUM >= 503
(void)i; /* unused */
lua_rotate(L, 2, (int)c);
#else
c = -c + ((c > 0) ? n+1 : 1);
if (c > 1) luaL_checkstack(L, c-1, "too many values");
for (i = 2; i <= c; ++i)
lua_pushvalue(L, i);
#endif
return n;
}
static int Lreverse(lua_State *L) {
int i, j, n = lua_gettop(L);
for (i = 1, j = n; i < j; ++i, --j) {
lua_pushvalue(L, i);
lua_pushvalue(L, j);
lua_replace(L, i);
lua_replace(L, j);
}
return n;
}
static int Lrep(lua_State *L) {
int top = lua_gettop(L), n = top - 1;
lua_Integer i, j, count = luaL_checkinteger(L, 1);
if (count <= 0) return 0;
if (n == 0) {
luaL_checkstack(L, count, "too many values");
lua_settop(L, count+1);
return count;
}
luaL_checkstack(L, n*(count-1), "too many values");
for (i = 1; i < count; ++i)
for (j = 2; j <= top; ++j)
lua_pushvalue(L, j);
return n*count;
}
LUALIB_API int luaopen_vararg(lua_State *L) {
luaL_Reg libs[] = {
{ "concat", Lunpack },
{ "unpack", Lunpack },
{ "filter", Lfilter },
{ "insert", Linsert },
{ "map", Lmap },
{ "pack", Lpack },
{ "pop", Lpop },
{ "push", Lpush },
{ "range", Lrange },
{ "reduce", Lreduce },
{ "remove", Lremove },
{ "rep", Lrep },
{ "replace", Lreplace },
{ "reverse", Lreverse },
{ "rotate", Lrotate },
{ "shift", Lshift },
{ "tail", Ltail },
{ "take", Ltake },
{ NULL, NULL }
};
#if LUA_VERSION_NUM >= 502
luaL_newlib(L, libs);
#else
luaL_register(L, "vararg", libs);
#endif
return 1;
}
/* cc: flags+='-s -O3 -mdll -DLUA_BUILD_AS_DLL' libs+='-llua53'
* cc: output='vararg.dll' */
local _G = require "_G"
local assert = _G.assert
local pcall = _G.pcall
local print = _G.print
local select = _G.select
local type = _G.type
local math = require "math"
local ceil = math.ceil
local huge = math.huge
local min = math.min
local table = require "table"
local unpack = table.unpack or _G.unpack
local vararg = require "vararg"
local pack = vararg.pack
local range = vararg.range
local insert = vararg.insert
local remove = vararg.remove
local replace = vararg.replace
local push = vararg.push
local concat = vararg.concat
local map = vararg.map
local rotate = vararg.rotate
-- auxiliary functions----------------------------------------------------------
local values = {}
local maxstack
for i = 1, huge do
if not pcall(unpack, values, 1, 2^i) then
local min, max = 2^(i-1), 2^i
while min < max do
local mid = ceil((min+max)/2)
if pcall(unpack, values, 1, mid) then
min = mid
else
max = mid-1
end
end
maxstack = max
break
end
end
for i = 1, maxstack, 2 do
values[i] = i
end
local function tpack(...)
return {..., n=select("#", ...)}
end
local function assertsame(v, i, j, ...)
local count = select("#", ...)
assert(count == j-i+1, count..","..i..","..j)
for pos = 1, count do
assert(v[i+pos-1] == select(pos, ...))
end
end
local function asserterror(expected, f, ...)
local ok, actual = pcall(f, ...)
assert(ok == false, "error was expected")
assert(actual:find(expected, 1), "wrong error, got "..actual)
end
-- test 'pack' function --------------------------------------------------------
local function testpack(...)
local v = {...}
local n = select("#", ...)
local p = pack(...)
assertsame(v, 1, n, p())
assert(n == p("#"))
for i,pv in p do assert(v[i] == pv) end
for i = 1, n do
assert(v[i] == p(i))
if n > 0 then
assert(v[i] == p(i-n-1))
end
end
for i = 1, n, 10 do
local j = i+9
assertsame(v, i, j, p(i, j))
if n > j then
assertsame(v, i, j, p(i-n-1, j-n-1))
end
end
end
testpack()
testpack({},{},{})
testpack(nil)
testpack(nil, nil)
testpack(nil, 1, nil)
testpack(unpack(values, 1, 254))
local ok, err = pcall(pack, unpack(values, 1, 255))
if ok then -- Lua version
assert(type(err) == "function")
else -- C version
assert(ok == false and err == "too many values to pack")
end
-- test 'range' function -------------------------------------------------------
local function testrange(n, ...)
local v = {...}
for c = 1, 3 do
for i = 1, n, c do
local j = min(i+c-1, n)
assertsame(v, i, j, range(i, j, ...))
local n = select("#", ...)
if n > 0 then
assertsame(v, i, j, range(i-n-1, j-n-1, ...))
end
end
end
end
local ok, err = pcall(range, 0, 0, ...)
if ok then -- Lua version
assert(err == nil)
else -- C version
assert(ok == false and err == "bad argument #1 to '?' (index out of bounds (0))")
end
testrange(10)
testrange(10, 1,2,3,4,5,6,7,8,9,0)
maxstack = 10000 -- use a smaller value
testrange(maxstack, unpack(values, 1, maxstack))
-- test other functions --------------------------------------------------------
assertsame({1,2,3,4,5}, 1, 5, insert(3, 3, 1,2,4,5))
assertsame({1,2,3,4,5}, 1, 5, insert(4,-1, 1,2,3,5))
assertsame({1,2,nil,4}, 1, 4, insert(4, 4, 1,2))
assertsame({nil,nil,3}, 1, 3, insert(3, 3))
assertsame({1,2,3,4,5}, 1, 5, replace(3, 3, 1,2,0,4,5))
assertsame({1,2,3,4,5}, 1, 5, replace(5,-1, 1,2,3,4,0))
assertsame({1,2,nil,4}, 1, 4, replace(4, 4, 1,2))
assertsame({nil,nil,3}, 1, 3, replace(3, 3))
assertsame({1,2,3,4,5}, 1, 5, remove( 3, 1,2,0,3,4,5))
assertsame({1,2,3,4,5}, 1, 5, remove(-1, 1,2,3,4,5,0))
assertsame({1,2,nil,4}, 1, 4, remove( 4, 1,2,nil,0,4))
assertsame({nil,nil,3}, 1, 3, remove( 3, nil,nil,0,3))
assertsame({1,2,3,4,5}, 1, 5, remove(10, 1,2,3,4,5))
assertsame({1,2,3,4,5}, 1, 5, push(5, 1,2,3,4))
assertsame({1,2,nil,4}, 1, 4, push(4, 1,2,nil))
assertsame({nil,nil,3}, 1, 3, push(3, nil,nil))
assertsame({5,1,2,3,4}, 1, 5, rotate(1, 1,2,3,4,5))
assertsame({2,3,4,5,1}, 1, 5, rotate(-1, 1,2,3,4,5))
assertsame({5,1,2,3,4}, 1, 5, rotate(11, 1,2,3,4,5))
assertsame({2,3,4,5,1}, 1, 5, rotate(-11, 1,2,3,4,5))
assertsame({1,2,3,4,5,6,7,8,9}, 1, 9, concat(pack(1,2,3),
pack(4,5,6),
pack(7,8,9)))
-- test function errors and expectional conditions ---------------------------
assertsame({"1","2","3","4","5"}, 1, 5, map(tostring, 1,2,3,4,5))
assertsame({"1","2","nil","4" }, 1, 4, map(tostring, 1,2,nil,4))
assertsame({"nil","nil","3" }, 1, 3, map(tostring, nil,nil,3))
assertsame({"1","nil","nil" }, 1, 3, map(tostring, 1,nil,nil))
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", insert)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", insert, nil)
asserterror("bad argument #2 to '[^']*' %(number expected, got nil%)", insert, nil, nil)
--asserterror("bad argument #2 to '[^']*' %(index out of bounds %(0%)%)", insert, nil, 0)
assert(select('#', insert(nil, 0)) == 1)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", replace)
asserterror("bad argument #2 to '[^']*' %(number expected, got no value%)", replace, nil)
asserterror("bad argument #2 to '[^']*' %(number expected, got nil%)", replace, nil, nil)
--asserterror("bad argument #2 to '[^']*' %(index out of bounds %(0%)%)", replace, nil, 0)
assert(select('#', replace(nil, 0)) == 0)
asserterror("bad argument #1 to '[^']*' %(number expected, got no value%)", remove)
asserterror("bad argument #1 to '[^']*' %(number expected, got nil%)", remove, nil)
--asserterror("bad argument #1 to '[^']*' %(index out of bounds %(0%)%)", remove, 0)
assert(select('#', remove(0)) == 0)
assertsame({}, 1, 0, push())
assertsame({nil}, 1, 1, push(nil))
assertsame({}, 1, 0, concat())
asserterror("attempt to call a nil value", concat, nil)
asserterror("bad argument #1 to '[^']*' %(value expected%)", map)
assertsame({}, 1, 0, map(nil))
asserterror("attempt to call a nil value", map, nil, nil)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment