Created
March 6, 2021 09:43
-
-
Save Proteusiq/bd081d9252967b3b0cf9e389d80db30a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSequenceClassification | |
MODEL_NAME = "microsoft/deberta-xlarge-mnli" | |
class Model: | |
_shared_model = { | |
"softmax": nn.Softmax(dim=1), | |
"tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME), | |
"model": AutoModelForSequenceClassification.from_pretrained(MODEL_NAME), | |
"sentences": [], | |
} | |
def __init__(self): | |
self.__dict__ = self._shared_model | |
class Logic(Model): | |
"""Sentence Logic | |
Compare two sentence to see if they entail or contradict each other | |
Usage: | |
>>> driving = Logic(""I am driving a car.") # add a sentence | |
>>> in_car = Logic("I am in not a car."") # add another sentence | |
>>> driving.entails(in_car) # check for entailment | |
>>> driving > in_car # check for entailment | |
>>> driving != in_car # check for contradition | |
>>> driving.contradicts(in_car, verbose=True) # check with extra info | |
""" | |
def __init__(self, sentence, tokenizer=None, model=None): | |
super().__init__() | |
self.sentences.append(sentence) | |
self.sentences = self.sentences[-2:] # keeps only two sentences | |
if model: | |
self.tokenizer = tokenizer | |
self.model = model | |
else: | |
# initiate the first instance with default model | |
if not hasattr(self, "model"): | |
raise RuntimeError("no model to perform operations!") | |
def __repr__(self): | |
return f"{self.__class__.__name__}(model={self.model_name})" | |
def entails(self, other, threashold=0.7, verbose=False): | |
sentence, other = self.sentences # overiding other | |
scores = self._predict(sentence, other) | |
return self._post_predict(scores, "entails", threashold, verbose) | |
def contradicts(self, other, threashold=0.7, verbose=False): | |
sentence, other = self.sentences # overiding other | |
scores = self._predict(sentence, other) | |
return self._post_predict(scores, "contradicts", threashold, verbose) | |
def _predict(self, sentence, other_sentence): | |
inputs = self.tokenizer( | |
[f"{sentence} [SEP] {other_sentence}"], | |
return_tensors="pt", | |
is_split_into_words=True, | |
) | |
labels = torch.tensor([1] * 3).unsqueeze(0) | |
outputs = self.model(**inputs, labels=labels) | |
predictions = self.softmax(outputs.logits) | |
return { | |
label: score | |
for label, score in zip( | |
["contradicts", "neutral", "entails"], predictions.detach().numpy()[0] | |
) | |
} | |
@staticmethod | |
def _post_predict(scores, task, threashold, verbose): | |
results = scores[task] > threashold | |
if verbose: | |
return {"results": results, "explanations": scores} | |
return results | |
# some other methods we can use | |
__ne__ = contradicts | |
__lt__ = entails | |
__gt__ = entails | |
__le__ = entails |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment