Last active
September 14, 2023 16:04
-
-
Save braingram/52b59a708843aab58f4ea44920825272 to your computer and use it in GitHub Desktop.
deepdiff asdf comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import pprint | |
import asdf | |
import astropy.time | |
from astropy.units import Quantity | |
import deepdiff | |
from deepdiff.operator import BaseOperator | |
import gwcs | |
from gwcs.converters.tests.test_wcs import _assert_wcs_equal | |
import numpy as np | |
rtol=1e-05 | |
atol=1e-08 | |
equal_nan=True | |
class NDArrayTypeOperator(BaseOperator): | |
def give_up_diffing(self, level, diff_instance): | |
a, b = level.t1, level.t2 | |
meta = {} | |
if a.shape != b.shape: | |
meta['shapes'] = [a.shape, b.shape] | |
if a.dtype != b.dtype: | |
meta['dtypes'] = [a.dtype, b.dtype] | |
if isinstance(a, Quantity) and isinstance(b, Quantity): | |
if a.unit != b.unit: | |
meta['units'] = [a.unit, b.unit] | |
if not meta: # only compare if shapes and dtypes match | |
if not np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan): | |
abs_diff = np.nansum(np.abs((a - b))) | |
meta['abs_diff'] = abs_diff | |
n_diffs = np.count_nonzero(np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) | |
meta['n_locations'] = n_diffs | |
if meta: | |
diff_instance.custom_report_result('arrays_differ', level, meta) | |
return True | |
nd_compare = NDArrayTypeOperator(types=[asdf.tags.core.NDArrayType, np.ndarray]) | |
class TimeOperator(BaseOperator): | |
def give_up_diffing(self, level, diff_instance): | |
if level.t1 != level.t2: | |
# TODO include time difference | |
diff_instance.custom_report_result('times_differ', level, { | |
"extra": "information", | |
}) | |
return True | |
time_compare = TimeOperator(types=[astropy.time.Time]) | |
def wcs_equal(a, b): | |
try: | |
# can this be made part of the public gwcs api? | |
_assert_wcs_equal(a, b) | |
return True | |
except AssertionError as e: | |
# TODO return information about difference | |
return False | |
class WCSOperator(BaseOperator): | |
def give_up_diffing(self, level, diff_instance): | |
if not wcs_equal(level.t1, level.t2): | |
diff_instance.custom_report_result('wcs_differ', level, { | |
"extra": "information", | |
}) | |
return True | |
wcs_compare = WCSOperator(types=[gwcs.WCS]) | |
def asdf_diff(af0, af1): | |
diff = deepdiff.DeepDiff( | |
af0.tree, | |
af1.tree, | |
ignore_nan_inequality=True, | |
custom_operators=[nd_compare, time_compare, wcs_compare], | |
) | |
# the conversion between NDArrayType and ndarray adds a bunch | |
# of type changes, ignore these for now. Ideally we could find | |
# a way to remove just the NDArrayType ones | |
if diff['type_changes']: | |
del diff['type_changes'] | |
return diff | |
if __name__ == '__main__': | |
fn = 'r0000101001001001001_01101_0001_WFI01_cal_repoint.asdf' | |
with asdf.open(fn, 'r') as source: | |
modified = asdf.AsdfFile(copy.deepcopy(source.tree)) | |
# modify | |
u = modified['roman']['data'].unit | |
modified['roman']['data'][0] += 1 * u | |
diff = asdf_diff(source, modified) | |
pprint.pprint(diff) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running the above outputs: