Created
June 6, 2022 07:08
-
-
Save ghamarian/bdd790c8c4880eb1752aa7626a971a60 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
from itertools import cycle | |
from numbers import Number | |
import numpy as np | |
def broadcast(a, b, op): | |
if isinstance(a, Number) and isinstance(b, Number): | |
return op(a, b) | |
result = [] | |
if a.ndim == b.ndim: | |
if a.shape[0] != b.shape[0]: | |
if a.shape[0] == 1: | |
a = cycle(a) | |
elif b.shape[0] == 1: | |
b = cycle(b) | |
else: | |
raise ValueError( | |
f"Could not broadcast together with shapes {a.shape} {b.shape}") | |
elif a.ndim < b.ndim: | |
a = cycle([a]) | |
else: | |
b = cycle([b]) | |
for a_in, b_in in zip(a, b): | |
result.append(broadcast(a_in, b_in, op)) | |
return np.array(result) | |
def test(): | |
list = [] | |
a = np.arange(6).reshape(2, 3) | |
b = np.arange(6).reshape(2, 3) | |
list += [(a, b, operator.add)] | |
a = np.arange(6).reshape(2, 3) | |
b = np.arange(3).reshape(1, 3) | |
list += [(a, b, operator.add)] | |
a = np.arange(9).reshape(3, 3) | |
b = np.arange(3).reshape(3, 1) | |
list += [(a, b, operator.add)] | |
a = np.arange(3).reshape(1, 3) | |
b = np.arange(3).reshape(3, 1) | |
list += [(a, b, operator.add)] | |
a = np.arange(3) | |
b = np.arange(3).reshape(3, 1) | |
list += [(a, b, operator.add)] | |
for a, b, op in list: | |
result = broadcast(a, b, op) | |
np_res = op(a, b) | |
np.testing.assert_equal(result, np_res) | |
print('done') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment