Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
json serializable base class for python serialization and deserialization backcompat and all that
import base64
import datetime
import json
import logging
import os
import re
import socket
import sys
from com.easter.egg.some_abstract_logging_class import SomeAbstractLoggingClass
class JsonSerializable(SomeAbstractLoggingClass):
SERIAL_VERSION_ID = 1 # bump this to signal work needs to be done on deserialize
SERIAL_VERSION_JSON_KEY = '__serial_version_id'
DEPRECATED_SERIAL_VERSION_JSON_KEYS = ['__serial_version_id__']
SERIALIZATION_TIMESTAMP_TAG = '__serialization_timestamp'
SERIALIZED_CLASS_NAME_TAG = '__serialized_class_name'
HOSTNAME_WHERE_SERIALIZED_TAG = '__hostname_where_serialized'
INTERPRETER_USED_TO_SERIALIALIZE = '__interpreter_used_to_serialize'
LOG_ATTR = 'log'
ENCODE_DURING_TRANSPORT = False
OMIT_FROM_COMPARISONS = [
SERIAL_VERSION_JSON_KEY,
SERIALIZATION_TIMESTAMP_TAG,
SERIALIZED_CLASS_NAME_TAG,
LOG_ATTR,
HOSTNAME_WHERE_SERIALIZED_TAG,
INTERPRETER_USED_TO_SERIALIALIZE
]
OMIT_FROM_COMPARISONS += DEPRECATED_SERIAL_VERSION_JSON_KEYS
def embed_serial_metadata(self, dct):
dct[JsonSerializable.SERIAL_VERSION_JSON_KEY] = self.__class__.get_current_serial_version_id()
dct[JsonSerializable.SERIALIZATION_TIMESTAMP_TAG] = datetime.datetime.now().isoformat()
dct[JsonSerializable.SERIALIZED_CLASS_NAME_TAG] = self.__class__.__name__
dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG] = socket.gethostname()
dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE] = sys.executable
@classmethod
def attributes_to_set_to_none_when_serializing(cls):
return []
def _top_secret_pre_serialization_hook_DO_NOT_OVERRIDE(self, outgoing_dict):
outgoing_dict[JsonSerializable.LOG_ATTR] = None
for attr in self.attributes_to_set_to_none_when_serializing():
assert attr in outgoing_dict, "attr {} not valid, serial migration may be necessary".format(attr)
outgoing_dict[attr] = None
def _top_secret_post_deserialization_hook_DO_NOT_OVERRIDE(self):
if hasattr(self, 'log') and getattr(self, 'log'):
return
self.log = self.get_logger()
def write_to(self, file_path):
serialized = self.serialize()
with open(file_path, 'w') as _o:
_o.write(serialized)
def serialize(self):
"""
default implementation, override as necessary but be sure to include the version id step (see below)
"""
try:
dct = self.__dict__.copy()
self.pre_serialization_hook(dct)
self._top_secret_pre_serialization_hook_DO_NOT_OVERRIDE(dct)
self.embed_serial_metadata(dct)
b = set(self.attribute_serialization_blacklist() or [])
w = set(self.attribute_serialization_whitelist() or [])
if w:
if b:
raise JsonSerializationError("Blacklist and whitelist cannot both be present")
dct = {k: v for k, v in dct.items() if k in w}
else:
dct = {k: v for k, v in dct.items() if k not in b}
as_string = json.dumps(dct, indent=2, sort_keys=True, cls=JsonSerializableEncoder)
if self.__class__.ENCODE_DURING_TRANSPORT:
as_string = str(base64.b64encode(as_string.encode()), encoding="utf-8")
return as_string
except Exception as e:
if e.args:
m = re.match(r"Object of type '(.*)' is not JSON serializable", e.args[0])
if m:
object_type = m.group(1)
raise JsonSerializationError("{} is composite and must override serialize() to handle type {}".format(
self.__class__.__name__, object_type
)) from e
raise JsonSerializationError("Failed to serialize {}".format(self.__class__.__name__)) from e
# deserialization and serialization
# override these as necessary
@classmethod
def deserialize(cls, as_string):
"""
Default implementation, override as necessary
For example, if the class takes args this must be overridden
Args:
as_string (str): serialized object
"""
try:
if cls.ENCODE_DURING_TRANSPORT:
as_string = base64.decodebytes(as_string.encode())
dct = cls.build_migrated_dictionary(as_string)
instance = cls.instantiate_from_state_dict(dct)
instance._top_secret_post_deserialization_hook_DO_NOT_OVERRIDE()
instance.post_deserialization_hook()
except Exception as e:
if isinstance(e, TypeError):
if "required positional argument" in e.args[0]:
message = "\n The __init__ method for {} has required arguments. Must override instantiate_from_state_dict to support this".format(cls.__name__)
raise JsonDeserializationError(message)
raise JsonDeserializationError(*e.args) from e
return instance
@classmethod
def instantiate_from_state_dict(cls, state_dict):
i = cls()
i.__dict__ = state_dict
return i
@classmethod
def build_migrated_dictionary(cls, as_string):
#### python 3.5 ####
if isinstance(as_string, bytes):
as_string = as_string.decode()
#### python 3.5 ####
dct = json.loads(as_string)
cls.pre_deserialization_hook(dct=dct)
assert isinstance(dct, dict)
cls.top_secret_pre_serial_migration_hook_DO_NOT_OVERRIDE(dct)
cls.serial_version_migration(dct)
if JsonSerializable.SERIALIZATION_TIMESTAMP_TAG in dct:
del dct[JsonSerializable.SERIALIZATION_TIMESTAMP_TAG]
if JsonSerializable.SERIAL_VERSION_JSON_KEY in dct:
del dct[JsonSerializable.SERIAL_VERSION_JSON_KEY]
if JsonSerializable.SERIALIZED_CLASS_NAME_TAG in dct:
del dct[JsonSerializable.SERIALIZED_CLASS_NAME_TAG]
if JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG in dct:
del dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG]
if JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE in dct:
del dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE]
return dct
@classmethod
def top_secret_pre_serial_migration_hook_DO_NOT_OVERRIDE(cls, dct):
current_version = cls.get_current_serial_version_id()
incoming_version = dct.get(JsonSerializable.SERIAL_VERSION_JSON_KEY, 1)
if incoming_version > current_version:
deserialized_with = "{}:{}".format(socket.gethostname(), sys.executable)
serialized_with = "n/a"
if JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG in dct:
assert JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE in dct
hostname = dct[JsonSerializable.HOSTNAME_WHERE_SERIALIZED_TAG]
interpreter = dct[JsonSerializable.INTERPRETER_USED_TO_SERIALIALIZE]
serialized_with = "{}:{}".format(hostname, interpreter)
msg = "server/client {} version mismatch error.\n\tInterpreter {}\n\tcannot deserialize {} that was serialized by\n\tinterpreter {}\n\tEither the former requires an upgrade or the latter requires a downgrade.".format(
cls.__name__,
deserialized_with,
cls.__name__,
serialized_with
)
raise JsonDeserializationError(msg)
def serialize_to(self, file_path):
with open(file_path, 'wb') as io:
io.write(self.serialize().encode())
def as_dict(self):
return json.loads(self.serialize()) # safe means of getting a dict that uses any specialized logic
@staticmethod
def attribute_serialization_blacklist():
"""
Returns:
(list[str] or None)
"""
return None
@staticmethod
def attribute_serialization_whitelist():
return None
@classmethod
def serial_version_migration(cls, dct):
"""
when bumping version this must be overridden to accommodate
Args:
dct (dict): your class dict as deserialized, may be out of date
"""
current_version = cls.get_current_serial_version_id()
incoming_version = dct.get(JsonSerializable.SERIAL_VERSION_JSON_KEY, current_version)
if incoming_version == current_version:
return
raise NotImplementedError("{} has new serial version, must override {}".format(cls.__name__, JsonSerializable.serial_version_migration.__name__))
@classmethod
def get_current_serial_version_id(cls):
return cls.SERIAL_VERSION_ID
def pre_serialization_hook(self, dct):
pass
@classmethod
def pre_deserialization_hook(cls, dct):
pass # override to update the dct before deserialize
def post_deserialization_hook(self):
pass
def get_additional_attributes_to_exclude_from_equality_comparison(self):
return []
@classmethod
def deserialize_from_file_path(cls, local_file_path):
assert isinstance(local_file_path, str)
if not os.path.isfile(local_file_path):
raise FileNotFoundError(local_file_path + " could not be found")
with open(local_file_path) as _i:
return cls.deserialize(_i.read())
def compare_with(self, other):
diffs = []
if not type(self) == type(other):
diffs.append("Type Mismatch, {} vs {}".format(type(self), type(other)))
return diffs
excluded_attrs = self.OMIT_FROM_COMPARISONS + self.get_additional_attributes_to_exclude_from_equality_comparison()
ldct = {k: v for k, v in self.__dict__.items() if k not in excluded_attrs}
rdct = {k: v for k, v in other.__dict__.items() if k not in excluded_attrs}
if ldct.keys() != rdct.keys():
a, b = sorted(list(ldct.keys())), sorted(list(rdct.keys()))
diffs.append("Atribute List Mismatch\n\tSelf:{}\n\tOther:{}".format(a, b))
common_keys = set(ldct) & set(rdct)
for k in common_keys:
l, r = ldct[k], rdct[k]
try:
eq = l == r
except:
self.log.exception("Failed to compare values for key {} due to".format(k))
raise
if not eq:
diffs.append("Atribute Mismatch\n\tKey:{}\n\tSelf:{}\n\tOther:{}".format(k, l, r))
else:
if self.log.isEnabledFor(logging.DEBUG):
self.log.debug("Attribute {} matches for instances of {}".format(k, self.__class__.__name__))
return diffs
def __eq__(self, other):
diffs = self.compare_with(other)
if diffs:
self.log.warning("Equality check failure for {}".format(self))
for d in diffs:
self.log.warning(d)
return False
return True
class JsonSerializationError(Exception):
pass
class JsonDeserializationError(NotImplementedError):
pass
class SerialVersionError(NotImplementedError):
pass
class JsonSerializableEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, JsonSerializable):
return o.serialize()
if isinstance(o, datetime.datetime):
return o.isoformat()
return json.dumps(o, indent=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.