Skip to content

Instantly share code, notes, and snippets.

@zitterbewegung
Created June 10, 2021 13:25
Show Gist options
  • Save zitterbewegung/589a869f59ae54b32c964c08c2f7bb80 to your computer and use it in GitHub Desktop.
Save zitterbewegung/589a869f59ae54b32c964c08c2f7bb80 to your computer and use it in GitHub Desktop.
Can't figure out the shape of an input to bert.
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import coremltools as ct
import torch
import torchvision
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)
# Instantiating the model
model = BertModel(config)
# The model needs to be in evaluation mode
model.eval()
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
## Get a pytorch model and save it as a *.pt file
#model = torchvision.models.mobilenet_v2()
#model.eval()
#example_input = torch.rand(1, 3, 224, 224)
#traced_model = torch.jit.trace(model, example_input)
#traced_model.save("torchvision_mobilenet_v2.pt")
#from torchsummary import summary
#summary(traced_model, ) #I think torchsummary can tell me this but i'm not sure.
# Convert the saved PyTorch model to Core ML
breakpoint()
mlmodel = ct.convert("traced_bert.pt", inputs=[???]) #Error is here I don't know how to determine the shape of the input.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment