Skip to content

Instantly share code, notes, and snippets.

@drscotthawley
Last active October 4, 2022 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drscotthawley/fa7a7c37c25158f53249bcfa5cd5174b to your computer and use it in GitHub Desktop.
Save drscotthawley/fa7a7c37c25158f53249bcfa5cd5174b to your computer and use it in GitHub Desktop.
Tries to multiply two arrays/matrices in a variety of ways; returns what "works"
def magic_mult(a, b):
"""
Tries to multiply two arrays/matrices in a variety of ways
Returns all possible working combos as a dict, with the shapes of their respective outputs
Author: Scott H. Hawley, @drscotthawley
"""
combos = ['a*b', 'a*b.T', 'a.T*b', 'a.T*b.T','b*a', 'b*a.T', 'b.T*a', 'b.T*a.T'] # elementwise multiplications
combos += [s.replace('*',' @ ') for s in combos] # matrix multiplications (I like the space here)
working_combos = {}
for s in combos:
try:
out = eval(s) # eval executes a string as if it were Python code. Usually regarded as a security nightmare.
working_combos[s] = out.shape
except: pass
return working_combos
@drscotthawley
Copy link
Author

drscotthawley commented Oct 4, 2022

import numpy as np 

a = np.random.rand(4,3)
b = np.random.rand(5,4)
print(magic_mult(a,b))
{'a.T @ b.T': (3, 5), 'b @ a': (5, 3)}

.

a2 = np.random.rand(3)
b2 = np.random.rand(3,4)
print(magic_mult(a2,b2))
{'a*b.T': (4, 3), 'a.T*b.T': (4, 3), 'b.T*a': (4, 3), 'b.T*a.T': (4, 3), 'a @ b': (4,), 'a.T @ b': (4,), 'b.T @ a': (4,), 'b.T @ a.T': (4,)}

.

a3 = np.random.rand(3,4,2)
b3 = np.random.rand(3,2,4)
print(magic_mult(a3,b3))
{'a @ b': (3, 4, 4), 'b @ a': (3, 2, 2)}

.

import torch

at = torch.tensor(a)
bt = torch.tensor(b)
print(magic_mult2(at,bt))
{'a.T @ b.T': torch.Size([3, 5]), 'b @ a': torch.Size([5, 3])}

.

a2t = torch.tensor(a2)
b2t = torch.tensor(b2)
print(magic_mult2(a2t,b2t))
{'a*b.T': torch.Size([4, 3]), 'a.T*b.T': torch.Size([4, 3]), 'b.T*a': torch.Size([4, 3]), 'b.T*a.T': torch.Size([4, 3]), 'a @ b': torch.Size([4]), 'a.T @ b': torch.Size([4]), 'b.T @ a': torch.Size([4]), 'b.T @ a.T': torch.Size([4])}

.

a4t = torch.tensor(a4)
b4t = torch.tensor(b4)
print(magic_mult2(a4t,b4t))
{'a @ b.T': torch.Size([5, 4, 3])}
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2981.)
  """Entry point for launching an IPython kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment