Skip to content

Instantly share code, notes, and snippets.

@EricWiener
Last active December 30, 2023 02:00
Show Gist options
  • Save EricWiener/5a3f5facb2362763fe582c167a915d44 to your computer and use it in GitHub Desktop.
Save EricWiener/5a3f5facb2362763fe582c167a915d44 to your computer and use it in GitHub Desktop.
Save tensors to pickle file
def save_tensors_to_pickle(output_file: str, to_numpy: bool, **kwargs) -> None:
"""Save tensors to pickle file.
Example usage:
```
save_tensors_to_pickle(
"train_output.pkl",
to_numpy=False,
inputs=inputs,
logits=logits,
)
```
Args:
output_file (str): output file path.
to_numpy (bool): whether to convert tensors to numpy.
**kwargs: tensors to save.
"""
import pickle
data = {}
for key, value in kwargs.items():
try:
data[key] = value.cpu().detach()
if to_numpy:
data[key] = data[key].numpy()
except Exception:
pass
with open(output_file, "wb+") as f:
pickle.dump(data, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment