Skip to content

Instantly share code, notes, and snippets.

@isilence
Last active April 30, 2017 18:18
Show Gist options
  • Save isilence/d6ad8be9b52dec32c800825c63421ad3 to your computer and use it in GitHub Desktop.
Save isilence/d6ad8be9b52dec32c800825c63421ad3 to your computer and use it in GitHub Desktop.
#! /usr/bin/env python3
import numpy as np
def shift_bit_length(x):
return 1 << (x - 1).bit_length()
def multh(a, b, c, d, e):
if (a.shape[0] <= 1024):
c[:] = np.matmul(a, b)
return c
n = a.shape[0]
nh = n // 2
a11, a12, a21, a22 = a[:nh, :nh], a[:nh, nh:], a[nh:, :nh], a[nh:, nh:]
b11, b12, b21, b22 = b[:nh, :nh], b[:nh, nh:], b[nh:, :nh], b[nh:, nh:]
c11, c12, c21, c22 = c[:nh, :nh], c[:nh, nh:], c[nh:, :nh], c[nh:, nh:]
d11, d12, d21, d22 = d[:nh, :nh], d[:nh, nh:], d[nh:, :nh], d[nh:, nh:]
e11, e21, e22 = e[:nh, :nh], e[nh:, :nh], e[nh:, nh:]
np.subtract(a11, a21, d11)
a21 += a22
np.subtract(a21, a11, d12)
np.subtract(b12, b11, d21)
np.subtract(b22, d21, d22)
np.subtract(b22, b12, e22)
np.subtract(d22, b21, e21)
multh(a11, b11, c22, b12, e11)
np.subtract(a12, d12, a11)
multh(a12, b21, c11, b12, e11)
multh(d12, d22, c21, b12, e11)
c11 += c22
c21 += c22
multh(a11, b22, c12, b12, e11)
multh(a21, d21, c22, d12, e11)
c12 += c21
c12 += c22
multh(d11, e22, a21, d12, e11)
multh(a22, e21, b22, d11, e11)
c21 += a21
c22 += c21
c21 -= b22
return c
def mult(mat1, mat2):
n, k = mat1.shape
k2, m = mat2.shape
assert(k == k2)
s = shift_bit_length(max(n, max(k, m)))
m1 = np.zeros((s, s), dtype=mat1.dtype, order="C")
m2 = np.zeros((s, s), dtype=mat2.dtype, order="C")
d = np.zeros((s, s), dtype=mat2.dtype, order="C")
c = np.zeros((s, s), dtype=mat2.dtype, order="C")
e = np.zeros((s, s), dtype=mat2.dtype, order="C")
m1[:n, :k] = mat1
m2[:k, :m] = mat2
return multh(m1, m2, c, d, e)[:n, :m]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment