Skip to content

Instantly share code, notes, and snippets.

@JohnGiorgi
Last active July 14, 2021 19:32
Show Gist options
  • Save JohnGiorgi/6930320f36f21cce501514a689fbb907 to your computer and use it in GitHub Desktop.
Save JohnGiorgi/6930320f36f21cce501514a689fbb907 to your computer and use it in GitHub Desktop.
The accompanying gist for the blog post A Dead Simple Example of Fine-tuning BERT with AllenNLP.
import logging
from typing import Dict
from allennlp.data.dataset_readers import SnliReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field
from allennlp.data.fields import LabelField
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Tokenizer
from overrides import overrides
logger = logging.getLogger(__name__)
@DatasetReader.register("bert_snli")
class BertSnliReader(SnliReader):
"""
Reads a file from the Stanford Natural Language Inference (SNLI) dataset. This data is
formatted as jsonl, one json-formatted instance per line. The keys in the data are
"gold_label", "sentence1", and "sentence2". We convert these keys into fields named "label",
and "tokens".
Parameters
----------
tokenizer : ``Tokenizer``, optional (default=``SpacyTokenizer()``)
We use this ``Tokenizer`` for both the premise and the hypothesis. See :class:`Tokenizer`.
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
We similarly use this for both the premise and the hypothesis. See :class:`TokenIndexer`.
"""
def __init__(
self,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
lazy: bool = False,
) -> None:
super(BertSnliReader, self).__init__(tokenizer, token_indexers, lazy)
@overrides
def text_to_instance(
self, # type: ignore
premise: str,
hypothesis: str,
label: str = None,
) -> Instance:
fields: Dict[str, Field] = {}
premise_tokens = self._tokenizer.tokenize(premise)
hypothesis_tokens = self._tokenizer.tokenize(hypothesis)
# Here, we join the premise with the hypothesis, dropping the CLS token from the hypothesis.
# This gives us our desired inputs: "[CLS] premise [SEP] hypothesis [SEP]"
tokens = premise_tokens + hypothesis_tokens[1:]
fields["tokens"] = TextField(tokens, self._token_indexers)
if label:
fields["label"] = LabelField(label)
return Instance(fields)
local bert_model = "bert-base-uncased";
{
"dataset_reader": {
"lazy": false,
"type": "bert_snli",
"tokenizer": {
"type": "pretrained_transformer",
"model_name": bert_model,
"do_lowercase": true
},
"token_indexers": {
"bert": {
"type": "bert-pretrained",
"pretrained_model": bert_model,
}
}
},
"train_data_path": "snli_1.0/snli_1.0_train.jsonl",
"validation_data_path": "snli_1.0/snli_1.0_dev.jsonl",
"model": {
"type": "bert_for_classification",
"bert_model": bert_model,
"dropout": 0.1,
"num_labels": 3,
},
"iterator": {
"type": "bucket",
"sorting_keys": [["tokens", "num_tokens"]],
"batch_size": 32
},
"trainer": {
"optimizer": {
"type": "bert_adam",
"lr": 2e-5
},
"validation_metric": "+accuracy",
"num_serialized_models_to_keep": 1,
"num_epochs": 4,
"grad_norm": 1.0,
"cuda_device": 0
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment