Skip to content

Instantly share code, notes, and snippets.

@Proteusiq
Created March 6, 2021 09:44
Show Gist options
  • Save Proteusiq/4a9d93f29c7042b757f6e2a502ce6322 to your computer and use it in GitHub Desktop.
Save Proteusiq/4a9d93f29c7042b757f6e2a502ce6322 to your computer and use it in GitHub Desktop.
original_logic.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Hugging",
"language": "python",
"name": "huggingface"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"colab": {
"name": "original_logic.ipynb",
"provenance": [],
"include_colab_link": true
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Proteusiq/4a9d93f29c7042b757f6e2a502ce6322/original_logic.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dRVfaljk7CIH"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"\n",
"MODEL_NAME = \"microsoft/deberta-xlarge-mnli\"\n",
"\n",
"\n",
"class Model:\n",
" _shared_model = {\n",
" \"softmax\": nn.Softmax(dim=1),\n",
" \"tokenizer\": AutoTokenizer.from_pretrained(MODEL_NAME),\n",
" \"model\": AutoModelForSequenceClassification.from_pretrained(MODEL_NAME),\n",
" 'sentences': [],\n",
" }\n",
"\n",
" def __init__(self):\n",
" self.__dict__ = self._shared_model"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "i0HKWkfc7CIM"
},
"source": [
"class Logic(Model):\n",
" \"\"\"Sentence Logic\n",
" Compare two sentence to see if they entail or contradict each other\n",
"\n",
" Usage:\n",
"\n",
" >>> driving = Logic(\"\"I am driving a car.\") # add a sentence\n",
" >>> in_car = Logic(\"I am in not a car.\"\") # add another sentence\n",
" >>> driving.entails(in_car) # check for entailment\n",
" >>> driving > in_car # check for entailment\n",
" >>> driving != in_car # check for contradition\n",
" >>> driving.contradicts(in_car, verbose=True) # check for contraidiction with extra info\n",
" \"\"\"\n",
"\n",
" def __init__(self, sentence, tokenizer=None, model=None):\n",
" super().__init__()\n",
" self.sentences.append(sentence)\n",
" self.sentences = self.sentences[-2:] # keeps only two sentences\n",
" if model:\n",
" self.tokenizer = tokenizer\n",
" self.model = model\n",
" else:\n",
" # initiate the first instance with default model\n",
" if not hasattr(self, \"model\"):\n",
" raise RuntimeError(\"no model to perform operations!\")\n",
"\n",
" def __repr__(self):\n",
" return f\"{self.__class__.__name__}(model={self.model_name})\"\n",
" \n",
" \n",
" \n",
" def entails(self, other, threashold=0.7, verbose=False):\n",
" sentence, other = self.sentences # overiding other\n",
" scores = self._predict(sentence, other)\n",
" return self._post_predict(scores, \"entails\", threashold, verbose)\n",
"\n",
" def contradicts(self, other, threashold=0.7, verbose=False):\n",
" sentence, other = self.sentences # overiding other\n",
" scores = self._predict(sentence, other)\n",
" return self._post_predict(scores, \"contradicts\", threashold, verbose)\n",
"\n",
" def _predict(self, sentence, other_sentence):\n",
"\n",
" inputs = self.tokenizer(\n",
" [f\"{sentence} [SEP] {other_sentence}\"], return_tensors=\"pt\", is_split_into_words = True\n",
" )\n",
" labels = torch.tensor([1] * 3).unsqueeze(0) \n",
" outputs = self.model(**inputs, labels=labels)\n",
" predictions = self.softmax(outputs.logits)\n",
" return {\n",
" label: score\n",
" for label, score in zip(\n",
" [\"contradicts\", \"neutral\", \"entails\"], predictions.detach().numpy()[0]\n",
" )\n",
" }\n",
"\n",
" @staticmethod\n",
" def _post_predict(scores, task, threashold, verbose):\n",
"\n",
" results = scores[task] > threashold\n",
" if verbose:\n",
" return {\"results\": results, \"explanations\": scores}\n",
" return results\n",
" \n",
" # some other methods we can use\n",
" __ne__ = contradicts\n",
" __lt__ = entails\n",
" __gt__ = entails\n",
" __le__ = entails\n",
" \n",
" \n",
" "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AWkpvfxK7CIO",
"outputId": "79519910-a00d-4574-c24f-8cddcafe2801"
},
"source": [
"in_car = Logic(\"I am driving a Ford C-Max.\")\n",
"driving = Logic(\"I am in a car.\")\n",
"\n",
"in_car.entails(driving)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mP8lkzm_7CIO",
"outputId": "7c69c679-a323-40e9-8360-952b3c76facf"
},
"source": [
"in_car > driving # entails > and <= means the same"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hNRoAc0F7CIP",
"outputId": "a81aa8f5-fe34-4171-a9f6-44aca9953d59"
},
"source": [
"in_car.contradicts(driving, verbose=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'results': False,\n",
" 'explanations': {'contradicts': 0.0013153073,\n",
" 'neutral': 0.02768614,\n",
" 'entails': 0.9709986}}"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aB6Zie1E7CIP",
"outputId": "7b262a3d-9b9b-4747-e990-e2a4dd1e3aa2"
},
"source": [
"black_cat = Logic(\"That cat is black.\")\n",
"white_cat = Logic(\"The cat's color is white.\")\n",
"\n",
"black_cat.entails(white_cat, verbose=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'results': False,\n",
" 'explanations': {'contradicts': 0.99652416,\n",
" 'neutral': 0.001198517,\n",
" 'entails': 0.0022773156}}"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Zvl7Qgr7CIQ",
"outputId": "db7ca20a-6391-4a4b-b5ba-70af0c07bb1b"
},
"source": [
"black_cat != white_cat"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AbqgHvrW7CIQ"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment