Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created April 4, 2019 10:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bzamecnik/4b96bd6c6b3e2c65509c6681f4178417 to your computer and use it in GitHub Desktop.
Save bzamecnik/4b96bd6c6b3e2c65509c6681f4178417 to your computer and use it in GitHub Desktop.
Allows to compare data in two DataFlows, eg. for regression tests.
import numpy as np
from tensorpack import DataFlow
class CompareData(DataFlow):
"""
Compares that two DataFlows generate equal data, raises ValueError if not.
"""
def __init__(self, a, b):
self.a = a
self.b = b
assert a.size() == b.size(), \
"Both DataFlows must have the same size! {} != {}".format(a.size(), b.size())
def reset_state(self):
for d in [self.a, self.b]:
d.reset_state()
def size(self):
"""
Return the minimum size among all.
"""
return min([self.a.size(), self.b.size()])
def get_data(self):
it_a = self.a.get_data()
it_b = self.b.get_data()
try:
while True:
data_a = next(it_a)
data_b = next(it_b)
yield compare_trees(data_a, data_b)
except StopIteration: # some of them are exhausted
pass
finally:
del it_a
del it_b
def compare_trees(a, b, path=""):
types = [type(x) for x in [a, b]]
if types[0] != types[1] and set(types) != {list, tuple}:
raise ValueError('%s: Non-equal types: %s' % (path, types))
if types[0] in [list, tuple]:
if len(a) != len(b):
raise ValueError('%s: Non-equal sequence lengths: %d %d' % (path, len(a), len(b)))
for i, (item_a, item_b) in enumerate(zip(a, b)):
compare_trees(item_a, item_b, path='%s/%d' % (path, i))
elif types[0] == np.ndarray:
if a.dtype != b.dtype:
raise ValueError('%s: Non-equal dtypes of numpy arrays: %s, %s' % (path, a.dtype, b.dtype))
if a.shape != b.shape:
raise ValueError('%s: Non-equal shapes of numpy arrays: %s, %s' % (path, a.shape, b.shape))
if not np.allclose(a, b):
raise ValueError('%s: Numpy array values are not close: %s %s' % (path, a, b))
else:
if a != b:
raise ValueError('%s: Non-equal values: %s %s' % (path, a, b))
print('%s: OK' % path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment