Last active
October 12, 2019 00:36
-
-
Save quattro/daa2298bc22bae412ac708ae0ad6ac8d to your computer and use it in GitHub Desktop.
mailman algorithm for binary matrix / real vector multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
import argparse as ap | |
import datetime as dt | |
import os | |
import sys | |
import numpy as np | |
from scipy.sparse import csr_matrix | |
def mailman(A, X): | |
""" | |
A = matrix | |
X = matrix | |
out = A @ X using mailman algorithm | |
""" | |
m, n = A.shape | |
n, trials = X.shape | |
P = mm_compute_p(A) | |
Y = mm_compute_prod(P, X, (m, trials)) | |
return Y | |
def mm_compute_p(A): | |
m, n = A.shape | |
# base counter to compute indices | |
b = np.arange(m)[::-1] | |
# A = U P; here we compute P | |
# this works by first computing the row indices by bit-shifting the column entries | |
# in other words, we compute the integer row index from the binary string column (ie bit shift and add) | |
# This bit-shift approach can only work for binary matrices. | |
# Using numpy-heavy notation here really speeds things up, but this is still the slow part compared | |
# with the U @ z = out loop; however if we need to compute -many- A @ xs for different x then the P matrix | |
# can be computed just once and reused per x | |
idxs = np.sum(np.left_shift(A.T, b), axis=1) | |
P = csr_matrix((np.ones(n), (idxs, np.arange(n))), (n, n)) | |
return P | |
def mm_compute_prod(P, X, shape): | |
n, n = P.shape | |
if np.isscalar(shape): | |
m = shape | |
else: | |
m = shape[0] | |
# compute P @ x = z | |
Z = P.dot(X) | |
# compute the U @ z = output | |
# U isn't explicitly constructed, but computing a product with it defines a recurrence relation | |
# we unroll that recurrence using this loop | |
# this computes products in 'batches' turning the MM matrix-vector multiplication into MM matrix-matrix mult | |
out = np.zeros(shape) | |
ubound = n | |
div = ubound // 2 | |
for idx in range(m): | |
tmp = Z[div:ubound] | |
out[idx] = np.sum(tmp, axis=0) | |
Z[:div] += tmp | |
ubound //= 2 | |
div = ubound // 2 | |
return out | |
def main(args): | |
argp = ap.ArgumentParser(description="Testing ground for binary-matrix multiplication with mailman algorithm") | |
argp.add_argument("m", type=int, help="generate m x 2**m matrix") | |
argp.add_argument("--iters", type=int, default=10, help="Number of multiplications to perform") | |
argp.add_argument("-o", "--output", type=ap.FileType("w"), default=sys.stdout) | |
args = argp.parse_args(args) | |
m = args.m | |
n = 2**m | |
A = np.random.randint(0, 2, size=(m, n)) | |
X = np.random.normal(size=(n, args.iters)) | |
start = dt.datetime.now() | |
y1s = A @ X | |
stop1 = dt.datetime.now() | |
y2s = mailman(A, X) | |
stop2 = dt.datetime.now() | |
alls = [] | |
for i in range(args.iters): | |
alls.append(all(np.isclose(y1s.T[i], y2s.T[i]))) | |
print("equal? = {}".format(all(alls))) | |
print("time1 = {}".format((stop1 - start).total_seconds())) | |
print("time2 = {}".format((stop2 - stop1).total_seconds())) | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main(sys.argv[1:])) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment