Skip to content

Instantly share code, notes, and snippets.

@danieldk
Last active December 7, 2020 15:10
Show Gist options
  • Save danieldk/fe4a98ee8fae3af10c38f3975258b1bb to your computer and use it in GitHub Desktop.
Save danieldk/fe4a98ee8fae3af10c38f3975258b1bb to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class TensorModule(nn.Module):
def __init__(self, tensors):
super(TensorModule, self).__init__()
for tensor_name, tensor in tensors.items():
setattr(self, tensor_name, nn.Parameter(tensor))
# Then, given some string -> Tensor dictionary.
wrapper = TensorModule(tensor_dict)
script = torch.jit.script(wrapper)
script.save(args.tensors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment