Created
August 27, 2019 21:37
-
-
Save andy-d/b7878d0044a4242c0498ed6d67fd50fe to your computer and use it in GitHub Desktop.
json serializable base class for python serialization and deserialization backcompat and all that
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import datetime | |
import json | |
import logging | |
import os | |
import re | |
import socket | |
import sys | |
from com.easter.egg.some_abstract_logging_class import SomeAbstractLoggingClass | |
class JsonSerializable(SomeAbstractLoggingClass): | |
SERIAL_VERSION_ID = 1 # bump this to signal work needs to be done on deserialize | |
SERIAL_VERSION_JSON_KEY = '__serial_version_id' | |
DEPRECATED_SERIAL_VERSION_JSON_KEYS = ['__serial_version_id__'] | |
SERIALIZATION_TIMESTAMP_TAG = '__serialization_timestamp' | |
SERIALIZED_CLASS_NAME_TAG = '__serialized_class_name' | |
HOSTNAME_WHERE_SERIALIZED_TAG = '__hostname_where_serialized' | |
INTERPRETER_USED_TO_SERIALIALIZE = '__interpreter_used_to_serialize' | |
LOG_ATTR = 'log' | |
ENCODE_DURING_TRANSPORT = False | |
OMIT_FROM_COMPARISONS = [ | |
SERIAL_VERSION_JSON_KEY, | |
SERIALIZATION_TIMESTAMP_TAG, | |
SERIALIZED_CLASS_NAME_TAG, | |
LOG_ATTR, | |
HOSTNAME_WHERE_SERIALIZED_TAG, | |
INTERPRETER_USED_TO_SERIALIALIZE | |
] | |
OMIT_FROM_COMPARISONS += DEPRECATED_SERIAL_VERSION_JSON_KEYS | |
def embed_serial_metadata(self, dct): | |
dct[JsonSerializable.SERIAL_VERSION_JSON_KEY] = self.__class__.get_current_serial_version_id() | |
dct[JsonSerializable.SERIALIZATION_TIMESTAMP_TAG] = datetime.datetime.now().isoformat() | |
dct[JsonSerializable.SERIALIZED_CLASS_NAME_TAG] = self.__class__.__name__ | |
dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG] = socket.gethostname() | |
dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE] = sys.executable | |
@classmethod | |
def attributes_to_set_to_none_when_serializing(cls): | |
return [] | |
def _top_secret_pre_serialization_hook_DO_NOT_OVERRIDE(self, outgoing_dict): | |
outgoing_dict[JsonSerializable.LOG_ATTR] = None | |
for attr in self.attributes_to_set_to_none_when_serializing(): | |
assert attr in outgoing_dict, "attr {} not valid, serial migration may be necessary".format(attr) | |
outgoing_dict[attr] = None | |
def _top_secret_post_deserialization_hook_DO_NOT_OVERRIDE(self): | |
if hasattr(self, 'log') and getattr(self, 'log'): | |
return | |
self.log = self.get_logger() | |
def write_to(self, file_path): | |
serialized = self.serialize() | |
with open(file_path, 'w') as _o: | |
_o.write(serialized) | |
def serialize(self): | |
""" | |
default implementation, override as necessary but be sure to include the version id step (see below) | |
""" | |
try: | |
dct = self.__dict__.copy() | |
self.pre_serialization_hook(dct) | |
self._top_secret_pre_serialization_hook_DO_NOT_OVERRIDE(dct) | |
self.embed_serial_metadata(dct) | |
b = set(self.attribute_serialization_blacklist() or []) | |
w = set(self.attribute_serialization_whitelist() or []) | |
if w: | |
if b: | |
raise JsonSerializationError("Blacklist and whitelist cannot both be present") | |
dct = {k: v for k, v in dct.items() if k in w} | |
else: | |
dct = {k: v for k, v in dct.items() if k not in b} | |
as_string = json.dumps(dct, indent=2, sort_keys=True, cls=JsonSerializableEncoder) | |
if self.__class__.ENCODE_DURING_TRANSPORT: | |
as_string = str(base64.b64encode(as_string.encode()), encoding="utf-8") | |
return as_string | |
except Exception as e: | |
if e.args: | |
m = re.match(r"Object of type '(.*)' is not JSON serializable", e.args[0]) | |
if m: | |
object_type = m.group(1) | |
raise JsonSerializationError("{} is composite and must override serialize() to handle type {}".format( | |
self.__class__.__name__, object_type | |
)) from e | |
raise JsonSerializationError("Failed to serialize {}".format(self.__class__.__name__)) from e | |
# deserialization and serialization | |
# override these as necessary | |
@classmethod | |
def deserialize(cls, as_string): | |
""" | |
Default implementation, override as necessary | |
For example, if the class takes args this must be overridden | |
Args: | |
as_string (str): serialized object | |
""" | |
try: | |
if cls.ENCODE_DURING_TRANSPORT: | |
as_string = base64.decodebytes(as_string.encode()) | |
dct = cls.build_migrated_dictionary(as_string) | |
instance = cls.instantiate_from_state_dict(dct) | |
instance._top_secret_post_deserialization_hook_DO_NOT_OVERRIDE() | |
instance.post_deserialization_hook() | |
except Exception as e: | |
if isinstance(e, TypeError): | |
if "required positional argument" in e.args[0]: | |
message = "\n The __init__ method for {} has required arguments. Must override instantiate_from_state_dict to support this".format(cls.__name__) | |
raise JsonDeserializationError(message) | |
raise JsonDeserializationError(*e.args) from e | |
return instance | |
@classmethod | |
def instantiate_from_state_dict(cls, state_dict): | |
i = cls() | |
i.__dict__ = state_dict | |
return i | |
@classmethod | |
def build_migrated_dictionary(cls, as_string): | |
#### python 3.5 #### | |
if isinstance(as_string, bytes): | |
as_string = as_string.decode() | |
#### python 3.5 #### | |
dct = json.loads(as_string) | |
cls.pre_deserialization_hook(dct=dct) | |
assert isinstance(dct, dict) | |
cls.top_secret_pre_serial_migration_hook_DO_NOT_OVERRIDE(dct) | |
cls.serial_version_migration(dct) | |
if JsonSerializable.SERIALIZATION_TIMESTAMP_TAG in dct: | |
del dct[JsonSerializable.SERIALIZATION_TIMESTAMP_TAG] | |
if JsonSerializable.SERIAL_VERSION_JSON_KEY in dct: | |
del dct[JsonSerializable.SERIAL_VERSION_JSON_KEY] | |
if JsonSerializable.SERIALIZED_CLASS_NAME_TAG in dct: | |
del dct[JsonSerializable.SERIALIZED_CLASS_NAME_TAG] | |
if JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG in dct: | |
del dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG] | |
if JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE in dct: | |
del dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE] | |
return dct | |
@classmethod | |
def top_secret_pre_serial_migration_hook_DO_NOT_OVERRIDE(cls, dct): | |
current_version = cls.get_current_serial_version_id() | |
incoming_version = dct.get(JsonSerializable.SERIAL_VERSION_JSON_KEY, 1) | |
if incoming_version > current_version: | |
deserialized_with = "{}:{}".format(socket.gethostname(), sys.executable) | |
serialized_with = "n/a" | |
if JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG in dct: | |
assert JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE in dct | |
hostname = dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG] | |
interpreter = dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE] | |
serialized_with = "{}:{}".format(hostname, interpreter) | |
msg = "server/client {} version mismatch error.\n\tInterpreter {}\n\tcannot deserialize {} that was serialized by\n\tinterpreter {}\n\tEither the former requires an upgrade or the latter requires a downgrade.".format( | |
cls.__name__, | |
deserialized_with, | |
cls.__name__, | |
serialized_with | |
) | |
raise JsonDeserializationError(msg) | |
def serialize_to(self, file_path): | |
with open(file_path, 'wb') as io: | |
io.write(self.serialize().encode()) | |
def as_dict(self): | |
return json.loads(self.serialize()) # safe means of getting a dict that uses any specialized logic | |
@staticmethod | |
def attribute_serialization_blacklist(): | |
""" | |
Returns: | |
(list[str] or None) | |
""" | |
return None | |
@staticmethod | |
def attribute_serialization_whitelist(): | |
return None | |
@classmethod | |
def serial_version_migration(cls, dct): | |
""" | |
when bumping version this must be overridden to accommodate | |
Args: | |
dct (dict): your class dict as deserialized, may be out of date | |
""" | |
current_version = cls.get_current_serial_version_id() | |
incoming_version = dct.get(JsonSerializable.SERIAL_VERSION_JSON_KEY, current_version) | |
if incoming_version == current_version: | |
return | |
raise NotImplementedError("{} has new serial version, must override {}".format(cls.__name__, JsonSerializable.serial_version_migration.__name__)) | |
@classmethod | |
def get_current_serial_version_id(cls): | |
return cls.SERIAL_VERSION_ID | |
def pre_serialization_hook(self, dct): | |
pass | |
@classmethod | |
def pre_deserialization_hook(cls, dct): | |
pass # override to update the dct before deserialize | |
def post_deserialization_hook(self): | |
pass | |
def get_additional_attributes_to_exclude_from_equality_comparison(self): | |
return [] | |
@classmethod | |
def deserialize_from_file_path(cls, local_file_path): | |
assert isinstance(local_file_path, str) | |
if not os.path.isfile(local_file_path): | |
raise FileNotFoundError(local_file_path + " could not be found") | |
with open(local_file_path) as _i: | |
return cls.deserialize(_i.read()) | |
def compare_with(self, other): | |
diffs = [] | |
if not type(self) == type(other): | |
diffs.append("Type Mismatch, {} vs {}".format(type(self), type(other))) | |
return diffs | |
excluded_attrs = self.OMIT_FROM_COMPARISONS + self.get_additional_attributes_to_exclude_from_equality_comparison() | |
ldct = {k: v for k, v in self.__dict__.items() if k not in excluded_attrs} | |
rdct = {k: v for k, v in other.__dict__.items() if k not in excluded_attrs} | |
if ldct.keys() != rdct.keys(): | |
a, b = sorted(list(ldct.keys())), sorted(list(rdct.keys())) | |
diffs.append("Atribute List Mismatch\n\tSelf:{}\n\tOther:{}".format(a, b)) | |
common_keys = set(ldct) & set(rdct) | |
for k in common_keys: | |
l, r = ldct[k], rdct[k] | |
try: | |
eq = l == r | |
except: | |
self.log.exception("Failed to compare values for key {} due to".format(k)) | |
raise | |
if not eq: | |
diffs.append("Atribute Mismatch\n\tKey:{}\n\tSelf:{}\n\tOther:{}".format(k, l, r)) | |
else: | |
if self.log.isEnabledFor(logging.DEBUG): | |
self.log.debug("Attribute {} matches for instances of {}".format(k, self.__class__.__name__)) | |
return diffs | |
def __eq__(self, other): | |
diffs = self.compare_with(other) | |
if diffs: | |
self.log.warning("Equality check failure for {}".format(self)) | |
for d in diffs: | |
self.log.warning(d) | |
return False | |
return True | |
class JsonSerializationError(Exception): | |
pass | |
class JsonDeserializationError(NotImplementedError): | |
pass | |
class SerialVersionError(NotImplementedError): | |
pass | |
class JsonSerializableEncoder(json.JSONEncoder): | |
def default(self, o): | |
if isinstance(o, JsonSerializable): | |
return o.serialize() | |
if isinstance(o, datetime.datetime): | |
return o.isoformat() | |
return json.dumps(o, indent=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment