Last active
November 16, 2018 02:21
-
-
Save chethega/d41971bf9ad432df65280688ca2e596d to your computer and use it in GitHub Desktop.
static bitvectors, set version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if VERSION >= v"1.0.2" | |
import Base._blsr | |
else | |
@inline _blsr(x) = x & (x-1) | |
end | |
struct SBitSet{N} <:AbstractSet{Int64} | |
chunks::NTuple{N,UInt64} | |
end | |
@inline Base.length(B::SBitSet) = count(B) | |
Base.@propagate_inbounds function Base.getindex(B::SBitSet, s) | |
i1,i2 = Base.get_chunks_id(i) | |
return !iszero(B.chunks[i1] & (1<<i2)) | |
end | |
Base.@propagate_inbounds Base.in(i, B::SBitSet) = B[i] | |
@inline function Base.iszero(B::SBitSet{N}) where N | |
@inbounds for i=1:N | |
iszero(B.chunks[i]) || return false | |
end | |
true | |
end | |
@inline Base.isempty(B::SBitSet) = iszero(B) | |
@inline function Base.count(B::SBitSet{N}) where N | |
return sum(count_ones, B.chunks) | |
end | |
@inline Base.count_ones(B::SBitSet) = count(B) | |
@inline function Base.:&(B::SBitSet{N}, C::SBitSet{N}) where N | |
@inbounds SBitSet(ntuple(i->(B.chunks[i] & C.chunks[i]),N)) | |
end | |
@inline function Base.:|(B::SBitSet{N}, C::SBitSet{N}) where N | |
@inbounds SBitSet(ntuple(i->(B.chunks[i] | C.chunks[i]),N)) | |
end | |
@inline function Base.:~(B::SBitSet{N}) where N | |
@inbounds SBitSet(ntuple(i->~B.chunks[i],N)) | |
end | |
@inline function Base.xor(B::SBitSet{N}, C::SBitSet{N}) where N | |
@inbounds SBitSet(ntuple(i->xor(B.chunks[i], C.chunks[i]),N)) | |
end | |
@inline function SBitSet{N}(k::Integer) where N | |
i1,i2 = Base.get_chunks_id(k) | |
u = UInt64(1)<<i2 | |
SBitSet{N}(ntuple(i->ifelse(i==i1, u, UInt64(0)),N)) | |
end | |
@inline function Base.iterate(B::SBitSet{N}) where N | |
N==0 && return nothing | |
return iterate(B, (1, @inbounds B.chunks[1])) | |
end | |
@inline function Base.iterate(B::SBitSet{N}, s) where N | |
N==0 && return nothing | |
i1, c = s | |
while c==0 | |
i1 % UInt >= N % UInt && return nothing | |
i1 += 1 | |
@inbounds c = B.chunks[i1] | |
end | |
tz = trailing_zeros(c) + 1 | |
c = _blsr(c) | |
return ((i1-1)<<6 + tz, (i1, c)) | |
end | |
struct BitMat{N} <: AbstractMatrix{Bool} | |
chunks::Matrix{UInt64} | |
end | |
function BitMat(nr, nc) | |
nr_chunk = Base.num_bit_chunks(nr) | |
chunks = zeros(UInt64, nr_chunk, nc) | |
return BitMat{nr_chunk}(chunks) | |
end | |
function BitMat(::Val{N}, nc) where N | |
chunks = zeros(UInt64, N, nc) | |
return BitMat{N}(chunks) | |
end | |
@inline Base.length(B::BitMat) = length(B.chunks)<<6 | |
@inline Base.size(B::BitMat{N}) where N = (N<<6, size(B.chunks,2)) | |
Base.IndexStyle(::BitMat) = IndexLinear() | |
Base.@propagate_inbounds function Base.getindex(B::BitMat, i) | |
i1,i2 = Base.get_chunks_id(i) | |
return !iszero(B.chunks[i1] & (1<<i2)) | |
end | |
Base.@propagate_inbounds function Base.getindex(B::BitMat, i, j) | |
i1,i2 = Base.get_chunks_id(i) | |
return !iszero(B.chunks[i1,j] & (1<<i2)) | |
end | |
Base.@propagate_inbounds function Base.setindex!(B::BitMat, b::Bool, i) | |
i1, i2 = Base.get_chunks_id(i) | |
u = UInt64(1) << i2 | |
c = B.chunks[i1] | |
B.chunks[i1] = ifelse(b, c | u, c & ~u) | |
b | |
end | |
Base.@propagate_inbounds function Base.setindex!(B::BitMat, b::Bool, i, j) | |
i1, i2 = Base.get_chunks_id(i) | |
u = UInt64(1) << i2 | |
c = B.chunks[i1,j] | |
B.chunks[i1,j] = ifelse(b, c | u, c & ~u) | |
b | |
end | |
Base.@propagate_inbounds function Base.setindex!(B::BitMat{N}, b::SBitSet{N}, ::Colon ,j::Integer) where N | |
for i=1:N | |
B.chunks[i,j]=b.chunks[i] | |
end | |
b | |
end | |
Base.@propagate_inbounds function Base.getindex(B::BitMat{N}, ::Colon, j::Integer) where N | |
Bc = B.chunks | |
return SBitSet{N}(ntuple(i->Bc[i,j], N)) | |
end | |
#optional: pretty printing. MIME variant is for iJulia | |
Base.show(io::IO, ::MIME{Symbol("text/plain")}, bv::SBitSet) = show(io, bv) | |
function Base.show(io::IO, bv::SBitSet{N}) where N | |
print(io, "SBitSet{$N}(", collect(bv), ")") | |
end | |
Base.show(io::IO, ::MIME{Symbol("text/plain")}, bm::BitMat) = show(io, bm) | |
function Base.show(io::IO, bm::BitMat{N}) where N | |
print(io, "BitMat{$N} (" ) | |
for i=1:size(bm,2) | |
print(io, "\n$(i)\t") | |
print(io, bm[:,i]) | |
end | |
print(io, " )") | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment