Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active May 25, 2023 21:16
Show Gist options
  • Save avidale/7abc1aa027afd69f6b50eaf7527ed294 to your computer and use it in GitHub Desktop.
Save avidale/7abc1aa027afd69f6b50eaf7527ed294 to your computer and use it in GitHub Desktop.
BERT-toxicity-classification.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "BERT-toxicity-classification.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOlRuBAiOilI73kwx/pYsaT",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/7abc1aa027afd69f6b50eaf7527ed294/bert-toxicity-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T7mxnesuCO7T"
},
"source": [
"Этот блокнот позволяет поиграться с моделькой, классифицирующей тексты как токсичные или неполиткорректные. \n",
"\n",
"Всё самое весёлое - в последней ячейке. \n",
"\n",
"https://huggingface.co/cointegrated/rubert-tiny-toxicity"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aJCuD52e7HSn"
},
"source": [
"!pip install transformers sentencepiece --quiet"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mkJuRwEp7Kmn"
},
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForSequenceClassification"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zgJAWNaz7RBO"
},
"source": [
"model_checkpoint = 'cointegrated/rubert-tiny-toxicity'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
"if torch.cuda.is_available():\n",
" model.cuda()"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ggRV-Z5R7Z34",
"outputId": "cbbe368c-8a2f-4d69-c45f-8382a3c5a45f"
},
"source": [
"def text2toxicity(text, aggregate=True):\n",
" \"\"\" Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)\"\"\"\n",
" with torch.no_grad():\n",
" inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(model.device)\n",
" proba = torch.sigmoid(model(**inputs).logits).cpu().numpy()\n",
" if isinstance(text, str):\n",
" proba = proba[0]\n",
" if aggregate:\n",
" return 1 - proba.T[0] * (1 - proba.T[-1])\n",
" return proba\n",
"\n",
"print(text2toxicity('я люблю нигеров', True))\n",
"# 0.57240640889815\n",
"\n",
"print(text2toxicity('я люблю нигеров', False))\n",
"# [9.9336821e-01 6.1555761e-03 1.2781911e-03 9.2758919e-04 5.6955177e-01]\n",
"\n",
"print(text2toxicity(['я люблю нигеров', 'я люблю африканцев'], True))\n",
"# [0.5724064 0.20111847]\n",
"\n",
"print(text2toxicity(['я люблю нигеров', 'я люблю африканцев'], False))\n",
"# [[9.9336821e-01 6.1555761e-03 1.2781911e-03 9.2758919e-04 5.6955177e-01]\n",
"# [9.9828428e-01 1.1138428e-03 1.1492912e-03 4.6551935e-04 1.9974548e-01]]"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"0.9350118728093193\n",
"[0.9715758 0.0180863 0.0045551 0.00189755 0.9331106 ]\n",
"[0.93501186 0.04156357]\n",
"[[9.7157580e-01 1.8086294e-02 4.5550885e-03 1.8975559e-03 9.3311059e-01]\n",
" [9.9979788e-01 1.9048342e-04 1.5297388e-04 1.7452303e-04 4.1369814e-02]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qhleF3BY77bq",
"outputId": "ca4156e1-cbce-4143-9f68-478372ab3384"
},
"source": [
"%%time\n",
"print(text2toxicity('Иди ты нафиг!'))"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"0.4770178304282737\n",
"CPU times: user 9.94 ms, sys: 147 µs, total: 10.1 ms\n",
"Wall time: 17.5 ms\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "a6G0bvwJ7gg2"
},
"source": [
""
],
"execution_count": 10,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment