Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created July 29, 2021 13:04
Show Gist options
  • Save enijkamp/aaacc540ebefccc1bc5fef457e54ddbe to your computer and use it in GitHub Desktop.
Save enijkamp/aaacc540ebefccc1bc5fef457e54ddbe to your computer and use it in GitHub Desktop.
reshard.py
def apply_reshard(pytree_params_in, pytree_params_out, shards_in, shards_out):
def override_dtype(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return x
def is_leaf(x):
return type(x) == np.ndarray
def traverse(pysubtree_in, pysubtree_out, path=''):
tree_out = {}
for k in pysubtree_out.keys():
tree_in_k = pysubtree_in[k]
tree_out_k = pysubtree_out[k]
if is_leaf(tree_in_k):
print(f'{path}/{k}')
print(tree_in_k.shape)
print(tree_out_k.shape)
leave_out_reshard = reshard(override_dtype(np.stack(tree_in_k)), tree_out_k.shape)
tree_out[k] = leave_out_reshard
else:
tree_out[k] = traverse(tree_in_k, tree_out_k, f'{path}/{k}')
return tree_out
return traverse(pysubtree_in=pytree_params_in, pysubtree_out=pytree_params_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment