Skip to content

Instantly share code, notes, and snippets.

@artreven
Last active October 12, 2021 08:24
Show Gist options
  • Save artreven/ef532e2a71f3ff4cb404cfeb541f8657 to your computer and use it in GitHub Desktop.
Save artreven/ef532e2a71f3ff4cb404cfeb541f8657 to your computer and use it in GitHub Desktop.
TSV.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "TSV.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPXgApcuZuQ31Fqkc3QFGqN",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/artreven/ef532e2a71f3ff4cb404cfeb541f8657/untitled0.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fht9KB7dygdi"
},
"source": [
"# Training\n",
"\n",
"First clone the repo, install the requirements and train the model. If you are using bash then just delete the first symbols `!` or `%` and put `export` instead of `env` in lines 4 and 6."
]
},
{
"cell_type": "code",
"metadata": {
"id": "96QABfZeqYU9"
},
"source": [
"%cd /content/\n",
"!git clone https://github.com/semantic-web-company/wic-tsv\n",
"%cd wic-tsv/\n",
"%env PYTHONPATH=/content/wic-tsv/\n",
"!pip install -r requirements.txt\n",
"%env HYPERBERT_PATH=HyperBert/eval/\n",
"!python3 HyperBert/HyperBert3.py --dataset_path ./data --model_output_path $HYPERBERT_PATH --model_name bert-base-uncased"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Klu8gtly6Wo"
},
"source": [
"# Use the model\n",
"Now you have the WiC-TSV pretrined HyperBert3. So the next step is to use it."
]
},
{
"cell_type": "code",
"metadata": {
"id": "AQndz-28qrbb"
},
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import torch\n",
"from transformers import AutoTokenizer\n",
"\n",
"from HyperBert.HyperBert3 import HyperBert3\n",
"from model_evaluation.wictsv_dataset import WiCTSVDataset\n",
"\n",
"model_path = Path(os.getenv('HYPERBERT_PATH'))\n",
"model = HyperBert3.from_pretrained(model_path)\n",
"model.eval()\n",
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n",
"\n",
"cxt1 = \"The jaguar's present range extends from Southwestern United States and Mexico in North America, \" \\\n",
" \"across much of Central America, and south to Paraguay and northern Argentina in South America.\"\n",
"target_word_ind1 = 1\n",
"cxt2 = \"Jaguar's business was founded as the Swallow Sidecar Company in 1922, originally making motorcycle sidecars \" \\\n",
" \"before developing bodies for passenger cars.\"\n",
"target_word_ind2 = 0\n",
"def_ = 'wild cat'\n",
"hypernyms = ['animal']\n",
"ds = WiCTSVDataset(contexts=[cxt1, cxt2],\n",
" target_inds=[target_word_ind1, target_word_ind2],\n",
" hypernyms=[hypernyms, hypernyms],\n",
" definitions=[def_, def_],\n",
" tokenizer=tok,\n",
" focus_token='$')\n",
"logits = model(**ds[:])[0]\n",
"probs = torch.sigmoid(logits).squeeze().tolist()\n",
"print(f'Probabilities: {probs}')\n",
"\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0LsMHzTFmbuc"
},
"source": [
"## Output:\n",
"`Probabilities: [0.9429619312286377, 0.052496373653411865]`"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment