Created
June 10, 2021 13:25
-
-
Save zitterbewegung/589a869f59ae54b32c964c08c2f7bb80 to your computer and use it in GitHub Desktop.
Can't figure out the shape of an input to bert.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertModel, BertTokenizer, BertConfig | |
import torch | |
import coremltools as ct | |
import torch | |
import torchvision | |
enc = BertTokenizer.from_pretrained("bert-base-uncased") | |
# Tokenizing input text | |
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" | |
tokenized_text = enc.tokenize(text) | |
# Masking one of the input tokens | |
masked_index = 8 | |
tokenized_text[masked_index] = '[MASK]' | |
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) | |
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] | |
# Creating a dummy input | |
tokens_tensor = torch.tensor([indexed_tokens]) | |
segments_tensors = torch.tensor([segments_ids]) | |
dummy_input = [tokens_tensor, segments_tensors] | |
# Initializing the model with the torchscript flag | |
# Flag set to True even though it is not necessary as this model does not have an LM Head. | |
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True) | |
# Instantiating the model | |
model = BertModel(config) | |
# The model needs to be in evaluation mode | |
model.eval() | |
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag | |
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) | |
# Creating the trace | |
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) | |
torch.jit.save(traced_model, "traced_bert.pt") | |
## Get a pytorch model and save it as a *.pt file | |
#model = torchvision.models.mobilenet_v2() | |
#model.eval() | |
#example_input = torch.rand(1, 3, 224, 224) | |
#traced_model = torch.jit.trace(model, example_input) | |
#traced_model.save("torchvision_mobilenet_v2.pt") | |
#from torchsummary import summary | |
#summary(traced_model, ) #I think torchsummary can tell me this but i'm not sure. | |
# Convert the saved PyTorch model to Core ML | |
breakpoint() | |
mlmodel = ct.convert("traced_bert.pt", inputs=[???]) #Error is here I don't know how to determine the shape of the input. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment