Skip to content

Instantly share code, notes, and snippets.

@tkoz0
Created June 28, 2023 20:48
Show Gist options
  • Save tkoz0/3ce99f5784d964560f4978b55373a96e to your computer and use it in GitHub Desktop.
Save tkoz0/3ce99f5784d964560f4978b55373a96e to your computer and use it in GitHub Desktop.
bit packing scheme for matrices or 0,1,-1 with detection for operations that show it is not TU (totally unimodular)
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <cstring>
// represents a totally unimodular matrix packing each number into 2 bits
// includes checking that operations never result in values other than -1,0,1
class TUMatrix
{
private:
// matrix dimensions
uint32_t m,n;
// number of uint32_t's to pack per row
uint32_t rowsize;
// dynamically allocated matrix
uint32_t *mat;
void _copy_from(const TUMatrix& a)
{
m = a.m;
n = a.n;
rowsize = a.rowsize;
mat = new uint32_t[m*rowsize]();
memcpy(mat,a.mat,m*rowsize*sizeof(uint32_t));
}
inline uint32_t *_row_ptr(uint32_t i) const
{
return mat + (i*rowsize);
}
// rotate right
static inline uint32_t _rotr(uint32_t u)
{
return (u >> 1) | (u << 31);
}
// rotate left
static inline uint32_t _rotl(uint32_t u)
{
return (u << 1) | (u >> 31);
}
// get hi bits
static inline uint32_t _his(uint32_t u)
{
return u & 0xAAAAAAAA;
}
// get lo bits
static inline uint32_t _los(uint32_t u)
{
return u & 0x55555555;
}
public:
TUMatrix(uint32_t m, uint32_t n)
{
assert(m > 0);
assert(n > 0);
this->m = m;
this->n = n;
rowsize = (n + 15) / 16;
mat = new uint32_t[m*rowsize]();
}
~TUMatrix()
{
delete[] mat;
}
TUMatrix(const TUMatrix& a)
{
_copy_from(a);
}
TUMatrix& operator=(const TUMatrix& a)
{
delete[] mat;
_copy_from(a);
return *this;
}
// number of rows
inline uint32_t numrows() const
{
return m;
}
// number of cols
inline uint32_t numcols() const
{
return n;
}
// store a value in the matrix, must be -1,0,1
inline void setval(uint32_t i, uint32_t j, uint32_t v)
{
assert(v == (uint32_t)(-1) || v == 0 || v == 1);
uint32_t b = (j&0xF)<<1; // bit position to store it in
uint32_t u = _row_ptr(i)[j>>4];
uint32_t m = 0xFFFFFFFF - (1 << b) - (1 << (b+1));
_row_ptr(i)[j>>4] = (u & m) | ((v & 0x00000003) << b);
}
// get a value in the matrix, uses lowest 2 bits in return value
inline uint32_t getval(uint32_t i, uint32_t j)
{
uint32_t b = (j&0xF)<<1;
uint32_t u = _row_ptr(i)[j>>4];
return (u >> b) & 0x3;
}
// multiply row by -1
inline void rowneg(uint32_t i)
{
uint32_t *r = _row_ptr(i);
for (uint32_t j = 0; j < rowsize; ++j)
// flip hi bit except when lo bit is 0
r[j] ^= 0xAAAAAAAA & _his(_rotl(r[j]));
}
// add row i1 to row i2
inline void rowadd(uint32_t i1, uint32_t i2)
{
uint32_t *r1 = _row_ptr(i1);
uint32_t *r2 = _row_ptr(i2);
for (uint32_t j = 0; j < rowsize; ++j)
{
uint32_t u1 = r1[j];
uint32_t u2 = r2[j];
uint32_t ls = _los(u1 ^ u2); // lo sum
uint32_t lc = _rotl(_los(u1 & u2)); // lo carry (in hi pos)
uint32_t hs = _his(u1 ^ u2) ^ lc; // hi sum
r2[j] = ls | hs;
}
}
// subtract row i1 from row i2
inline void rowsub(uint32_t i1, uint32_t i2)
{
uint32_t *r1 = _row_ptr(i1);
uint32_t *r2 = _row_ptr(i2);
for (uint32_t j = 0; j < rowsize; ++j)
{
uint32_t u1 = r1[j];
uint32_t u2 = ~r2[j];
uint32_t ls = _los(~(u1 ^ u2));
uint32_t lc = _rotl(_los(u1 | u2));
uint32_t hs = _his(u1 ^ u2) ^ lc;
r2[j] = ls | hs;
}
}
// swap 2 rows
inline void rowswap(uint32_t i1, uint32_t i2)
{
uint32_t *r1 = _row_ptr(i1);
uint32_t *r2 = _row_ptr(i2);
for (uint32_t j = 0; j < rowsize; ++j)
{
uint32_t t = r1[j];
r1[j] = r2[j];
r2[j] = t;
}
}
// check that the matrix has 1,-1,0 entries only on row i
// if something goes wrong with TU operations, it creates 2 or -2
inline void rowcheck(uint32_t i) const
{
const uint32_t *r = (const uint32_t*) _row_ptr(i);
for (uint32_t i = 0; i < rowsize; ++i)
{
uint32_t u = r[i];
assert(!(u & _his(_rotl(~u))));
}
}
// row reduce and find the matrix rank
uint32_t rank()
{
uint32_t rank = 0;
for (uint32_t j = 0; j < n; ++j)
{
// search column for nonzero value
uint32_t i = -1;
bool isneg;
for (uint32_t ii = rank; ii < m; ++ii)
{
uint32_t a = getval(ii,j);
if (1)
assert(a != 2);
if (a != 0)
{
i = ii;
isneg = (a == 3); // -1 in 2 bit value
break;
}
}
if (i == (uint32_t)(-1)) // all zero, no pivot
continue;
if (rank != i)
rowswap(rank,i);
if (isneg)
rowneg(rank);
// eliminate down the column
for (uint32_t ii = rank+1; ii < m; ++ii)
//for (uint32_t ii = 0; ii < m; ++ii)
{
//if (ii == rank) continue;
uint32_t a = getval(ii,j);
if (a == 1)
rowsub(rank,ii);
else if (a == 3) // -1 in 2 bit value
rowadd(rank,ii);
if (1)
rowcheck(ii);
}
++rank;
}
return rank;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment