Skip to content

Instantly share code, notes, and snippets.

@shuuji3
Forked from Terminus-IMRC/cannon.py
Last active July 1, 2019 10:07
Show Gist options
  • Save shuuji3/fc98c1151f05dea1fb544165ff02ccb5 to your computer and use it in GitHub Desktop.
Save shuuji3/fc98c1151f05dea1fb544165ff02ccb5 to your computer and use it in GitHub Desktop.
Matrix-matrix multiplication: Cannon's algorithm
#!/usr/bin/env python3
import numpy as np
# Assume that P = N * N.
N = 4
A = np.arange(0, 2 * N * N, 2).reshape(N, N)
B = np.arange(1, 2 * N * N, 2).reshape(N, N)
C_gt = A.dot(B)
# Procedure 1: Shift A[i, :] to left by i
print('Original A:')
print(A)
print()
for i in range(N):
A[i, :] = np.roll(A[i, :], -i)
print('A after setup:')
print(A)
print()
# Procedure 2: Shift B[:, j] to upper by j
print('Original B:')
print(B)
print()
for j in range(N):
B[:, j] = np.roll(B[:, j], -j)
print('B after setup:')
print(B)
print()
# Procedure 3: Zero-clear C
C = np.zeros((N, N), dtype = int)
# Procedure 4: Repeat N times (k = 0, 1, ..., N-1)
for k in range(N):
for i in range(N):
for j in range(N):
C[i, j] += A[i, j] * B[i, j]
# Procedure 4.1: Shift A to left by 1
A = np.roll(A, -1, axis = 1)
# Procedure 4.2: Shift B to upper by 1
B = np.roll(B, -1, axis = 0)
print('C (ground truth):')
print(C_gt)
print()
print('C with Cannon\'s algorithm:')
print(C)
print()
print('Matched' if np.allclose(C, C_gt) else 'Not matched')
Original A:
[[ 0 2 4 6]
[ 8 10 12 14]
[16 18 20 22]
[24 26 28 30]]
A after setup:
[[ 0 2 4 6]
[10 12 14 8]
[20 22 16 18]
[30 24 26 28]]
Original B:
[[ 1 3 5 7]
[ 9 11 13 15]
[17 19 21 23]
[25 27 29 31]]
B after setup:
[[ 1 11 21 31]
[ 9 19 29 7]
[17 27 5 15]
[25 3 13 23]]
C (ground truth):
[[ 236 260 284 308]
[ 652 740 828 916]
[1068 1220 1372 1524]
[1484 1700 1916 2132]]
C with Cannon's algorithm:
[[ 236 260 284 308]
[ 652 740 828 916]
[1068 1220 1372 1524]
[1484 1700 1916 2132]]
Matched
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment