Created
June 28, 2023 20:48
-
-
Save tkoz0/3ce99f5784d964560f4978b55373a96e to your computer and use it in GitHub Desktop.
bit packing scheme for matrices or 0,1,-1 with detection for operations that show it is not TU (totally unimodular)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <cstdint> | |
#include <cstdlib> | |
#include <cstring> | |
// represents a totally unimodular matrix packing each number into 2 bits | |
// includes checking that operations never result in values other than -1,0,1 | |
class TUMatrix | |
{ | |
private: | |
// matrix dimensions | |
uint32_t m,n; | |
// number of uint32_t's to pack per row | |
uint32_t rowsize; | |
// dynamically allocated matrix | |
uint32_t *mat; | |
void _copy_from(const TUMatrix& a) | |
{ | |
m = a.m; | |
n = a.n; | |
rowsize = a.rowsize; | |
mat = new uint32_t[m*rowsize](); | |
memcpy(mat,a.mat,m*rowsize*sizeof(uint32_t)); | |
} | |
inline uint32_t *_row_ptr(uint32_t i) const | |
{ | |
return mat + (i*rowsize); | |
} | |
// rotate right | |
static inline uint32_t _rotr(uint32_t u) | |
{ | |
return (u >> 1) | (u << 31); | |
} | |
// rotate left | |
static inline uint32_t _rotl(uint32_t u) | |
{ | |
return (u << 1) | (u >> 31); | |
} | |
// get hi bits | |
static inline uint32_t _his(uint32_t u) | |
{ | |
return u & 0xAAAAAAAA; | |
} | |
// get lo bits | |
static inline uint32_t _los(uint32_t u) | |
{ | |
return u & 0x55555555; | |
} | |
public: | |
TUMatrix(uint32_t m, uint32_t n) | |
{ | |
assert(m > 0); | |
assert(n > 0); | |
this->m = m; | |
this->n = n; | |
rowsize = (n + 15) / 16; | |
mat = new uint32_t[m*rowsize](); | |
} | |
~TUMatrix() | |
{ | |
delete[] mat; | |
} | |
TUMatrix(const TUMatrix& a) | |
{ | |
_copy_from(a); | |
} | |
TUMatrix& operator=(const TUMatrix& a) | |
{ | |
delete[] mat; | |
_copy_from(a); | |
return *this; | |
} | |
// number of rows | |
inline uint32_t numrows() const | |
{ | |
return m; | |
} | |
// number of cols | |
inline uint32_t numcols() const | |
{ | |
return n; | |
} | |
// store a value in the matrix, must be -1,0,1 | |
inline void setval(uint32_t i, uint32_t j, uint32_t v) | |
{ | |
assert(v == (uint32_t)(-1) || v == 0 || v == 1); | |
uint32_t b = (j&0xF)<<1; // bit position to store it in | |
uint32_t u = _row_ptr(i)[j>>4]; | |
uint32_t m = 0xFFFFFFFF - (1 << b) - (1 << (b+1)); | |
_row_ptr(i)[j>>4] = (u & m) | ((v & 0x00000003) << b); | |
} | |
// get a value in the matrix, uses lowest 2 bits in return value | |
inline uint32_t getval(uint32_t i, uint32_t j) | |
{ | |
uint32_t b = (j&0xF)<<1; | |
uint32_t u = _row_ptr(i)[j>>4]; | |
return (u >> b) & 0x3; | |
} | |
// multiply row by -1 | |
inline void rowneg(uint32_t i) | |
{ | |
uint32_t *r = _row_ptr(i); | |
for (uint32_t j = 0; j < rowsize; ++j) | |
// flip hi bit except when lo bit is 0 | |
r[j] ^= 0xAAAAAAAA & _his(_rotl(r[j])); | |
} | |
// add row i1 to row i2 | |
inline void rowadd(uint32_t i1, uint32_t i2) | |
{ | |
uint32_t *r1 = _row_ptr(i1); | |
uint32_t *r2 = _row_ptr(i2); | |
for (uint32_t j = 0; j < rowsize; ++j) | |
{ | |
uint32_t u1 = r1[j]; | |
uint32_t u2 = r2[j]; | |
uint32_t ls = _los(u1 ^ u2); // lo sum | |
uint32_t lc = _rotl(_los(u1 & u2)); // lo carry (in hi pos) | |
uint32_t hs = _his(u1 ^ u2) ^ lc; // hi sum | |
r2[j] = ls | hs; | |
} | |
} | |
// subtract row i1 from row i2 | |
inline void rowsub(uint32_t i1, uint32_t i2) | |
{ | |
uint32_t *r1 = _row_ptr(i1); | |
uint32_t *r2 = _row_ptr(i2); | |
for (uint32_t j = 0; j < rowsize; ++j) | |
{ | |
uint32_t u1 = r1[j]; | |
uint32_t u2 = ~r2[j]; | |
uint32_t ls = _los(~(u1 ^ u2)); | |
uint32_t lc = _rotl(_los(u1 | u2)); | |
uint32_t hs = _his(u1 ^ u2) ^ lc; | |
r2[j] = ls | hs; | |
} | |
} | |
// swap 2 rows | |
inline void rowswap(uint32_t i1, uint32_t i2) | |
{ | |
uint32_t *r1 = _row_ptr(i1); | |
uint32_t *r2 = _row_ptr(i2); | |
for (uint32_t j = 0; j < rowsize; ++j) | |
{ | |
uint32_t t = r1[j]; | |
r1[j] = r2[j]; | |
r2[j] = t; | |
} | |
} | |
// check that the matrix has 1,-1,0 entries only on row i | |
// if something goes wrong with TU operations, it creates 2 or -2 | |
inline void rowcheck(uint32_t i) const | |
{ | |
const uint32_t *r = (const uint32_t*) _row_ptr(i); | |
for (uint32_t i = 0; i < rowsize; ++i) | |
{ | |
uint32_t u = r[i]; | |
assert(!(u & _his(_rotl(~u)))); | |
} | |
} | |
// row reduce and find the matrix rank | |
uint32_t rank() | |
{ | |
uint32_t rank = 0; | |
for (uint32_t j = 0; j < n; ++j) | |
{ | |
// search column for nonzero value | |
uint32_t i = -1; | |
bool isneg; | |
for (uint32_t ii = rank; ii < m; ++ii) | |
{ | |
uint32_t a = getval(ii,j); | |
if (1) | |
assert(a != 2); | |
if (a != 0) | |
{ | |
i = ii; | |
isneg = (a == 3); // -1 in 2 bit value | |
break; | |
} | |
} | |
if (i == (uint32_t)(-1)) // all zero, no pivot | |
continue; | |
if (rank != i) | |
rowswap(rank,i); | |
if (isneg) | |
rowneg(rank); | |
// eliminate down the column | |
for (uint32_t ii = rank+1; ii < m; ++ii) | |
//for (uint32_t ii = 0; ii < m; ++ii) | |
{ | |
//if (ii == rank) continue; | |
uint32_t a = getval(ii,j); | |
if (a == 1) | |
rowsub(rank,ii); | |
else if (a == 3) // -1 in 2 bit value | |
rowadd(rank,ii); | |
if (1) | |
rowcheck(ii); | |
} | |
++rank; | |
} | |
return rank; | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment