Skip to content

Instantly share code, notes, and snippets.

@maxrohleder
Last active January 13, 2023 15:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxrohleder/4ec505ea615feb06a4c2c0f73995f0e1 to your computer and use it in GitHub Desktop.
Save maxrohleder/4ec505ea615feb06a4c2c0f73995f0e1 to your computer and use it in GitHub Desktop.
update
# set new weights from loaded tf values
with torch.no_grad():
for (name, param), (tf_name, tf_param) in zip(m.named_parameters(), tf_weights.items()):
# convert NHWC to NCHW format and copy to change memory layout
tf_param = np.transpose(tf_param, (3, 2, 0, 1)).copy() if len(tf_param.shape) == 4 else tf_param
assert tf_param.shape == param.detach().numpy().shape, name
# https://discuss.pytorch.org/t/how-to-assign-an-arbitrary-tensor-to-models-parameter/44082/3
param.copy_(torch.tensor(tf_param, requires_grad=True, dtype=param.dtype))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment