Skip to content

Instantly share code, notes, and snippets.

Last active October 3, 2023 10:02
Show Gist options
  • Save cottrell/f3d78b27a9dcd9d47dd7fd74f1841ab1 to your computer and use it in GitHub Desktop.
Save cottrell/f3d78b27a9dcd9d47dd7fd74f1841ab1 to your computer and use it in GitHub Desktop.
Equinox module state extraction and serialization
import dataclasses
import importlib
# begin serialization lib
import io
import json
import lzma
import pickle
from base64 import b64decode, b64encode
from types import FunctionType
import cloudpickle
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Int
# NOTE: see
def save_model_state(model, filename, array_flavour='tolist'):
Save a model to a "json" file with arrays encoded according to array_flavour.
- array_flavour: one of 'tolist', 'save', 'save_xz_b64'
params = recurse_get_state(model)
jsonifiable = params_to_jsonifiable(params)
print(f'saving to {filename}')
json.dump(jsonifiable, open(filename, 'w'))
def load_model_state(filename):
jsonifiable = json.load(open(filename))
params = jsonifiable_to_params(jsonifiable)
model = reconstitute(params)
return model
def io_helper(f_save):
def inner(x):
fout = io.BytesIO()
f_save(fout, x)
return inner
np_save = io_helper(
jnp_save = io_helper(
def _maybe_json_loads(x):
if isinstance(x, str):
if x in ('True', 'False'):
return bool(x)
return json.loads(x)
return x
serializers_deserializers = {
'np_save': {'write': lambda x: b64encode(np_save(x)).decode(), 'read': lambda x: np.load(io.BytesIO(b64decode(x)))},
'jnp_save': {
'write': lambda x: b64encode(jnp_save(x)).decode(),
'read': lambda x: jnp.load(io.BytesIO(b64decode(x))),
'np_save_xz_b64': {
'write': lambda x: b64encode(lzma.compress(np_save(x))).decode(),
'read': lambda x: np.load(io.BytesIO(lzma.decompress(b64decode(x)))),
'jnp_save_xz_b64': {
'write': lambda x: b64encode(lzma.compress(jnp_save(x))).decode(),
'read': lambda x: jnp.load(io.BytesIO(lzma.decompress(b64decode(x)))),
# NOTE: these are pretty awful now as they are not even really json anymore with this str: prefix thing
'np_tolist': {'write': lambda x: x.tolist(), 'read': lambda x: np.array(_maybe_json_loads(x))},
'jnp_tolist': {'write': lambda x: x.tolist(), 'read': lambda x: jnp.array(_maybe_json_loads(x))},
'pickle': {'write': lambda x: b64encode(pickle.dumps(x)).decode(), 'read': lambda x: pickle.loads(b64decode(x))},
'cloudpickle': {
'write': lambda x: b64encode(cloudpickle.dumps(x)).decode(),
'read': lambda x: cloudpickle.loads(b64decode(x)),
def params_to_jsonifiable(params, array_flavour='tolist'):
Dict of params to something that shoudl be jsonifiable. Arrays handled according to array_flavour.
- array_flavour: one of 'tolist', 'save', 'save_xz_b64'
# NOTE: probably awful just do something for now ... look for someone to have done something sane on the jax side
# that isn't pickle. Likely the equinox pattern with some way to get at the typing would be fine.
def inner(x):
if isinstance(x, jax.Array):
key = {'tolist': 'jnp_tolist', 'save': 'jnp_save', 'save_xz_b64': 'jnp_save_xz_b64'}[array_flavour]
fun = serializers_deserializers[key]['write']
return f'{key}:{fun(x)}'
elif isinstance(x, np.ndarray):
key = {'tolist': 'np_tolist', 'save': 'np_save', 'save_xz_b64': 'np_save_xz_b64'}[array_flavour]
fun = serializers_deserializers[key]['write']
return f'{key}:{fun(x)}'
elif isinstance(x, FunctionType):
# NOTE: bad but just for functions not sure what else to do here
fun = serializers_deserializers['cloudpickle']['write']
return f'cloudpickle:{fun(x)}'
except TypeError:
fun = serializers_deserializers['cloudpickle']['write']
return f'cloudpickle:{fun(x)}'
return x # f'json:{x}'
return jax.tree_map(inner, params)
def jsonifiable_to_params(jsonifiable):
def inner(x):
if not isinstance(x, str):
return x
args = x.split(':', 1)
if len(args) == 1:
return x
key, val = args
fun = serializers_deserializers[key]['read']
return fun(val)
return jax.tree_map(inner, jsonifiable)
# end serialization lib
def recurse_get_state(x):
# NOTE: this is a somewhat custom recursion due to eqx.Module detection
if isinstance(x, eqx.Module):
# return {'module': {(x.__class__.__module__, x.__class__.__qualname__): recurse_get_state(x.__getstate__())}}
# NOTE: some libraries like msgpack do not allow non-string dictionary keys so let's just are MORE NESTING
return {'module': {x.__class__.__module__: {x.__class__.__qualname__: recurse_get_state(x.__getstate__())}}}
elif isinstance(x, dict):
# TODO: review this, symptom was in diffrax test got
# dict_keys(['t0', 't1', 'ts', 'ys', 'interpolation', 'stats', 'result', 'solver_state', 'controller_state',
# 'made_jump', '__doc__', '__annotations__', '__module__'])
# comment out and uncomment below two lines to see error in test_diffrax
return {'dict': {k: recurse_get_state(v) for k, v in x.items() if not k.startswith('__')}}
# return {'dict': {k: recurse_get_state(v) for k, v in x.items()}}
elif isinstance(x, list):
return [recurse_get_state(v) for v in x]
elif isinstance(x, tuple):
return tuple(recurse_get_state(v) for v in x)
return x
def init_from_state_params(class_, params):
module = object.__new__(class_)
fieldnames = { for f in dataclasses.fields(class_)}
if params is None:
assert len(fieldnames) == 0
assert set(params.keys()) == fieldnames
for key, value in params.items():
object.__setattr__(module, key, value)
return module
def get_object_from_module_and_qualname(module_name, qualname):
module = importlib.import_module(module_name)
obj = module
for attr in qualname.split('.'):
obj = getattr(obj, attr)
return obj
def reconstitute_from_root(params):
out = None
if isinstance(params, dict):
assert len(params) == 1
k, v = list(params.items())[0]
if k == 'module':
assert len(v) == 1
module, v = list(v.items())[0]
assert len(v) == 1
qualname, v = list(v.items())[0]
class_ = get_object_from_module_and_qualname(module, qualname)
params_ = reconstitute_from_root(v)
out = init_from_state_params(class_, params_)
elif k == 'dict':
out = {k_: reconstitute_from_root(v_) for k_, v_ in v.items()}
raise Exception(f'unknown key {k}')
elif isinstance(params, list):
out = [reconstitute_from_root(v) for v in params]
elif isinstance(params, tuple):
out = tuple(reconstitute_from_root(v) for v in params)
out = params
return out
def reconstitute(params):
module = reconstitute_from_root(params)
return module
if len(module) == 1:
return module[list(module.keys())[0]]
def serialization_test_fun(params):
"""params comes from recurse_get_state"""
for array_flavour in ['tolist', 'save', 'save_xz_b64']:
jsonifiable = params_to_jsonifiable(params, array_flavour=array_flavour)
string_ = json.dumps(jsonifiable)
jsonifiable_ = json.loads(string_)
check = check_identical(jsonifiable, jsonifiable_)
params_ = jsonifiable_to_params(jsonifiable_)
check = check_identical(params, params_)
if not check:
assert check_identical(params, params_), f'array_flavour={array_flavour} failed'
def tuple_to_list(tree):
if isinstance(tree, tuple):
return [tuple_to_list(elem) for elem in tree]
elif isinstance(tree, list):
return [tuple_to_list(elem) for elem in tree]
elif isinstance(tree, dict):
return {key: tuple_to_list(value) for key, value in tree.items()}
return tree
def check_identical(tree1, tree2):
def compare_elements(x, y):
if isinstance(x, FunctionType):
return x.__code__.co_code == y.__code__.co_code
return jnp.all(x == y)
comparison_tree = jax.tree_map(compare_elements, tree1, tree2)
all_identical = all(jax.tree_util.tree_flatten(comparison_tree)[0])
return all_identical
def check_identical_with_debug(tree1, tree2):
disagreements = []
def compare_elements(x, y):
if isinstance(x, FunctionType):
identical = x.__code__.co_code == y.__code__.co_code
identical = jnp.all(x == y)
if not identical:
disagreements.append((x, y))
return identical
comparison_tree = jax.tree_map(compare_elements, tree1, tree2)
all_identical = all(jax.tree_util.tree_flatten(comparison_tree)[0])
print(f"all_identical: {all_identical}")
if not all_identical:
print("Disagreeing elements:")
for x, y in disagreements:
print(f"x: {x}, y: {y}")
return all_identical
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
class Another(eqx.Module):
layers: list
def __init__(self, n, in_size, out_size, key):
self.layers = [Linear(in_size, out_size, key) for _ in range(n)]
def test_simple():
key = jax.random.PRNGKey(0)
in_size = 12
out_size = 3
n = 5
a = Another(n, in_size, out_size, key)
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'failed'
class Func(eqx.Module):
func: FunctionType
def __init__(self):
self.func = lambda x: x
def test_func():
a = Func()
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'failed'
def test_lineax():
from lineax import CG, GMRES, LU, QR, SVD, BiCGStab, Diagonal, NormalCG, Triangular, Tridiagonal
for module_ in [BiCGStab, CG, GMRES, NormalCG]:
a = module_(atol=1e-3, rtol=1e-4)
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'{module_} failed'
for module_ in [Diagonal, LU, QR, SVD, Triangular, Tridiagonal]:
a = module_()
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'{module_} failed'
def test_diffrax():
from diffrax import Dopri5, ODETerm, diffeqsolve
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2.0, 3.0])
a = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'diffrax failed'
class Model_stateful(eqx.Module):
norm1: eqx.nn.BatchNorm
spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
norm2: eqx.nn.BatchNorm
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
self.spectral_linear = eqx.nn.SpectralNorm(
layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)
def __call__(self, x, state):
x, state = self.norm1(x, state)
x, state = self.spectral_linear(x, state)
x = jax.nn.relu(x)
x, state = self.norm2(x, state)
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x, state
def test_stateful():
# from
key = jax.random.PRNGKey(0)
a = Model_stateful(key=key)
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'stateful failed'
# TODO: NOTE: abolish tuples as they are not json serializable round trip
params = tuple_to_list(params)
class LanguageModel_shared(eqx.Module):
shared: eqx.nn.Shared
def __init__(self, *, key):
embedding = eqx.nn.Embedding(num_embeddings=3, embedding_size=4, key=key)
linear = eqx.nn.Linear(in_features=4, out_features=3, key=key)
# These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)
def __call__(self, tokens: Int[Array, "sequence"]):
# Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight # same parameter!
# Now go ahead and evaluate your language model.
values = jax.vmap(embedding)(tokens)
# ... # other layers, probably
return jax.vmap(linear)(values)
def test_shared():
# from
key = jax.random.PRNGKey(0)
a = LanguageModel_shared(key=key)
params = recurse_get_state(a)
b = reconstitute(params)
assert check_identical(a, b), f'stateful failed'
def test_all():
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment