Skip to content

Instantly share code, notes, and snippets.

@quattro
Last active October 12, 2019 00:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save quattro/daa2298bc22bae412ac708ae0ad6ac8d to your computer and use it in GitHub Desktop.
Save quattro/daa2298bc22bae412ac708ae0ad6ac8d to your computer and use it in GitHub Desktop.
mailman algorithm for binary matrix / real vector multiplication
#! /usr/bin/env python
import argparse as ap
import datetime as dt
import os
import sys
import numpy as np
from scipy.sparse import csr_matrix
def mailman(A, X):
"""
A = matrix
X = matrix
out = A @ X using mailman algorithm
"""
m, n = A.shape
n, trials = X.shape
P = mm_compute_p(A)
Y = mm_compute_prod(P, X, (m, trials))
return Y
def mm_compute_p(A):
m, n = A.shape
# base counter to compute indices
b = np.arange(m)[::-1]
# A = U P; here we compute P
# this works by first computing the row indices by bit-shifting the column entries
# in other words, we compute the integer row index from the binary string column (ie bit shift and add)
# This bit-shift approach can only work for binary matrices.
# Using numpy-heavy notation here really speeds things up, but this is still the slow part compared
# with the U @ z = out loop; however if we need to compute -many- A @ xs for different x then the P matrix
# can be computed just once and reused per x
idxs = np.sum(np.left_shift(A.T, b), axis=1)
P = csr_matrix((np.ones(n), (idxs, np.arange(n))), (n, n))
return P
def mm_compute_prod(P, X, shape):
n, n = P.shape
if np.isscalar(shape):
m = shape
else:
m = shape[0]
# compute P @ x = z
Z = P.dot(X)
# compute the U @ z = output
# U isn't explicitly constructed, but computing a product with it defines a recurrence relation
# we unroll that recurrence using this loop
# this computes products in 'batches' turning the MM matrix-vector multiplication into MM matrix-matrix mult
out = np.zeros(shape)
ubound = n
div = ubound // 2
for idx in range(m):
tmp = Z[div:ubound]
out[idx] = np.sum(tmp, axis=0)
Z[:div] += tmp
ubound //= 2
div = ubound // 2
return out
def main(args):
argp = ap.ArgumentParser(description="Testing ground for binary-matrix multiplication with mailman algorithm")
argp.add_argument("m", type=int, help="generate m x 2**m matrix")
argp.add_argument("--iters", type=int, default=10, help="Number of multiplications to perform")
argp.add_argument("-o", "--output", type=ap.FileType("w"), default=sys.stdout)
args = argp.parse_args(args)
m = args.m
n = 2**m
A = np.random.randint(0, 2, size=(m, n))
X = np.random.normal(size=(n, args.iters))
start = dt.datetime.now()
y1s = A @ X
stop1 = dt.datetime.now()
y2s = mailman(A, X)
stop2 = dt.datetime.now()
alls = []
for i in range(args.iters):
alls.append(all(np.isclose(y1s.T[i], y2s.T[i])))
print("equal? = {}".format(all(alls)))
print("time1 = {}".format((stop1 - start).total_seconds()))
print("time2 = {}".format((stop2 - stop1).total_seconds()))
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment