Last active
November 23, 2021 05:05
-
-
Save nbertagnolli/3dd87f8702570749dc7c111ae8997189 to your computer and use it in GitHub Desktop.
Holds the code for https://towardsdatascience.com/build-a-bert-sci-kit-transformer-59d60ddd54a5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing import Any, Dict, List, Callable, Optional, Tuple, Union\n", | |
"import json\n", | |
"import torch\n", | |
"import transformers\n", | |
"import pandas as pd\n", | |
"from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer\n", | |
"import numpy as np\n", | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"from torch import optim, nn\n", | |
"from sklearn.base import BaseEstimator, TransformerMixin\n", | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn import svm\n", | |
"from sklearn.pipeline import FeatureUnion, Pipeline\n", | |
"from sklearn.utils.multiclass import unique_labels\n", | |
"from sklearn import metrics as sk_metrics\n", | |
"import dill\n", | |
"import os\n", | |
"import seaborn as sns\n", | |
"sns.set()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def split_random(train: float, val: float, test: float) -> str:\n", | |
" if train + val + test != 1.0:\n", | |
" raise ValueError(\"train + val + test must equal 1\")\n", | |
" rand_num = np.random.rand()\n", | |
" \n", | |
" if rand_num <= train:\n", | |
" return \"train\"\n", | |
" elif rand_num <= train + val:\n", | |
" return \"val\"\n", | |
" else:\n", | |
" return \"test\"\n", | |
" \n", | |
"class BertTransformer(BaseEstimator, TransformerMixin):\n", | |
" def __init__(\n", | |
" self,\n", | |
" bert_tokenizer,\n", | |
" bert_model,\n", | |
" max_length: int = 60,\n", | |
" embedding_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,\n", | |
" ):\n", | |
" self.tokenizer = bert_tokenizer\n", | |
" self.model = bert_model\n", | |
" self.model.eval()\n", | |
" self.max_length = max_length\n", | |
" self.embedding_func = embedding_func\n", | |
"\n", | |
" if self.embedding_func is None:\n", | |
" self.embedding_func = lambda x: x[0][:, 0, :].squeeze()\n", | |
"\n", | |
" # TODO:: PADDING\n", | |
"\n", | |
" def _tokenize(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:\n", | |
" # Tokenize the text with the provided tokenizer\n", | |
" tokenized_text = self.tokenizer.encode_plus(\n", | |
" text, add_special_tokens=True, max_length=self.max_length, truncation=True\n", | |
" )[\"input_ids\"]\n", | |
"\n", | |
" # padding\n", | |
" padded_text = tokenized_text + [0]*(self.max_length-len(tokenized_text))\n", | |
" # Create an attention mask telling BERT to use all words\n", | |
" attention_mask = np.where(np.array(padded_text) != 0, 1, 0)\n", | |
"\n", | |
" # bert takes in a batch so we need to unsqueeze the rows\n", | |
" return (\n", | |
" torch.tensor(padded_text).unsqueeze(0),\n", | |
" torch.tensor(attention_mask).unsqueeze(0),\n", | |
" )\n", | |
"\n", | |
" def _tokenize_and_predict(self, text: str) -> torch.Tensor:\n", | |
" tokenized, attention_mask = self._tokenize(text)\n", | |
"\n", | |
" embeddings = self.model(tokenized, attention_mask)\n", | |
" return self.embedding_func(embeddings)\n", | |
"\n", | |
" def transform(self, text: List[str]):\n", | |
" if isinstance(text, pd.Series):\n", | |
" text = text.tolist()\n", | |
"\n", | |
" with torch.no_grad():\n", | |
" return torch.stack([self._tokenize_and_predict(string) for string in text])\n", | |
"\n", | |
" def fit(self, X, y=None):\n", | |
" \"\"\"No fitting necessary so we just return ourselves\"\"\"\n", | |
" return self\n", | |
" \n", | |
"def calculate_classification_metrics(\n", | |
" y_true: np.array,\n", | |
" y_pred: np.array,\n", | |
" average: Optional[str] = None,\n", | |
" return_df: bool = True,\n", | |
") -> Union[Dict[str, float], pd.DataFrame]:\n", | |
" \"\"\"Computes f1, precision, recall, precision, kappa, accuracy, and support\n", | |
"\n", | |
" Args:\n", | |
" y_true: The true labels\n", | |
" y_pred: The predicted labels\n", | |
" average: How to average multiclass results\n", | |
"\n", | |
" Returns:\n", | |
" Either a dataframe of the performance metrics or a single dictionary\n", | |
" \"\"\"\n", | |
" labels = unique_labels(y_true, y_pred)\n", | |
"\n", | |
" # get results\n", | |
" precision, recall, f_score, support = sk_metrics.precision_recall_fscore_support(\n", | |
" y_true, y_pred, labels=labels, average=average\n", | |
" )\n", | |
"\n", | |
" kappa = sk_metrics.cohen_kappa_score(y_true, y_pred, labels=labels)\n", | |
" accuracy = sk_metrics.accuracy_score(y_true, y_pred)\n", | |
"\n", | |
" # create a pandas DataFrame\n", | |
" if return_df:\n", | |
" results = pd.DataFrame(\n", | |
" {\n", | |
" \"class\": labels,\n", | |
" \"f_score\": f_score,\n", | |
" \"precision\": precision,\n", | |
" \"recall\": recall,\n", | |
" \"support\": support,\n", | |
" \"kappa\": kappa,\n", | |
" \"accuracy\": accuracy,\n", | |
" }\n", | |
" )\n", | |
" else:\n", | |
" results = {\n", | |
" \"f1\": f_score,\n", | |
" \"precision\": precision,\n", | |
" \"recall\": recall,\n", | |
" \"kappa\": kappa,\n", | |
" \"accuracy\": accuracy,\n", | |
" }\n", | |
"\n", | |
" return results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"figure_8_classes = [\"joy\", \"sadness\", \"fear\", \"anger\", \"disgust\", \"surprise\", \"love\", \"noemo\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"figure8_df = pd.read_csv(\"/Users/tetracycline/data/text_emotion.csv\")\n", | |
"figure8_df[\"split\"] = figure8_df[\"sentiment\"].apply(lambda _: split_random(.7, .15, .15))\n", | |
"x_train = figure8_df[figure8_df[\"split\"] == \"train\"]\n", | |
"y_train = x_train[\"sentiment\"]\n", | |
"x_val = figure8_df[figure8_df[\"split\"] == \"val\"]\n", | |
"y_val = x_val[\"sentiment\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"classifier = svm.LinearSVC(C=1.0, class_weight=\"balanced\")\n", | |
"\n", | |
"dbt = BertTransformer(DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\"),\n", | |
" DistilBertModel.from_pretrained(\"distilbert-base-uncased\"),\n", | |
" embedding_func=lambda x: x[0][:, 0, :].squeeze())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cloudpickle\n", | |
"dill._dill._reverse_typemap['ClassType'] = type\n", | |
"model = Pipeline(\n", | |
" [\n", | |
" (\"vectorizer\", dbt),\n", | |
" (\"classifier\", classifier),\n", | |
" ]\n", | |
")\n", | |
"\n", | |
"model.fit(x_train[\"content\"], y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Original figure 8 dataset\n", | |
"preds = model.predict(x_val[\"content\"])\n", | |
"calculate_classification_metrics(preds, y_val)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Thanks! I'm glad that you liked it! I'll take a look later this week!
Thanks a lot for sharing
Hi! Thanks for this great tutorial. I have tried running your notebook, but model.fit()
raises AttributeError: 'BertTransformer' object has no attribute 'bert_model'
. Any ideas how to fix this?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@nbertagnolli I really enjoyed your BERT tutorial with scikit-learn. I have updated the code to include padding and to explicitly truncate long text. Please feel free to update the code from my fork here: https://gist.github.com/nialloriordan/4deec5ad99613f02201b65f26d66cf48