Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save davidefiocco/6c77070d838328c0b13546886de5c06a to your computer and use it in GitHub Desktop.
Save davidefiocco/6c77070d838328c0b13546886de5c06a to your computer and use it in GitHub Desktop.
Text classification in PyTorch to refactor with petastorm.ipynb
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Text classification in PyTorch to refactor with petastorm.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOa6VtIQB5zcrstMpyPUWiu",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/davidefiocco/6c77070d838328c0b13546886de5c06a/text-classification-in-pytorch-to-refactor-with-petastorm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "C3dhjnf3cJcd"
},
"source": [
"!pip install transformers --quiet"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZplY26R4oz50"
},
"source": [
"The code below is a PyTorch text classifier obtained by getting code from https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html, changed a little to work on a custom dataframe. \r\n",
"How can I transform this to work with pyspark dataframes instead of pandas dataframes?"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FVfGgGQycB7Y"
},
"source": [
"import pandas as pd\r\n",
"import torch\r\n",
"from torch.utils.data.dataset import Dataset\r\n",
"from transformers import BertTokenizer\r\n",
"import torch.nn as nn\r\n",
"import torch.nn.functional as F\r\n",
"from torch.utils.data import DataLoader\r\n",
"\r\n",
"# using HuggingFace tokenization\r\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\r\n",
"\r\n",
"text = [\"This is a test.\", \"This is not a test.\"]*100\r\n",
"label = [1, 0]*100\r\n",
"\r\n",
"df = pd.DataFrame({\"text\": text, \"label\": label})\r\n",
"df[\"tokenized\"] = df[\"text\"].apply(tokenizer.encode)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "SLL5N35nn8CL",
"outputId": "630fcf0f-c1be-4556-86ae-664538d9c966"
},
"source": [
"df.head()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>label</th>\n",
" <th>tokenized</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>This is a test.</td>\n",
" <td>1</td>\n",
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>This is not a test.</td>\n",
" <td>0</td>\n",
" <td>[101, 2023, 2003, 2025, 1037, 3231, 1012, 102]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>This is a test.</td>\n",
" <td>1</td>\n",
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>This is not a test.</td>\n",
" <td>0</td>\n",
" <td>[101, 2023, 2003, 2025, 1037, 3231, 1012, 102]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>This is a test.</td>\n",
" <td>1</td>\n",
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text label tokenized\n",
"0 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]\n",
"1 This is not a test. 0 [101, 2023, 2003, 2025, 1037, 3231, 1012, 102]\n",
"2 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]\n",
"3 This is not a test. 0 [101, 2023, 2003, 2025, 1037, 3231, 1012, 102]\n",
"4 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dRMFGrjPcFI_"
},
"source": [
"# Using pandas dataframe I use this class to create a custom PyTorch dataset\r\n",
"\r\n",
"class TokenizedDataset(Dataset):\r\n",
"\r\n",
" def __init__(self, df):\r\n",
" self.data = df\r\n",
" \r\n",
" def __getitem__(self, index):\r\n",
" text = self.data.loc[index].tokenized\r\n",
" text = torch.LongTensor(text)\r\n",
" label = self.data.loc[index].label\r\n",
" return (text, label)\r\n",
"\r\n",
" def __len__(self):\r\n",
" count = len(self.data)\r\n",
" return count\r\n",
"\r\n",
"train_dataset = TokenizedDataset(df)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1bv2nnj_co5K"
},
"source": [
"def generate_batch(batch):\r\n",
" label = torch.LongTensor([entry[1] for entry in batch])\r\n",
" text = [entry[0] for entry in batch]\r\n",
" offsets = [0] + [len(entry) for entry in text]\r\n",
" # torch.Tensor.cumsum returns the cumulative sum\r\n",
" # of elements in the dimension dim.\r\n",
" # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)\r\n",
"\r\n",
" offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\r\n",
" text = torch.cat(text)\r\n",
" return text, offsets, label"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mNzkJINlcpl6"
},
"source": [
"def train_func(sub_train_):\r\n",
"\r\n",
" # Train the model\r\n",
" train_loss = 0\r\n",
" train_acc = 0\r\n",
" data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,\r\n",
" collate_fn=generate_batch)\r\n",
" for i, (text, offsets, cls) in enumerate(data):\r\n",
" optimizer.zero_grad()\r\n",
" text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)\r\n",
" output = model(text, offsets)\r\n",
" loss = criterion(output, cls)\r\n",
" train_loss += loss.item()\r\n",
" loss.backward()\r\n",
" optimizer.step()\r\n",
" train_acc += (output.argmax(1) == cls).sum().item()\r\n",
"\r\n",
" # Adjust the learning rate\r\n",
" scheduler.step()\r\n",
"\r\n",
" return train_loss / len(sub_train_), train_acc / len(sub_train_)\r\n",
"\r\n",
"def test(data_):\r\n",
" loss = 0\r\n",
" acc = 0\r\n",
" data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)\r\n",
" for text, offsets, cls in data:\r\n",
" text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)\r\n",
" with torch.no_grad():\r\n",
" output = model(text, offsets)\r\n",
" loss = criterion(output, cls)\r\n",
" loss += loss.item()\r\n",
" acc += (output.argmax(1) == cls).sum().item()\r\n",
"\r\n",
" return loss / len(data_), acc / len(data_)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "b63NVBxqd5T1"
},
"source": [
"class TextSentiment(nn.Module):\r\n",
" def __init__(self, vocab_size, embed_dim, num_class):\r\n",
" super().__init__()\r\n",
" self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\r\n",
" self.fc = nn.Linear(embed_dim, num_class)\r\n",
" self.init_weights()\r\n",
"\r\n",
" def init_weights(self):\r\n",
" initrange = 0.5\r\n",
" self.embedding.weight.data.uniform_(-initrange, initrange)\r\n",
" self.fc.weight.data.uniform_(-initrange, initrange)\r\n",
" self.fc.bias.data.zero_()\r\n",
"\r\n",
" def forward(self, text, offsets):\r\n",
" embedded = self.embedding(text, offsets)\r\n",
" return self.fc(embedded)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ak81UL1FkKly"
},
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "I-qVHqzqmoxE"
},
"source": [
"VOCAB_SIZE = 31090\r\n",
"EMBED_DIM = 768\r\n",
"NUM_CLASS = 2\r\n",
"model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QYrKEBF9d565",
"outputId": "6a9a86f5-be10-4a1e-929f-8931d420e43e"
},
"source": [
"import time\r\n",
"from torch.utils.data.dataset import random_split\r\n",
"N_EPOCHS = 5\r\n",
"BATCH_SIZE = 8\r\n",
"min_valid_loss = float('inf')\r\n",
"\r\n",
"criterion = torch.nn.CrossEntropyLoss().to(device)\r\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=4.0)\r\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)\r\n",
"\r\n",
"train_len = int(len(train_dataset) * 0.95)\r\n",
"sub_train_, sub_valid_ = \\\r\n",
" random_split(train_dataset, [train_len, len(train_dataset) - train_len])\r\n",
"\r\n",
"for epoch in range(N_EPOCHS):\r\n",
"\r\n",
" start_time = time.time()\r\n",
" train_loss, train_acc = train_func(sub_train_)\r\n",
" valid_loss, valid_acc = test(sub_valid_)\r\n",
"\r\n",
" secs = int(time.time() - start_time)\r\n",
" mins = secs / 60\r\n",
" secs = secs % 60\r\n",
"\r\n",
" print('Epoch: %d' %(epoch + 1), \" | time in %d minutes, %d seconds\" %(mins, secs))\r\n",
" print(f'\\tLoss: {train_loss:.4f}(train)\\t|\\tAcc: {train_acc * 100:.1f}%(train)')\r\n",
" print(f'\\tLoss: {valid_loss:.4f}(valid)\\t|\\tAcc: {valid_acc * 100:.1f}%(valid)')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 1 | time in 0 minutes, 0 seconds\n",
"\tLoss: 0.2762(train)\t|\tAcc: 84.7%(train)\n",
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n",
"Epoch: 2 | time in 0 minutes, 0 seconds\n",
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n",
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n",
"Epoch: 3 | time in 0 minutes, 0 seconds\n",
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n",
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n",
"Epoch: 4 | time in 0 minutes, 0 seconds\n",
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n",
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n",
"Epoch: 5 | time in 0 minutes, 0 seconds\n",
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n",
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment