Skip to content

Instantly share code, notes, and snippets.

@pfirsich
Created August 26, 2018 18:04
Show Gist options
  • Save pfirsich/3a79b57b63b7548d545f8df4c48672c4 to your computer and use it in GitHub Desktop.
Save pfirsich/3a79b57b63b7548d545f8df4c48672c4 to your computer and use it in GitHub Desktop.
A LuaJIT bitmask class (like std::bitset)
local bit = require("bit")
local band, bor, bxor, bnot, lshift = bit.band, bit.bor, bit.bxor, bit.bnot, bit.lshift
local function pot(n) -- power of two
return lshift(1, n)
end
local BitMask = setmetatable({}, {__call = function(t, ...)
local self = setmetatable({}, t)
self:initialize(...)
return self
end})
BitMask.__index = BitMask
function BitMask:initialize(arg)
arg = arg or 32
assert(type(arg) == "number" or type(arg) == "table")
if type(arg) == "number" then
local size = arg
self.size = size
local count = math.ceil(size / 32) -- 32 bits per number
self.numbers = {}
for i = 1, count do
self.numbers[i] = 0
end
elseif type(arg) == "table" then -- copy constructor
local other = arg
self.size = other.size
self.numbers = {}
for i = 1, #other.numbers do
self.numbers[i] = other.numbers[i]
end
end
end
-- BitMask:count, all, any
function BitMask:_getNumIndex(i)
assert(i > 0 and i <= self.size)
local numIndex = math.floor((i-1)/32) + 1
local mask = pot((i-1) % 32)
return numIndex, mask
end
function BitMask:set(i, value)
local numIndex, mask = self:_getNumIndex(i)
if value then
--print("set", self.numbers[numIndex], mask, bor(self.numbers[numIndex], mask))
self.numbers[numIndex] = bor(self.numbers[numIndex], mask)
else
self.numbers[numIndex] = band(self.numbers[numIndex], bnot(mask))
end
end
function BitMask:get(i)
local numIndex, mask = self:_getNumIndex(i)
return band(self.numbers[numIndex], mask) == mask
end
function BitMask:toggle(i)
self:set(i, not self:get(i))
end
-- checks if `self & other == other` i.e. if other is set in self.
function BitMask:check(other)
assert(self.size == other.size)
for i = 1, #self.numbers do
if band(self.numbers[i], other.numbers[i]) ~= other.numbers[i] then
return false
end
end
return true
end
-- does `self = self & ~other` i.e. remove all bits set in other from self
function BitMask:remove(other)
assert(self.size == other.size)
for i = 1, #self.numbers do
self.numbers[i] = band(self.numbers[i], bnot(other.numbers[i]))
end
end
function BitMask:setOp(op, other)
assert(self.size == other.size)
for i = 1, #self.numbers do
self.numbers[i] = op(self.numbers[i], other.numbers[i])
end
end
function BitMask:setAnd(other)
self:setOp(band, other)
end
function BitMask:setOr(other)
self:setOp(bor, other)
end
function BitMask:setXor(other)
self:setOp(bxor, other)
end
function BitMask:setNot()
for i = 1, #self.numbers do
self.numbers[i] = bnot(self.numbers[i])
end
end
function BitMask:retAnd(other)
local ret = BitMask(self)
ret:setAnd(other)
return ret
end
function BitMask:retOr(other)
local ret = BitMask(self)
ret:setOr(other)
return ret
end
function BitMask:retXor(other)
local ret = BitMask(self)
ret:setXor(other)
return ret
end
function BitMask:retNot()
local ret = BitMask(self)
ret:setNot()
return ret
end
function BitMask:string()
local parts = {}
for i = self.size, 1, -1 do
parts[i] = self:get(i) and "1" or "0"
end
return table.concat(parts)
end
if arg[1] == "test" then
local function randomMask(size, numSet)
local mask = BitMask(size)
for i = 1, numSet do
mask:set(math.random(1, mask.size), true)
end
return mask
end
local masks = {}
for i = 1, 1000 do
masks[#masks + 1] = randomMask(nil, math.random(1, 16))
masks[#masks + 1] = BitMask(masks[#masks])
assert(masks[#masks]:check(masks[#masks-1]))
masks[#masks + 1] = randomMask(32, math.random(1, 16))
masks[#masks + 1] = BitMask(masks[#masks])
assert(masks[#masks]:check(masks[#masks-1]))
masks[#masks + 1] = randomMask(64, math.random(1, 32))
masks[#masks + 1] = BitMask(masks[#masks])
assert(masks[#masks]:check(masks[#masks-1]))
masks[#masks + 1] = randomMask(128, math.random(1, 64))
masks[#masks + 1] = BitMask(masks[#masks])
assert(masks[#masks]:check(masks[#masks-1]))
end
for m = 1, #masks do
local mask = masks[m]
local other = randomMask(mask.size, mask.size/4)
local i = math.random(1, mask.size)
local val = math.random() > 0.5
mask:set(i, val)
assert(mask:get(i) == val)
assert(mask:check(mask))
mask:setOr(other)
assert(mask:check(other))
mask:remove(other)
assert(not mask:check(other))
end
end
return BitMask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment