Created
October 18, 2022 22:55
-
-
Save owahltinez/9c765f3596490cba622313e8276b5f7e to your computer and use it in GitHub Desktop.
GoEmotions Dataset: Dataset Re-Classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/owahltinez/9c765f3596490cba622313e8276b5f7e/goemotions-dataset-dataset-re-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "c-6AaDi8w6n9" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Install Dependencies\n", | |
"!pip install -q \"tensorflow==2.9.*\"\n", | |
"!pip install -q \"tensorflow-text==2.9.*\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "_tOll1k9jIPV" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title Downloading and re-classifying the sentiment140 dataset\n", | |
"import csv\n", | |
"import pathlib\n", | |
"import urllib.request\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_datasets as tfds\n", | |
"import tensorflow_text as tftxt\n", | |
"from tqdm.notebook import tqdm\n", | |
"\n", | |
"emotions = urllib.request.urlopen(\n", | |
" 'https://raw.githubusercontent.com/google-research/google-research'\n", | |
" '/master/goemotions/data/emotions.txt').read().decode('utf8').split('\\n')\n", | |
"\n", | |
"batch_size = 128\n", | |
"ds_orig = tfds.load('sentiment140', split='train')\n", | |
"output_path = pathlib.Path('.') / 'sentiment140_goemotions.csv'\n", | |
"\n", | |
"# Download and load our pretrained classifier.\n", | |
"url_root = 'https://huggingface.co/owahltinez/goemotions_bert/resolve/main'\n", | |
"url_model = f'{url_root}/bert_model.tar.gz'\n", | |
"tar_path = tf.keras.utils.get_file(origin=url_model, extract=True)\n", | |
"model_path = pathlib.Path(tar_path) / '..' / 'bert_model'\n", | |
"classifier = tf.keras.models.load_model(model_path)\n", | |
"\n", | |
"# Reclassify the dataset by mapping the original text to fine-grained emotions.\n", | |
"ds_emo = ds_orig.batch(batch_size).map(lambda x: x['text'])\n", | |
"ds_emo = ds_emo.map(lambda x: (x, tf.math.argmax(classifier(x), axis=1)))\n", | |
"batch_count = ds_emo.cardinality().numpy()\n", | |
"\n", | |
"# Save the resulting dataset into a CSV file.\n", | |
"with open(output_path, 'w') as f:\n", | |
" writer = csv.writer(f)\n", | |
" writer.writerow(['sentence', 'label'])\n", | |
" for x_batch, y_batch in tqdm(iter(ds_emo), total=batch_count):\n", | |
" for x, y in zip(x_batch, y_batch):\n", | |
" label = emotions[y.numpy()]\n", | |
" sentence = x.numpy().decode('utf8')\n", | |
" writer.writerow([sentence, label])" | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"provenance": [], | |
"authorship_tag": "ABX9TyMX3j4ACyvBieE+uUKamzc+", | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment