Created
March 6, 2021 09:44
-
-
Save Proteusiq/4a9d93f29c7042b757f6e2a502ce6322 to your computer and use it in GitHub Desktop.
original_logic.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Hugging", | |
"language": "python", | |
"name": "huggingface" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
}, | |
"colab": { | |
"name": "original_logic.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Proteusiq/4a9d93f29c7042b757f6e2a502ce6322/original_logic.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dRVfaljk7CIH" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from transformers import AutoTokenizer\n", | |
"from transformers import AutoModelForSequenceClassification\n", | |
"\n", | |
"\n", | |
"MODEL_NAME = \"microsoft/deberta-xlarge-mnli\"\n", | |
"\n", | |
"\n", | |
"class Model:\n", | |
" _shared_model = {\n", | |
" \"softmax\": nn.Softmax(dim=1),\n", | |
" \"tokenizer\": AutoTokenizer.from_pretrained(MODEL_NAME),\n", | |
" \"model\": AutoModelForSequenceClassification.from_pretrained(MODEL_NAME),\n", | |
" 'sentences': [],\n", | |
" }\n", | |
"\n", | |
" def __init__(self):\n", | |
" self.__dict__ = self._shared_model" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "i0HKWkfc7CIM" | |
}, | |
"source": [ | |
"class Logic(Model):\n", | |
" \"\"\"Sentence Logic\n", | |
" Compare two sentence to see if they entail or contradict each other\n", | |
"\n", | |
" Usage:\n", | |
"\n", | |
" >>> driving = Logic(\"\"I am driving a car.\") # add a sentence\n", | |
" >>> in_car = Logic(\"I am in not a car.\"\") # add another sentence\n", | |
" >>> driving.entails(in_car) # check for entailment\n", | |
" >>> driving > in_car # check for entailment\n", | |
" >>> driving != in_car # check for contradition\n", | |
" >>> driving.contradicts(in_car, verbose=True) # check for contraidiction with extra info\n", | |
" \"\"\"\n", | |
"\n", | |
" def __init__(self, sentence, tokenizer=None, model=None):\n", | |
" super().__init__()\n", | |
" self.sentences.append(sentence)\n", | |
" self.sentences = self.sentences[-2:] # keeps only two sentences\n", | |
" if model:\n", | |
" self.tokenizer = tokenizer\n", | |
" self.model = model\n", | |
" else:\n", | |
" # initiate the first instance with default model\n", | |
" if not hasattr(self, \"model\"):\n", | |
" raise RuntimeError(\"no model to perform operations!\")\n", | |
"\n", | |
" def __repr__(self):\n", | |
" return f\"{self.__class__.__name__}(model={self.model_name})\"\n", | |
" \n", | |
" \n", | |
" \n", | |
" def entails(self, other, threashold=0.7, verbose=False):\n", | |
" sentence, other = self.sentences # overiding other\n", | |
" scores = self._predict(sentence, other)\n", | |
" return self._post_predict(scores, \"entails\", threashold, verbose)\n", | |
"\n", | |
" def contradicts(self, other, threashold=0.7, verbose=False):\n", | |
" sentence, other = self.sentences # overiding other\n", | |
" scores = self._predict(sentence, other)\n", | |
" return self._post_predict(scores, \"contradicts\", threashold, verbose)\n", | |
"\n", | |
" def _predict(self, sentence, other_sentence):\n", | |
"\n", | |
" inputs = self.tokenizer(\n", | |
" [f\"{sentence} [SEP] {other_sentence}\"], return_tensors=\"pt\", is_split_into_words = True\n", | |
" )\n", | |
" labels = torch.tensor([1] * 3).unsqueeze(0) \n", | |
" outputs = self.model(**inputs, labels=labels)\n", | |
" predictions = self.softmax(outputs.logits)\n", | |
" return {\n", | |
" label: score\n", | |
" for label, score in zip(\n", | |
" [\"contradicts\", \"neutral\", \"entails\"], predictions.detach().numpy()[0]\n", | |
" )\n", | |
" }\n", | |
"\n", | |
" @staticmethod\n", | |
" def _post_predict(scores, task, threashold, verbose):\n", | |
"\n", | |
" results = scores[task] > threashold\n", | |
" if verbose:\n", | |
" return {\"results\": results, \"explanations\": scores}\n", | |
" return results\n", | |
" \n", | |
" # some other methods we can use\n", | |
" __ne__ = contradicts\n", | |
" __lt__ = entails\n", | |
" __gt__ = entails\n", | |
" __le__ = entails\n", | |
" \n", | |
" \n", | |
" " | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AWkpvfxK7CIO", | |
"outputId": "79519910-a00d-4574-c24f-8cddcafe2801" | |
}, | |
"source": [ | |
"in_car = Logic(\"I am driving a Ford C-Max.\")\n", | |
"driving = Logic(\"I am in a car.\")\n", | |
"\n", | |
"in_car.entails(driving)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mP8lkzm_7CIO", | |
"outputId": "7c69c679-a323-40e9-8360-952b3c76facf" | |
}, | |
"source": [ | |
"in_car > driving # entails > and <= means the same" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hNRoAc0F7CIP", | |
"outputId": "a81aa8f5-fe34-4171-a9f6-44aca9953d59" | |
}, | |
"source": [ | |
"in_car.contradicts(driving, verbose=True)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'results': False,\n", | |
" 'explanations': {'contradicts': 0.0013153073,\n", | |
" 'neutral': 0.02768614,\n", | |
" 'entails': 0.9709986}}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aB6Zie1E7CIP", | |
"outputId": "7b262a3d-9b9b-4747-e990-e2a4dd1e3aa2" | |
}, | |
"source": [ | |
"black_cat = Logic(\"That cat is black.\")\n", | |
"white_cat = Logic(\"The cat's color is white.\")\n", | |
"\n", | |
"black_cat.entails(white_cat, verbose=True)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'results': False,\n", | |
" 'explanations': {'contradicts': 0.99652416,\n", | |
" 'neutral': 0.001198517,\n", | |
" 'entails': 0.0022773156}}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9Zvl7Qgr7CIQ", | |
"outputId": "db7ca20a-6391-4a4b-b5ba-70af0c07bb1b" | |
}, | |
"source": [ | |
"black_cat != white_cat" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AbqgHvrW7CIQ" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment