Skip to content

Instantly share code, notes, and snippets.

@driazati
Created February 12, 2019 00:32
Show Gist options
  • Save driazati/72d411556d9766dd550f65e669adec75 to your computer and use it in GitHub Desktop.
Save driazati/72d411556d9766dd550f65e669adec75 to your computer and use it in GitHub Desktop.
from torch import nn, jit
from typing import Dict
@jit.script
def lookup_all(d, nested_keys):
# type: (Dict[str, Tensor], List[str]) -> List[Tensor]
tensors = []
for i in range(len(nested_keys)):
key = nested_keys[i]
tensors.append(d[key])
return tensors
class MyDictModel(jit.ScriptModule):
@jit.script_method
def forward(self, keys, lookup):
# type: (List[str], Dict[str, Tensor]) -> List[Tensor]
return lookup_all(lookup, keys)
vocab = {chr(o): torch.LongTensor(o) for o in range(ord('a'), ord('z') + 1)}
model = MyDictModel()
input = list('cheese')
print(model(input, vocab))
model.save("my_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment