Created
March 30, 2020 21:42
-
-
Save macks22/0c6a3c89582cce2f358cf69035d0ef90 to your computer and use it in GitHub Desktop.
Example of using custom JSONEncoder and object_hook to add numpy array support to built-in JSON serialization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import numpy as np | |
def is_diagonal(matrix): | |
return np.count_nonzero(matrix - np.diag(np.diagonal(matrix))) == 0 | |
def is_identity(matrix): | |
return (is_diagonal(matrix) and | |
np.all(np.diag(matrix) == 1)) | |
def is_zeros(array): | |
return np.all(array == 0) | |
class NumpyEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.bool_): | |
return int(obj) | |
elif isinstance(obj, np.integer): # any int type | |
return int(obj) | |
elif isinstance(obj, np.floating): # any float type | |
return float(obj) | |
elif isinstance(obj, np.ndarray): | |
# process several special cases: zeros, diagonal, identity | |
values = None | |
if is_zeros(obj): | |
constructor = 'zeros' | |
elif is_identity(obj): | |
constructor = 'identity' | |
elif is_diagonal(obj): | |
constructor = 'diag' | |
values = np.diagonal(obj).tolist() | |
else: # need to express as full array | |
constructor = 'array' | |
values = obj.tolist() | |
return { | |
'dtype': f'{obj.dtype}', | |
'constructor': constructor, | |
'values': values, | |
'shape': obj.shape | |
} | |
return json.JSONEncoder.default(self, obj) | |
def decode_object(decoded_dict): | |
if 'dtype' in decoded_dict: | |
constructor_name = decoded_dict['constructor'] | |
if constructor_name == 'array': | |
return np.array(decoded_dict['values'], dtype=decoded_dict['dtype']) | |
elif constructor_name == 'diag': | |
return np.diag(decoded_dict['values']).astype(decoded_dict['dtype']) | |
elif constructor_name == 'identity': | |
return np.identity(decoded_dict['shape'][0], dtype=decoded_dict['dtype']) | |
else: | |
constructor = getattr(np, constructor_name) | |
return constructor(decoded_dict['shape'], dtype=decoded_dict['dtype']) | |
return decoded_dict | |
def to_json(obj): | |
return json.dumps(obj, indent=4, cls=NumpyEncoder) | |
def from_json(json_blob): | |
return json.loads(object_hook=decode_object) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment