Skip to content

Instantly share code, notes, and snippets.

@icoxfog417
Created April 8, 2019 01:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save icoxfog417/a9dd4a0299b06c3effb7446fcea3f2e1 to your computer and use it in GitHub Desktop.
Save icoxfog417/a9dd4a0299b06c3effb7446fcea3f2e1 to your computer and use it in GitHub Desktop.
allennlp_tutorial.py
class LstmTagger(Model):
def __init__(self,
word_embeddings: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
vocab: Vocabulary) -> None:
super().__init__(vocab)
self.word_embeddings = word_embeddings
self.encoder = encoder
self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
out_features=vocab.get_vocab_size('labels'))
self.accuracy = CategoricalAccuracy()
def forward(self,
sentence: Dict[str, torch.Tensor],
labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
mask = get_text_field_mask(sentence)
embeddings = self.word_embeddings(sentence)
encoder_out = self.encoder(embeddings, mask)
tag_logits = self.hidden2tag(encoder_out)
output = {"tag_logits": tag_logits}
if labels is not None:
self.accuracy(tag_logits, labels, mask)
output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)
return output
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {"accuracy": self.accuracy.get_metric(reset)}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment