Skip to content

Instantly share code, notes, and snippets.

@ghamarian
Created June 6, 2022 07:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ghamarian/bdd790c8c4880eb1752aa7626a971a60 to your computer and use it in GitHub Desktop.
Save ghamarian/bdd790c8c4880eb1752aa7626a971a60 to your computer and use it in GitHub Desktop.
import operator
from itertools import cycle
from numbers import Number
import numpy as np
def broadcast(a, b, op):
if isinstance(a, Number) and isinstance(b, Number):
return op(a, b)
result = []
if a.ndim == b.ndim:
if a.shape[0] != b.shape[0]:
if a.shape[0] == 1:
a = cycle(a)
elif b.shape[0] == 1:
b = cycle(b)
else:
raise ValueError(
f"Could not broadcast together with shapes {a.shape} {b.shape}")
elif a.ndim < b.ndim:
a = cycle([a])
else:
b = cycle([b])
for a_in, b_in in zip(a, b):
result.append(broadcast(a_in, b_in, op))
return np.array(result)
def test():
list = []
a = np.arange(6).reshape(2, 3)
b = np.arange(6).reshape(2, 3)
list += [(a, b, operator.add)]
a = np.arange(6).reshape(2, 3)
b = np.arange(3).reshape(1, 3)
list += [(a, b, operator.add)]
a = np.arange(9).reshape(3, 3)
b = np.arange(3).reshape(3, 1)
list += [(a, b, operator.add)]
a = np.arange(3).reshape(1, 3)
b = np.arange(3).reshape(3, 1)
list += [(a, b, operator.add)]
a = np.arange(3)
b = np.arange(3).reshape(3, 1)
list += [(a, b, operator.add)]
for a, b, op in list:
result = broadcast(a, b, op)
np_res = op(a, b)
np.testing.assert_equal(result, np_res)
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment