Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Last active November 23, 2021 05:05
Show Gist options
  • Save nbertagnolli/3dd87f8702570749dc7c111ae8997189 to your computer and use it in GitHub Desktop.
Save nbertagnolli/3dd87f8702570749dc7c111ae8997189 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Dict, List, Callable, Optional, Tuple, Union\n",
"import json\n",
"import torch\n",
"import transformers\n",
"import pandas as pd\n",
"from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer\n",
"import numpy as np\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch import optim, nn\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn import svm\n",
"from sklearn.pipeline import FeatureUnion, Pipeline\n",
"from sklearn.utils.multiclass import unique_labels\n",
"from sklearn import metrics as sk_metrics\n",
"import dill\n",
"import os\n",
"import seaborn as sns\n",
"sns.set()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def split_random(train: float, val: float, test: float) -> str:\n",
" if train + val + test != 1.0:\n",
" raise ValueError(\"train + val + test must equal 1\")\n",
" rand_num = np.random.rand()\n",
" \n",
" if rand_num <= train:\n",
" return \"train\"\n",
" elif rand_num <= train + val:\n",
" return \"val\"\n",
" else:\n",
" return \"test\"\n",
" \n",
"class BertTransformer(BaseEstimator, TransformerMixin):\n",
" def __init__(\n",
" self,\n",
" bert_tokenizer,\n",
" bert_model,\n",
" max_length: int = 60,\n",
" embedding_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,\n",
" ):\n",
" self.tokenizer = bert_tokenizer\n",
" self.model = bert_model\n",
" self.model.eval()\n",
" self.max_length = max_length\n",
" self.embedding_func = embedding_func\n",
"\n",
" if self.embedding_func is None:\n",
" self.embedding_func = lambda x: x[0][:, 0, :].squeeze()\n",
"\n",
" # TODO:: PADDING\n",
"\n",
" def _tokenize(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" # Tokenize the text with the provided tokenizer\n",
" tokenized_text = self.tokenizer.encode_plus(\n",
" text, add_special_tokens=True, max_length=self.max_length, truncation=True\n",
" )[\"input_ids\"]\n",
"\n",
" # padding\n",
" padded_text = tokenized_text + [0]*(self.max_length-len(tokenized_text))\n",
" # Create an attention mask telling BERT to use all words\n",
" attention_mask = np.where(np.array(padded_text) != 0, 1, 0)\n",
"\n",
" # bert takes in a batch so we need to unsqueeze the rows\n",
" return (\n",
" torch.tensor(padded_text).unsqueeze(0),\n",
" torch.tensor(attention_mask).unsqueeze(0),\n",
" )\n",
"\n",
" def _tokenize_and_predict(self, text: str) -> torch.Tensor:\n",
" tokenized, attention_mask = self._tokenize(text)\n",
"\n",
" embeddings = self.model(tokenized, attention_mask)\n",
" return self.embedding_func(embeddings)\n",
"\n",
" def transform(self, text: List[str]):\n",
" if isinstance(text, pd.Series):\n",
" text = text.tolist()\n",
"\n",
" with torch.no_grad():\n",
" return torch.stack([self._tokenize_and_predict(string) for string in text])\n",
"\n",
" def fit(self, X, y=None):\n",
" \"\"\"No fitting necessary so we just return ourselves\"\"\"\n",
" return self\n",
" \n",
"def calculate_classification_metrics(\n",
" y_true: np.array,\n",
" y_pred: np.array,\n",
" average: Optional[str] = None,\n",
" return_df: bool = True,\n",
") -> Union[Dict[str, float], pd.DataFrame]:\n",
" \"\"\"Computes f1, precision, recall, precision, kappa, accuracy, and support\n",
"\n",
" Args:\n",
" y_true: The true labels\n",
" y_pred: The predicted labels\n",
" average: How to average multiclass results\n",
"\n",
" Returns:\n",
" Either a dataframe of the performance metrics or a single dictionary\n",
" \"\"\"\n",
" labels = unique_labels(y_true, y_pred)\n",
"\n",
" # get results\n",
" precision, recall, f_score, support = sk_metrics.precision_recall_fscore_support(\n",
" y_true, y_pred, labels=labels, average=average\n",
" )\n",
"\n",
" kappa = sk_metrics.cohen_kappa_score(y_true, y_pred, labels=labels)\n",
" accuracy = sk_metrics.accuracy_score(y_true, y_pred)\n",
"\n",
" # create a pandas DataFrame\n",
" if return_df:\n",
" results = pd.DataFrame(\n",
" {\n",
" \"class\": labels,\n",
" \"f_score\": f_score,\n",
" \"precision\": precision,\n",
" \"recall\": recall,\n",
" \"support\": support,\n",
" \"kappa\": kappa,\n",
" \"accuracy\": accuracy,\n",
" }\n",
" )\n",
" else:\n",
" results = {\n",
" \"f1\": f_score,\n",
" \"precision\": precision,\n",
" \"recall\": recall,\n",
" \"kappa\": kappa,\n",
" \"accuracy\": accuracy,\n",
" }\n",
"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure_8_classes = [\"joy\", \"sadness\", \"fear\", \"anger\", \"disgust\", \"surprise\", \"love\", \"noemo\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure8_df = pd.read_csv(\"/Users/tetracycline/data/text_emotion.csv\")\n",
"figure8_df[\"split\"] = figure8_df[\"sentiment\"].apply(lambda _: split_random(.7, .15, .15))\n",
"x_train = figure8_df[figure8_df[\"split\"] == \"train\"]\n",
"y_train = x_train[\"sentiment\"]\n",
"x_val = figure8_df[figure8_df[\"split\"] == \"val\"]\n",
"y_val = x_val[\"sentiment\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"classifier = svm.LinearSVC(C=1.0, class_weight=\"balanced\")\n",
"\n",
"dbt = BertTransformer(DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\"),\n",
" DistilBertModel.from_pretrained(\"distilbert-base-uncased\"),\n",
" embedding_func=lambda x: x[0][:, 0, :].squeeze())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cloudpickle\n",
"dill._dill._reverse_typemap['ClassType'] = type\n",
"model = Pipeline(\n",
" [\n",
" (\"vectorizer\", dbt),\n",
" (\"classifier\", classifier),\n",
" ]\n",
")\n",
"\n",
"model.fit(x_train[\"content\"], y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Original figure 8 dataset\n",
"preds = model.predict(x_val[\"content\"])\n",
"calculate_classification_metrics(preds, y_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@nialloriordan
Copy link

@nbertagnolli I really enjoyed your BERT tutorial with scikit-learn. I have updated the code to include padding and to explicitly truncate long text. Please feel free to update the code from my fork here: https://gist.github.com/nialloriordan/4deec5ad99613f02201b65f26d66cf48

@nbertagnolli
Copy link
Author

Thanks!  I'm glad that you liked it!  I'll take a look later this week! 

@teddddddy
Copy link

Thanks a lot for sharing

@ddenz
Copy link

ddenz commented Nov 23, 2021

Hi! Thanks for this great tutorial. I have tried running your notebook, but model.fit() raises AttributeError: 'BertTransformer' object has no attribute 'bert_model'. Any ideas how to fix this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment