Skip to content

Instantly share code, notes, and snippets.

@Proteusiq
Created March 6, 2021 09:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Proteusiq/bd081d9252967b3b0cf9e389d80db30a to your computer and use it in GitHub Desktop.
Save Proteusiq/bd081d9252967b3b0cf9e389d80db30a to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
MODEL_NAME = "microsoft/deberta-xlarge-mnli"
class Model:
_shared_model = {
"softmax": nn.Softmax(dim=1),
"tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME),
"model": AutoModelForSequenceClassification.from_pretrained(MODEL_NAME),
"sentences": [],
}
def __init__(self):
self.__dict__ = self._shared_model
class Logic(Model):
"""Sentence Logic
Compare two sentence to see if they entail or contradict each other
Usage:
>>> driving = Logic(""I am driving a car.") # add a sentence
>>> in_car = Logic("I am in not a car."") # add another sentence
>>> driving.entails(in_car) # check for entailment
>>> driving > in_car # check for entailment
>>> driving != in_car # check for contradition
>>> driving.contradicts(in_car, verbose=True) # check with extra info
"""
def __init__(self, sentence, tokenizer=None, model=None):
super().__init__()
self.sentences.append(sentence)
self.sentences = self.sentences[-2:] # keeps only two sentences
if model:
self.tokenizer = tokenizer
self.model = model
else:
# initiate the first instance with default model
if not hasattr(self, "model"):
raise RuntimeError("no model to perform operations!")
def __repr__(self):
return f"{self.__class__.__name__}(model={self.model_name})"
def entails(self, other, threashold=0.7, verbose=False):
sentence, other = self.sentences # overiding other
scores = self._predict(sentence, other)
return self._post_predict(scores, "entails", threashold, verbose)
def contradicts(self, other, threashold=0.7, verbose=False):
sentence, other = self.sentences # overiding other
scores = self._predict(sentence, other)
return self._post_predict(scores, "contradicts", threashold, verbose)
def _predict(self, sentence, other_sentence):
inputs = self.tokenizer(
[f"{sentence} [SEP] {other_sentence}"],
return_tensors="pt",
is_split_into_words=True,
)
labels = torch.tensor([1] * 3).unsqueeze(0)
outputs = self.model(**inputs, labels=labels)
predictions = self.softmax(outputs.logits)
return {
label: score
for label, score in zip(
["contradicts", "neutral", "entails"], predictions.detach().numpy()[0]
)
}
@staticmethod
def _post_predict(scores, task, threashold, verbose):
results = scores[task] > threashold
if verbose:
return {"results": results, "explanations": scores}
return results
# some other methods we can use
__ne__ = contradicts
__lt__ = entails
__gt__ = entails
__le__ = entails
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment