Created June 10, 2021 13:25
Can't figure out the shape of an input to bert.
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import coremltools as ct
import torch
import torchvision
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)
# Instantiating the model
model = BertModel(config)
# The model needs to be in evaluation mode
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]), "")
## Get a pytorch model and save it as a *.pt file
#model = torchvision.models.mobilenet_v2()
#example_input = torch.rand(1, 3, 224, 224)
#traced_model = torch.jit.trace(model, example_input)"")
#from torchsummary import summary
#summary(traced_model, ) #I think torchsummary can tell me this but i'm not sure.
# Convert the saved PyTorch model to Core ML
mlmodel = ct.convert("", inputs=[???]) #Error is here I don't know how to determine the shape of the input.
