Skip to content

Instantly share code, notes, and snippets.

@tmmsartor
Last active September 18, 2019 14:34
Show Gist options
  • Save tmmsartor/cba7b8c1c922968f378c9a4eace84904 to your computer and use it in GitHub Desktop.
Save tmmsartor/cba7b8c1c922968f378c9a4eace84904 to your computer and use it in GitHub Desktop.
#! /usr/bin/python3.7
from scipy.linalg import blas
import numpy
K = 2
# K1 = M-K | N-K
K1 = 1
VERBOSE=0
options_tuple = (
"diag",
"side",
"lower",
"overwrite_b"
)
INFO_TPL = " ".join([f"{opt}:{{{opt}}}" for opt in options_tuple])
def build_A(K, diag=0, lower=0):
A = numpy.eye(K)
iteri =0
for i in range(K):
for j in range(K):
if not lower and j >= i:
iteri += 1
A[i,j] = iteri
if lower and j <= i:
iteri += 1
A[i,j] = iteri
if diag and j == i:
A[i,j] = 1
return A
def build_B(M,N):
B = numpy.zeros((M,N))
iteri =0
for i in range(M):
for j in range(N):
iteri += 1
B[i,j] = iteri
return B
def compare(info_str, A, B, D_trmm, D_gemm):
diff = D_trmm-D_gemm
if diff.any():
print(f"{info_str}: Error")
if VERBOSE:
print("DTRMM differs from DGEMM")
print("A:")
print(A)
print("B:")
print(B)
print("TRMM:")
print(D_trmm)
print("GEMM:")
print(D_gemm)
print("diff:")
print(diff)
else:
print(f"{info_str}: Passed")
def check_trmm(kargs, side=0, lower=0, diag=0, trans_a=0, overwrite_b=0):
B = build_B(K,K+K1)
A = build_A(K, diag=diag, lower=lower)
info_str = INFO_TPL.format(**kargs)
if side==0:
# print(f"Testing A_diag:{diag}lower:{lower}*B")
D_trmm = blas.dtrmm(1.0, A, B, **kargs)
D_gemm = blas.dgemm(1.0, A, B, trans_b=0, trans_a=0, overwrite_c=0)
compare(info_str, A, B, D_trmm, D_gemm)
if side==1:
# print(f"Testing B*A_diag:{diag}lower:{lower}")
D_trmm = blas.dtrmm(1.0, A, B.T, **kargs)
D_gemm = blas.dgemm(1.0, B.T, A, trans_b=0, trans_a=0, overwrite_c=0)
compare(info_str, A, B.T, D_trmm, D_gemm)
options_n = len(options_tuple)
for conf in range(2**options_n):
# get current combination
bits = list(map(int, bin(conf)[2:]))
bits = [0]*(options_n-len(bits)) + bits
conf_dict = {name:val for name, val in zip(options_tuple,bits)}
check_trmm(conf_dict, **conf_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment