Skip to content

Instantly share code, notes, and snippets.

@macks22
Created March 30, 2020 21:42
Show Gist options
  • Save macks22/0c6a3c89582cce2f358cf69035d0ef90 to your computer and use it in GitHub Desktop.
Save macks22/0c6a3c89582cce2f358cf69035d0ef90 to your computer and use it in GitHub Desktop.
Example of using custom JSONEncoder and object_hook to add numpy array support to built-in JSON serialization
import json
import numpy as np
def is_diagonal(matrix):
return np.count_nonzero(matrix - np.diag(np.diagonal(matrix))) == 0
def is_identity(matrix):
return (is_diagonal(matrix) and
np.all(np.diag(matrix) == 1))
def is_zeros(array):
return np.all(array == 0)
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.bool_):
return int(obj)
elif isinstance(obj, np.integer): # any int type
return int(obj)
elif isinstance(obj, np.floating): # any float type
return float(obj)
elif isinstance(obj, np.ndarray):
# process several special cases: zeros, diagonal, identity
values = None
if is_zeros(obj):
constructor = 'zeros'
elif is_identity(obj):
constructor = 'identity'
elif is_diagonal(obj):
constructor = 'diag'
values = np.diagonal(obj).tolist()
else: # need to express as full array
constructor = 'array'
values = obj.tolist()
return {
'dtype': f'{obj.dtype}',
'constructor': constructor,
'values': values,
'shape': obj.shape
}
return json.JSONEncoder.default(self, obj)
def decode_object(decoded_dict):
if 'dtype' in decoded_dict:
constructor_name = decoded_dict['constructor']
if constructor_name == 'array':
return np.array(decoded_dict['values'], dtype=decoded_dict['dtype'])
elif constructor_name == 'diag':
return np.diag(decoded_dict['values']).astype(decoded_dict['dtype'])
elif constructor_name == 'identity':
return np.identity(decoded_dict['shape'][0], dtype=decoded_dict['dtype'])
else:
constructor = getattr(np, constructor_name)
return constructor(decoded_dict['shape'], dtype=decoded_dict['dtype'])
return decoded_dict
def to_json(obj):
return json.dumps(obj, indent=4, cls=NumpyEncoder)
def from_json(json_blob):
return json.loads(object_hook=decode_object)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment