Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active May 25, 2023 21:23
Show Gist options
  • Save avidale/dc7a26eb3cffc90075bf100e15b5950f to your computer and use it in GitHub Desktop.
Save avidale/dc7a26eb3cffc90075bf100e15b5950f to your computer and use it in GitHub Desktop.
rut5-encoder.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "rut5-encoder.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOrHePY8G2KZK7/q0OZZvVv",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"04d7817cfbfc4ec3a60248d35852d4b1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_cdfcbe3e247647af83a6f51613869af1",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_902841146d2a4ddc94b3e41ddaeb5395",
"IPY_MODEL_4b1c8d124ff14c01959a45c98677efe0"
]
}
},
"cdfcbe3e247647af83a6f51613869af1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"902841146d2a4ddc94b3e41ddaeb5395": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_868e55bab60a4528bff43a077fe02f39",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 30,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 30,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_4cb0312c42bc4c009a0c6a9b13efbf6b"
}
},
"4b1c8d124ff14c01959a45c98677efe0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_1d1dbafc1a5e40449a5e75fff6bae439",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 30/30 [00:23<00:00, 1.30ba/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_071477472574442d9c7a573843676bd8"
}
},
"868e55bab60a4528bff43a077fe02f39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"4cb0312c42bc4c009a0c6a9b13efbf6b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"1d1dbafc1a5e40449a5e75fff6bae439": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"071477472574442d9c7a573843676bd8": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"cc7687822d2c4bf987ce7b26211009d7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_fdd32a4e667d4749b1df52f4f6cc428c",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_79640c28d08b4c8d953fa0d5177e66d2",
"IPY_MODEL_89112305483a445b999db99748dc24b8"
]
}
},
"fdd32a4e667d4749b1df52f4f6cc428c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"79640c28d08b4c8d953fa0d5177e66d2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_2b6051386d824569b65edf54becd75e5",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 1,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_21d949a74aad47bc8262f54c51da7f6d"
}
},
"89112305483a445b999db99748dc24b8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_24a105e32e41429e8d7eaebf199d55d9",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 1/1 [00:00<00:00, 1.31ba/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_c2c4135046e045708c79ef807a8a9123"
}
},
"2b6051386d824569b65edf54becd75e5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"21d949a74aad47bc8262f54c51da7f6d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"24a105e32e41429e8d7eaebf199d55d9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"c2c4135046e045708c79ef807a8a9123": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"723c3b192e3540e0a1b606eaec5e11f8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_dc0da346fbce43f19d675d825cf20d5f",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_ec6456a677704e858f047264bfe9d560",
"IPY_MODEL_24638a87d17f498e8c8f0735097bbadb"
]
}
},
"dc0da346fbce43f19d675d825cf20d5f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ec6456a677704e858f047264bfe9d560": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_6a7802e18fca437d934bec70d5b34891",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 17,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 17,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_505ca854e513405394336cac72f7410a"
}
},
"24638a87d17f498e8c8f0735097bbadb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_08e7f5aefd4f4827a7a8282c50d1204c",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 17/17 [00:14<00:00, 1.17ba/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_68618d9b52324cafa13aed6e90b1849e"
}
},
"6a7802e18fca437d934bec70d5b34891": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"505ca854e513405394336cac72f7410a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"08e7f5aefd4f4827a7a8282c50d1204c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"68618d9b52324cafa13aed6e90b1849e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c372b84e3dbf4427929e3a835deb2088": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_758c714071c146ff9523ae43759dc63d",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_8614f0fada8b404ab4c5825ebe40d1fa",
"IPY_MODEL_ea9fc2b0eab043b29ebd91787d0b7a3f"
]
}
},
"758c714071c146ff9523ae43759dc63d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"8614f0fada8b404ab4c5825ebe40d1fa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_1e4253a16b504637ba899766360d1edf",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 30,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 30,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_e8ac8cd005ae4a50b06d74277e874f85"
}
},
"ea9fc2b0eab043b29ebd91787d0b7a3f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_8041caf1f1f84c2db2b80b28d4ed1890",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 30/30 [00:40<00:00, 1.36s/ba]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_0df29fe57e8b47efabd964c6c46309ea"
}
},
"1e4253a16b504637ba899766360d1edf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"e8ac8cd005ae4a50b06d74277e874f85": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"8041caf1f1f84c2db2b80b28d4ed1890": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"0df29fe57e8b47efabd964c6c46309ea": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"151cb5b916974e68a8a75e65def04e78": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_806d0f299cbf4786adb6b94311af4351",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_4de17bb1d8ef4fe89f488a4561fa3bf3",
"IPY_MODEL_48a1994bbb1b4cdf84f0d480d9e7e8f4"
]
}
},
"806d0f299cbf4786adb6b94311af4351": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"4de17bb1d8ef4fe89f488a4561fa3bf3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_c8118ab02dfa4d049f21c3af8d41d818",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 1,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_f4618b8d7329490bbee49456d908808e"
}
},
"48a1994bbb1b4cdf84f0d480d9e7e8f4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_18ff39f1c9cd4092ab4599a3575e5530",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 1/1 [00:00<00:00, 2.59ba/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_e9b471e8db414e79a093cfeac2d97695"
}
},
"c8118ab02dfa4d049f21c3af8d41d818": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"f4618b8d7329490bbee49456d908808e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"18ff39f1c9cd4092ab4599a3575e5530": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"e9b471e8db414e79a093cfeac2d97695": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"dccb0b60b13948fcadc6363f58ae0869": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_8c4af8e3832740299c2b3cc30df5da2f",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_90a5ccdf86a34a68917659348c9a585b",
"IPY_MODEL_667bb2c7492349aa972d865ef0aae768"
]
}
},
"8c4af8e3832740299c2b3cc30df5da2f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"90a5ccdf86a34a68917659348c9a585b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_e878a683929a4e3684d5c77cbdf6aaa1",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 17,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 17,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_0555e43593724bb2a4630bb0d52db24b"
}
},
"667bb2c7492349aa972d865ef0aae768": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_181b87b0de0f4b669bad62b8757ce830",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 17/17 [00:18<00:00, 1.10s/ba]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_dc586ad12a704d918160381c15bd0b19"
}
},
"e878a683929a4e3684d5c77cbdf6aaa1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"0555e43593724bb2a4630bb0d52db24b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"181b87b0de0f4b669bad62b8757ce830": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"dc586ad12a704d918160381c15bd0b19": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/dc7a26eb3cffc90075bf100e15b5950f/rut5-encoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4idbg8EzkR-d"
},
"source": [
"Задача этого блокнота - показать, как можно переиспользовать маленькую модель T5 для русского языка (https://huggingface.co/cointegrated/rut5-small) для задачи классификации текстов. \n",
"\n",
"Мы выкинем декодер от этой модели, а к энкодеру прилепим голову для классификации. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "ugBNkEDBg2Om"
},
"source": [
"!pip install sentencepiece transformers datasets"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_feh8Ypeg6R_"
},
"source": [
"from transformers import T5ForConditionalGeneration, T5Tokenizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ezdcSM35lXK-"
},
"source": [
"tokenizer = T5Tokenizer.from_pretrained('cointegrated/rut5-small')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "m_e9HnUKi-ge"
},
"source": [
"import torch\n",
"from torch import nn"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rcHvHi04ngpR"
},
"source": [
"На Huggingface нет готового кода для классификации текстов с помощью T5, но его очень легко написать по аналогии с BertForSequenceClassification. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "REetkbGYhd_B"
},
"source": [
"from transformers import T5PreTrainedModel\n",
"from transformers.models.t5.modeling_t5 import T5Stack\n",
"from transformers.modeling_outputs import SequenceClassifierOutput\n",
"from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
"import copy\n",
"\n",
"\n",
"def mean_pooling(model_output, attention_mask):\n",
" token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
" input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
" sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n",
" sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
" return sum_embeddings / sum_mask\n",
"\n",
"\n",
"class T5ForSequenceClassification(T5PreTrainedModel):\n",
" def __init__(self, config):\n",
" super().__init__(config)\n",
" self.num_labels = config.num_labels\n",
" self.config = config\n",
"\n",
" self.shared = nn.Embedding(config.vocab_size, config.d_model)\n",
"\n",
" encoder_config = copy.deepcopy(config)\n",
" encoder_config.is_decoder = False\n",
" encoder_config.use_cache = False\n",
" encoder_config.is_encoder_decoder = False\n",
"\n",
" self.encoder = T5Stack(encoder_config, self.shared)\n",
"\n",
" self.dropout = nn.Dropout(config.dropout_rate)\n",
" self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
"\n",
" self.init_weights()\n",
"\n",
" # Model parallel\n",
" self.model_parallel = False\n",
" self.device_map = None\n",
"\n",
" def forward(\n",
" self,\n",
" input_ids=None,\n",
" attention_mask=None,\n",
" head_mask=None,\n",
" inputs_embeds=None,\n",
" output_attentions=None,\n",
" output_hidden_states=None,\n",
" return_dict=None,\n",
" labels=None,\n",
" ):\n",
" r\"\"\"\n",
" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n",
" Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n",
" config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n",
" If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
" \"\"\"\n",
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
"\n",
" outputs = self.encoder(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" head_mask=head_mask,\n",
" inputs_embeds=inputs_embeds,\n",
" output_attentions=output_attentions,\n",
" output_hidden_states=output_hidden_states,\n",
" return_dict=return_dict,\n",
" )\n",
"\n",
" if attention_mask is None:\n",
" total_output = outputs[0] # batch, seq_len, emb_dim\n",
" pooled_output = total_output.mean(dim=1)\n",
" else:\n",
" pooled_output = mean_pooling(outputs, attention_mask)\n",
"\n",
" pooled_output = self.dropout(pooled_output)\n",
" logits = self.classifier(pooled_output)\n",
"\n",
" loss = None\n",
" if labels is not None:\n",
" if self.config.problem_type is None:\n",
" if self.num_labels == 1:\n",
" self.config.problem_type = \"regression\"\n",
" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
" self.config.problem_type = \"single_label_classification\"\n",
" else:\n",
" self.config.problem_type = \"multi_label_classification\"\n",
"\n",
" if self.config.problem_type == \"regression\":\n",
" loss_fct = MSELoss()\n",
" loss = loss_fct(logits.view(-1, self.num_labels), labels)\n",
" elif self.config.problem_type == \"single_label_classification\":\n",
" loss_fct = CrossEntropyLoss()\n",
" loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
" elif self.config.problem_type == \"multi_label_classification\":\n",
" loss_fct = BCEWithLogitsLoss()\n",
" loss = loss_fct(logits, labels)\n",
" if not return_dict:\n",
" output = (logits,) + outputs[2:]\n",
" return ((loss,) + output) if loss is not None else output\n",
"\n",
" return SequenceClassifierOutput(\n",
" loss=loss,\n",
" logits=logits,\n",
" hidden_states=outputs.hidden_states,\n",
" attentions=outputs.attentions,\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "us0_IkM3jEB-",
"outputId": "81064a76-5810-4969-93af-c3010e37b798"
},
"source": [
"clf = T5ForSequenceClassification.from_pretrained('cointegrated/rut5-small-normalizer')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at cointegrated/rut5-small-normalizer were not used when initializing T5ForSequenceClassification: ['decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_1.weight', 'lm_head.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.final_layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.embed_tokens.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight']\n",
"- This IS expected if you are initializing T5ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing T5ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at cointegrated/rut5-small-normalizer and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6mOoJ5YfnaOO"
},
"source": [
"Проверим, что класс корректно работает. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "plx_-8BOjJJ-",
"outputId": "233d94d7-3baa-4b3c-e476-10a3be9e6b47"
},
"source": [
"input_ids = tokenizer(\"Твоя каша - самая отвратная из всех каш, что мне довелось отведать!\", return_tensors=\"pt\").input_ids \n",
"with torch.no_grad():\n",
" outputs = clf(input_ids=input_ids)\n",
"outputs"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"SequenceClassifierOutput([('logits', tensor([[ 0.0532, -0.0420]]))])"
]
},
"metadata": {
"tags": []
},
"execution_count": 191
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cdNuQdKPm7iZ",
"outputId": "9c0eb95d-6351-41fc-ac36-0c065f2b147d"
},
"source": [
"with torch.no_grad():\n",
" outputs = clf(input_ids=input_ids, labels=torch.tensor([0]))\n",
"outputs"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"SequenceClassifierOutput([('loss', tensor(0.6467)),\n",
" ('logits', tensor([[ 0.0532, -0.0420]]))])"
]
},
"metadata": {
"tags": []
},
"execution_count": 192
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B5Qe3wOglNGD"
},
"source": [
"В сохранённом виде такая модель занимает 112 мегабайт. Немного. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hzFKPQLKjThP",
"outputId": "3ca294a2-a191-4f10-82cf-8643a36c87c0"
},
"source": [
"clf.save_pretrained('clf')\n",
"tokenizer.save_pretrained('clf')\n",
"! ls -alsh clf"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"total 112M\n",
"4.0K drwxr-xr-x 2 root root 4.0K May 12 21:56 .\n",
"4.0K drwxr-xr-x 1 root root 4.0K May 12 20:05 ..\n",
"4.0K -rw-r--r-- 1 root root 748 May 12 21:56 config.json\n",
"112M -rw-r--r-- 1 root root 112M May 12 21:56 pytorch_model.bin\n",
"4.0K -rw-r--r-- 1 root root 65 May 12 21:56 special_tokens_map.json\n",
"628K -rw-r--r-- 1 root root 625K May 12 21:56 spiece.model\n",
"4.0K -rw-r--r-- 1 root root 2.1K May 12 21:56 tokenizer_config.json\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ETFw8n1oYVb"
},
"source": [
"# Файн-тюнинг на задачу классификации"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wHaiOQiqp5wV"
},
"source": [
"Возьмём задачу классификации неприемлемых текстов. Она достаточно сложная, ибо нужно выявлять и чувствительные темы, и отношение, которое говорящий к этим темам демонстрирует. \n",
"\n",
"https://github.com/skoltech-nlp/inappropriate-sensitive-topics"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FAB5Y6qAobfa"
},
"source": [
"!wget https://github.com/skoltech-nlp/inappropriate-sensitive-topics/blob/main/Version2/appropriateness/train.csv?raw=true -O train.csv\n",
"!wget https://github.com/skoltech-nlp/inappropriate-sensitive-topics/blob/main/Version2/appropriateness/val.csv?raw=true -O val.csv\n",
"!wget https://github.com/skoltech-nlp/inappropriate-sensitive-topics/blob/main/Version2/appropriateness/test.csv?raw=true -O test.csv"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "V4YE2T0FrmPM"
},
"source": [
"Будем пользоваться готовыми инструментами для файн-тюнинга. \n",
"\n",
"https://huggingface.co/transformers/training.html\n",
"\n",
"https://huggingface.co/docs/datasets/loading_datasets.html"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OW_e0Zl5qKbz",
"outputId": "ad51341b-2790-4d25-e6a6-88578e40e541"
},
"source": [
"from datasets import load_dataset\n",
"dev_set = load_dataset(\"csv\", data_files={'train': 'train.csv', 'val': 'val.csv', 'test':'test.csv'})"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Using custom data configuration default-068a5f9c22e7d382\n",
"Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-068a5f9c22e7d382/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "m3M3Ye62qYbd",
"outputId": "2a6e667e-0842-498f-add4-1158e7b2de97"
},
"source": [
"dev_set"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'inappropriate', 'offline_crime', 'online_crime', 'drugs', 'gambling', 'pornography', 'prostitution', 'slavery', 'suicide', 'terrorism', 'weapons', 'body_shaming', 'health_shaming', 'politics', 'racism', 'religion', 'sexual_minorities', 'sexism', 'social_injustice', 'human_labeled'],\n",
" num_rows: 130665\n",
" })\n",
" val: Dataset({\n",
" features: ['text', 'inappropriate', 'offline_crime', 'online_crime', 'drugs', 'gambling', 'pornography', 'prostitution', 'slavery', 'suicide', 'terrorism', 'weapons', 'body_shaming', 'health_shaming', 'politics', 'racism', 'religion', 'sexual_minorities', 'sexism', 'social_injustice', 'human_labeled'],\n",
" num_rows: 16333\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'inappropriate', 'offline_crime', 'online_crime', 'drugs', 'gambling', 'pornography', 'prostitution', 'slavery', 'suicide', 'terrorism', 'weapons', 'body_shaming', 'health_shaming', 'politics', 'racism', 'religion', 'sexual_minorities', 'sexism', 'social_injustice', 'human_labeled'],\n",
" num_rows: 16334\n",
" })\n",
"})"
]
},
"metadata": {
"tags": []
},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ECrP8yX8Olc6",
"outputId": "47a2e874-47e1-4be1-bdd4-39259d451020"
},
"source": [
"import pandas as pd\n",
"pd.Series([len(tokenizer.tokenize(t)) for t in dev_set['train'].shuffle(seed=42).select(range(10000))['text']]).quantile([0.5, 0.75, 0.9, 0.95, 0.99, 1])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.50 29.0\n",
"0.75 44.0\n",
"0.90 63.0\n",
"0.95 73.0\n",
"0.99 96.0\n",
"1.00 728.0\n",
"dtype: float64"
]
},
"metadata": {
"tags": []
},
"execution_count": 62
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8rc_4kbjP2Zn",
"outputId": "6e25d56a-6fa2-4110-86e4-92829a5fc16a"
},
"source": [
"import pandas as pd\n",
"pd.Series(dev_set['train'].shuffle(seed=42).select(range(10000))['inappropriate']).describe()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"count 10000.000000\n",
"mean 0.316345\n",
"std 0.407695\n",
"min 0.000000\n",
"25% 0.000000\n",
"50% 0.060000\n",
"75% 0.780000\n",
"max 1.000000\n",
"dtype: float64"
]
},
"metadata": {
"tags": []
},
"execution_count": 79
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
},
"id": "9550-ySdQFuc",
"outputId": "6cbb5d34-cf5f-4d20-b099-3e4bb4555040"
},
"source": [
"pd.Series(dev_set['train'].shuffle(seed=42).select(range(10000))['inappropriate']).hist();"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAR0UlEQVR4nO3df4wc5X3H8fc3OCTUIdiJkxOy3RxVnLZOrCToBI5StZe4NcapMFIJckSCQW4tpaTqD6ut01ZyC0EiqghNUH5di4WJSMBNm9pKaKllOKFWNcGUBPOjlAsxwS6JG2zcXlBoL/32j30u3Rofu3e7t8vyvF/S6WaeeWbm+d7Zn5mdmd2LzESSVIdX9HsAkqTeMfQlqSKGviRVxNCXpIoY+pJUkQX9HsCLWbJkSQ4PD895/R/84AcsXLiwewN6iautXrDmWljz7Nx///3fz8w3nGrZSzr0h4eHOXDgwJzXHx8fZ3R0tHsDeomrrV6w5lpY8+xExJMzLfPyjiRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVeQl/Y7cTh08coIrtn2t5/s9dN37er5PSWqHZ/qSVBFDX5IqYuhLUkUMfUmqiKEvSRUx9CWpIoa+JFXE0Jekihj6klQRQ1+SKmLoS1JFDH1JqoihL0kVMfQlqSKGviRVpK3Qj4hDEXEwIr4REQdK2+siYm9EPF6+Ly7tERGfioiJiHgwIs5t2s6m0v/xiNg0PyVJkmYymzP992TmOzJzpMxvA/Zl5gpgX5kHuBBYUb62AJ+FxkEC2A6cD5wHbJ8+UEiSeqOTyzsbgJ1leidwcVP7LdmwH1gUEWcDFwB7M/NYZh4H9gLrOti/JGmW2v1ziQn8fUQk8PnMHAOGMvPpsvy7wFCZXgo81bTu4dI2U/v/ExFbaLxCYGhoiPHx8TaH+EJDZ8DWVVNzXn+uOhlzJyYnJ/u2736x5jpYc/e0G/o/l5lHIuKNwN6I+JfmhZmZ5YDQsXJAGQMYGRnJ0dHROW/rxlt3c/3B3v8Z4EOXjfZ8n9A42HTy8xpE1lwHa+6eti7vZOaR8v0o8BUa1+S/Vy7bUL4fLd2PAMubVl9W2mZqlyT1SMvQj4iFEXHm9DSwFngI2ANMP4GzCdhdpvcAl5eneFYDJ8ploDuBtRGxuNzAXVvaJEk90s61jyHgKxEx3f+Lmfl3EXEfsCsiNgNPApeW/ncA64EJ4DngSoDMPBYR1wD3lX5XZ+axrlUiSWqpZehn5hPA20/R/gyw5hTtCVw1w7Z2ADtmP0xJUjf4jlxJqoihL0kVMfQlqSKGviRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVcTQl6SKGPqSVBFDX5IqYuhLUkUMfUmqiKEvSRUx9CWpIoa+JFXE0Jekihj6klQRQ1+SKmLoS1JFDH1JqoihL0kVMfQlqSKGviRVpO3Qj4jTIuKBiPhqmT8nIu6NiImIuD0iTi/tryrzE2X5cNM2PlraH4uIC7pdjCTpxc3mTP83gUeb5j8O3JCZbwaOA5tL+2bgeGm/ofQjIlYCG4G3AuuAz0TEaZ0NX5I0G22FfkQsA94H/EWZD+C9wJdLl53AxWV6Q5mnLF9T+m8AbsvM5zPz28AEcF43ipAktWdBm/3+DPg94Mwy/3rg2cycKvOHgaVleinwFEBmTkXEidJ/KbC/aZvN6/xYRGwBtgAMDQ0xPj7ebi0vMHQGbF011bpjl3Uy5k5MTk72bd/9Ys11sObuaRn6EfHLwNHMvD8iRrs+gpNk5hgwBjAyMpKjo3Pf5Y237ub6g+0e17rn0GWjPd8nNA42nfy8BpE118Gau6edRHw3cFFErAdeDbwW+CSwKCIWlLP9ZcCR0v8IsBw4HBELgLOAZ5rapzWvI0nqgZbX9DPzo5m5LDOHadyIvSszLwPuBi4p3TYBu8v0njJPWX5XZmZp31ie7jkHWAF8vWuVSJJa6uTax+8Dt0XEx4AHgJtK+03AFyJiAjhG40BBZj4cEbuAR4Ap4KrM/FEH+5ckzdKsQj8zx4HxMv0Ep3j6JjN/CLx/hvWvBa6d7SAlSd3hO3IlqSKGviRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVcTQl6SKGPqSVBFDX5IqYuhLUkUMfUmqiKEvSRUx9CWpIoa+JFXE0Jekihj6klQRQ1+SKmLoS1JFDH1JqoihL0kVMfQlqSKGviRVxNCXpIoY+pJUEUNfkirSMvQj4tUR8fWI+GZEPBwRf1Laz4mIeyNiIiJuj4jTS/uryvxEWT7ctK2PlvbHIuKC+SpKknRq7ZzpPw+8NzPfDrwDWBcRq4GPAzdk5puB48Dm0n8zcLy031D6ERErgY3AW4F1wGci4rRuFiNJenEtQz8bJsvsK8tXAu8FvlzadwIXl+kNZZ6yfE1ERGm/LTOfz8xvAxPAeV2pQpLUlgXtdCpn5PcDbwY+DXwLeDYzp0qXw8DSMr0UeAogM6ci4gTw+tK+v2mzzes072sLsAVgaGiI8fHx2VXUZOgM2LpqqnXHLutkzJ2YnJzs2777xZrrYM3d01boZ+aPgHdExCLgK8DPdH0k/7evMWAMYGRkJEdHR+e8rRtv3c31B9sqsasOXTba831C42DTyc9rEFlzHay5e2b19E5mPgvcDbwLWBQR04m6DDhSpo8AywHK8rOAZ5rbT7GOJKkH2nl65w3lDJ+IOAP4JeBRGuF/Sem2CdhdpveUecryuzIzS/vG8nTPOcAK4OvdKkSS1Fo71z7OBnaW6/qvAHZl5lcj4hHgtoj4GPAAcFPpfxPwhYiYAI7ReGKHzHw4InYBjwBTwFXlspEkqUdahn5mPgi88xTtT3CKp28y84fA+2fY1rXAtbMfpiSpG3xHriRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVcTQl6SKGPqSVBFDX5IqYuhLUkUMfUmqiKEvSRUx9CWpIoa+JFXE0Jekihj6klQRQ1+SKmLoS1JFDH1JqoihL0kVMfQlqSKGviRVxNCXpIoY+pJUEUNfkiqyoFWHiFgO3AIMAQmMZeYnI+J1wO3AMHAIuDQzj0dEAJ8E1gPPAVdk5j+XbW0C/qhs+mOZubO75UhS9wxv+1rf9n3zuoXzst12zvSngK2ZuRJYDVwVESuBbcC+zFwB7CvzABcCK8rXFuCzAOUgsR04HzgP2B4Ri7tYiySphZahn5lPT5+pZ+Z/Ao8CS4ENwPSZ+k7g4jK9AbglG/YDiyLibOACYG9mHsvM48BeYF1Xq5EkvahZXdOPiGHgncC9wFBmPl0WfZfG5R9oHBCealrtcGmbqV2S1CMtr+lPi4jXAH8F/FZm/kfj0n1DZmZEZDcGFBFbaFwWYmhoiPHx8Tlva+gM2LpqqhvDmpVOxtyJycnJvu27X6y5Dv2quR/5MW2+am4r9CPilTQC/9bM/OvS/L2IODszny6Xb46W9iPA8qbVl5W2I8DoSe3jJ+8rM8eAMYCRkZEcHR09uUvbbrx1N9cfbPu41jWHLhvt+T6hcbDp5Oc1iKy5Dv2q+Yo+38idj5pbXt4pT+PcBDyamZ9oWrQH2FSmNwG7m9ovj4bVwIlyGehOYG1ELC43cNeWNklSj7RzGvxu4EPAwYj4Rmn7A+A6YFdEbAaeBC4ty+6g8bjmBI1HNq8EyMxjEXENcF/pd3VmHutKFZKktrQM/cz8ByBmWLzmFP0TuGqGbe0AdsxmgJKk7vEduZJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVcTQl6SKGPqSVBFDX5IqYuhLUkUMfUmqiKEvSRUx9CWpIoa+JFXE0Jekihj6klQRQ1+SKmLoS1JFDH1JqoihL0kVMfQlqSKGviRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFWkZ+hGxIyKORsRDTW2vi4i9EfF4+b64tEdEfCoiJiLiwYg4t2mdTaX/4xGxaX7KkSS9mHbO9G8G1p3Utg3Yl5krgH1lHuBCYEX52gJ8FhoHCWA7cD5wHrB9+kAhSeqdlqGfmfcAx05q3gDsLNM7gYub2m/Jhv3Aoog4G7gA2JuZxzLzOLCXFx5IJEnzbMEc1xvKzKfL9HeBoTK9FHiqqd/h0jZT+wtExBYarxIYGhpifHx8jkOEoTNg66qpOa8/V52MuROTk5N923e/WHMd+lVzP/Jj2nzVPNfQ/7HMzIjIbgymbG8MGAMYGRnJ0dHROW/rxlt3c/3BjkuctUOXjfZ8n9A42HTy8xpE1lyHftV8xbav9Xyf025et3Beap7r0zvfK5dtKN+PlvYjwPKmfstK20ztkqQemmvo7wGmn8DZBOxuar+8PMWzGjhRLgPdCayNiMXlBu7a0iZJ6qGW1z4i4kvAKLAkIg7TeArnOmBXRGwGngQuLd3vANYDE8BzwJUAmXksIq4B7iv9rs7Mk28OS5LmWcvQz8wPzLBozSn6JnDVDNvZAeyY1egkSV3lO3IlqSKGviRVxNCXpIoY+pJUEUNfkipi6EtSRQx9SaqIoS9JFTH0Jakihr4kVaT3nztcgeE+fRzrzesW9mW/kgaHZ/qSVBFDX5IqYuhLUkUMfUmqiKEvSRXx6Z2XkYNHTvTtDzkfuu59fdmvpNnxTF+SKuKZvqSXvH6+in258Uxfkipi6EtSRQx9SaqIoS9JFfFGrrrCD5mTBoOhr4FW43sTaqxZ3WPoS3PUr1c3W1f1ZbdAnTW/3HhNX5IqYuhLUkUMfUmqSM9DPyLWRcRjETEREdt6vX9JqllPQz8iTgM+DVwIrAQ+EBErezkGSapZr8/0zwMmMvOJzPwv4DZgQ4/HIEnViszs3c4iLgHWZeavlvkPAedn5kea+mwBtpTZnwYe62CXS4Dvd7D+oKmtXrDmWljz7LwpM99wqgUvuef0M3MMGOvGtiLiQGaOdGNbg6C2esGaa2HN3dPryztHgOVN88tKmySpB3od+vcBKyLinIg4HdgI7OnxGCSpWj29vJOZUxHxEeBO4DRgR2Y+PI+77MplogFSW71gzbWw5i7p6Y1cSVJ/+Y5cSaqIoS9JFRn40G/1sQ4R8aqIuL0svzcihns/yu5qo+bfiYhHIuLBiNgXEW/qxzi7qd2P74iIX4mIjIiBf7yvnZoj4tLyu344Ir7Y6zF2Wxv/tn8yIu6OiAfKv+/1/Rhnt0TEjog4GhEPzbA8IuJT5efxYESc2/FOM3Ngv2jcDP4W8FPA6cA3gZUn9fl14HNleiNwe7/H3YOa3wP8RJn+cA01l35nAvcA+4GRfo+7B7/nFcADwOIy/8Z+j7sHNY8BHy7TK4FD/R53hzX/PHAu8NAMy9cDfwsEsBq4t9N9DvqZfjsf67AB2FmmvwysiYjo4Ri7rWXNmXl3Zj5XZvfTeD/EIGv34zuuAT4O/LCXg5sn7dT8a8CnM/M4QGYe7fEYu62dmhN4bZk+C/i3Ho6v6zLzHuDYi3TZANySDfuBRRFxdif7HPTQXwo81TR/uLSdsk9mTgEngNf3ZHTzo52am22mcaYwyFrWXF72Ls/M/vxpp+5r5/f8FuAtEfGPEbE/Itb1bHTzo52a/xj4YEQcBu4AfqM3Q+ub2f5/b+kl9zEM6p6I+CAwAvxCv8cynyLiFcAngCv6PJReW0DjEs8ojVdz90TEqsx8tq+jml8fAG7OzOsj4l3AFyLibZn5P/0e2KAY9DP9dj7W4cd9ImIBjZeEz/RkdPOjrY+yiIhfBP4QuCgzn+/R2OZLq5rPBN4GjEfEIRrXPvcM+M3cdn7Ph4E9mfnfmflt4F9pHAQGVTs1bwZ2AWTmPwGvpvHBZC9XXf/omkEP/XY+1mEPsKlMXwLcleUOyYBqWXNEvBP4PI3AH/TrvNCi5sw8kZlLMnM4M4dp3Me4KDMP9Ge4XdHOv+2/oXGWT0QsoXG554leDrLL2qn5O8AagIj4WRqh/+89HWVv7QEuL0/xrAZOZObTnWxwoC/v5Awf6xARVwMHMnMPcBONl4ATNG6YbOzfiDvXZs1/CrwG+Mtyz/o7mXlR3wbdoTZrfllps+Y7gbUR8QjwI+B3M3NgX8W2WfNW4M8j4rdp3NS9YpBP4iLiSzQO3EvKfYrtwCsBMvNzNO5brAcmgOeAKzve5wD/vCRJszTol3ckSbNg6EtSRQx9SaqIoS9JFTH0Jakihr4kVcTQl6SK/C/3n6u1FtwkpgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nyRRVz14D6KM"
},
"source": [
"def preprocess(x):\n",
" return dict(\n",
" labels=[int(v>0.5) for v in x[\"inappropriate\"]],\n",
" **tokenizer(x[\"text\"], padding=\"max_length\", truncation=True, max_length=128,\n",
" )\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 164,
"referenced_widgets": [
"04d7817cfbfc4ec3a60248d35852d4b1",
"cdfcbe3e247647af83a6f51613869af1",
"902841146d2a4ddc94b3e41ddaeb5395",
"4b1c8d124ff14c01959a45c98677efe0",
"868e55bab60a4528bff43a077fe02f39",
"4cb0312c42bc4c009a0c6a9b13efbf6b",
"1d1dbafc1a5e40449a5e75fff6bae439",
"071477472574442d9c7a573843676bd8",
"cc7687822d2c4bf987ce7b26211009d7",
"fdd32a4e667d4749b1df52f4f6cc428c",
"79640c28d08b4c8d953fa0d5177e66d2",
"89112305483a445b999db99748dc24b8",
"2b6051386d824569b65edf54becd75e5",
"21d949a74aad47bc8262f54c51da7f6d",
"24a105e32e41429e8d7eaebf199d55d9",
"c2c4135046e045708c79ef807a8a9123",
"723c3b192e3540e0a1b606eaec5e11f8",
"dc0da346fbce43f19d675d825cf20d5f",
"ec6456a677704e858f047264bfe9d560",
"24638a87d17f498e8c8f0735097bbadb",
"6a7802e18fca437d934bec70d5b34891",
"505ca854e513405394336cac72f7410a",
"08e7f5aefd4f4827a7a8282c50d1204c",
"68618d9b52324cafa13aed6e90b1849e"
]
},
"id": "YLK1_eJADfsf",
"outputId": "3fb02dc0-01cd-4910-89df-2356eaeca9c1"
},
"source": [
"tr_tok = dev_set['train'].shuffle(seed=42).select(range(30_000)).map(preprocess, batched=True)\n",
"val_tok = dev_set['val'].shuffle(seed=42).select(range(300)).map(preprocess, batched=True)\n",
"test_tok = dev_set['test'].map(preprocess, batched=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "04d7817cfbfc4ec3a60248d35852d4b1",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc7687822d2c4bf987ce7b26211009d7",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "723c3b192e3540e0a1b606eaec5e11f8",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ijHkAnt9qTHr"
},
"source": [
"from transformers import TrainingArguments, Trainer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7N8IplVCM2wm"
},
"source": [
"from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score\n",
"import numpy as np\n",
"\n",
"def sm(pred):\n",
" e = np.exp(pred)\n",
" return e[:, 1] / e.sum(axis=1)\n",
"\n",
"def compute_metrics(pred):\n",
" labels = pred.label_ids\n",
" preds = pred.predictions.argmax(-1)\n",
" precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')\n",
" acc = accuracy_score(labels, preds)\n",
" return {\n",
" 'accuracy': acc,\n",
" 'roc_auc': roc_auc_score(labels, sm(pred.predictions)),\n",
" 'f1': f1,\n",
" 'precision': precision,\n",
" 'recall': recall\n",
" }"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TxFjs4iGTVRg",
"outputId": "c4d91065-5d19-4db8-fd71-7015e77cb5dd"
},
"source": [
"clf = T5ForSequenceClassification.from_pretrained('cointegrated/rut5-small-normalizer')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at cointegrated/rut5-small-normalizer were not used when initializing T5ForSequenceClassification: ['decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_1.weight', 'lm_head.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.final_layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.embed_tokens.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight']\n",
"- This IS expected if you are initializing T5ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing T5ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at cointegrated/rut5-small-normalizer and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_30haDx9rJZm"
},
"source": [
"training_args = TrainingArguments(\n",
" output_dir='t5_clf', \n",
" overwrite_output_dir=True,\n",
" evaluation_strategy=\"steps\",\n",
" gradient_accumulation_steps=1,\n",
" per_device_train_batch_size=64,\n",
" per_device_eval_batch_size=64,\n",
" learning_rate=1e-3,\n",
" num_train_epochs=1,\n",
" save_total_limit=3,\n",
" logging_steps=100,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-U6jHGUeqVfe"
},
"source": [
"trainer = Trainer(\n",
" model=clf, \n",
" args=training_args, \n",
" train_dataset=tr_tok, \n",
" eval_dataset=val_tok,\n",
" tokenizer=tokenizer, \n",
" compute_metrics=compute_metrics,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XOInMjeTHFIs"
},
"source": [
"clf.cuda();"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bzUqXoi7KShE"
},
"source": [
"import gc\n",
"gc.collect()\n",
"torch.cuda.empty_cache()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0Glb0dU3KcNU",
"outputId": "7bc4921e-79ce-4217-fe3d-5f19798cf896"
},
"source": [
"!nvidia-smi"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Wed May 12 21:57:19 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 465.19.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 42C P0 59W / 149W | 906MiB / 11441MiB | 12% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ascJOgeRMVMe"
},
"source": [
"Модель T5 обучается около 8 минут на этом количестве примеров. В результате получаем AUC 78%, что в принципе неплохо, с учётом того, что параметры мы вообще не тюнили, и скорее всего, модель сильно недообучена (одной эпохи обычно мало). "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 236
},
"id": "kcFah1BAdP7x",
"outputId": "4c8ddb31-17e3-4484-82dc-ef4eb97911e6"
},
"source": [
"trainer.train()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='469' max='469' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [469/469 07:19, Epoch 1/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" <th>Roc Auc</th>\n",
" <th>F1</th>\n",
" <th>Precision</th>\n",
" <th>Recall</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.571900</td>\n",
" <td>0.557053</td>\n",
" <td>0.680000</td>\n",
" <td>0.757408</td>\n",
" <td>0.524752</td>\n",
" <td>0.540816</td>\n",
" <td>0.509615</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.522100</td>\n",
" <td>0.555676</td>\n",
" <td>0.740000</td>\n",
" <td>0.766042</td>\n",
" <td>0.512500</td>\n",
" <td>0.732143</td>\n",
" <td>0.394231</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>0.521700</td>\n",
" <td>0.538386</td>\n",
" <td>0.730000</td>\n",
" <td>0.790326</td>\n",
" <td>0.433566</td>\n",
" <td>0.794872</td>\n",
" <td>0.298077</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>0.502500</td>\n",
" <td>0.534034</td>\n",
" <td>0.733333</td>\n",
" <td>0.798617</td>\n",
" <td>0.512195</td>\n",
" <td>0.700000</td>\n",
" <td>0.403846</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=469, training_loss=0.5254120247196287, metrics={'train_runtime': 440.9788, 'train_samples_per_second': 1.064, 'total_flos': 672203289600000.0, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 0, 'init_mem_gpu_alloc_delta': 118529536, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 172032, 'train_mem_gpu_alloc_delta': 354115072, 'train_mem_cpu_peaked_delta': 135168, 'train_mem_gpu_peaked_delta': 4027416576})"
]
},
"metadata": {
"tags": []
},
"execution_count": 201
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 258
},
"id": "GgmLFqX_dRHJ",
"outputId": "22a0a9ad-965b-4d8e-ef92-a03eb6e1035b"
},
"source": [
"trainer.evaluate(test_tok)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='256' max='256' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [256/256 01:10]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'epoch': 1.0,\n",
" 'eval_accuracy': 0.7575609158809844,\n",
" 'eval_f1': 0.5040080160320642,\n",
" 'eval_loss': 0.4992245137691498,\n",
" 'eval_mem_cpu_alloc_delta': 0,\n",
" 'eval_mem_cpu_peaked_delta': 0,\n",
" 'eval_mem_gpu_alloc_delta': 0,\n",
" 'eval_mem_gpu_peaked_delta': 361152000,\n",
" 'eval_precision': 0.6507115135834411,\n",
" 'eval_recall': 0.41128372853638595,\n",
" 'eval_roc_auc': 0.7869150633226727,\n",
" 'eval_runtime': 70.7379,\n",
" 'eval_samples_per_second': 230.909}"
]
},
"metadata": {
"tags": []
},
"execution_count": 202
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P6w6qooWdRxs"
},
"source": [
"Имеет ли смысл вообще использовать предобученные веса? Или всё дело в архитектуре?\n",
"\n",
"Видим, что если веса инициализировать заново, то обучается наша модель существенно медленнее. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZgokRPkjdZFB",
"outputId": "0e3821f3-c0fd-489c-e6a2-215364d7dc40"
},
"source": [
"clf1 = T5ForSequenceClassification.from_pretrained('cointegrated/rut5-small-normalizer')\n",
"clf1.init_weights()\n",
"clf1.shared.weight.data.normal_(mean=0.0, std=clf1.config.initializer_factor * 1.0);"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at cointegrated/rut5-small-normalizer were not used when initializing T5ForSequenceClassification: ['decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_1.weight', 'lm_head.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.final_layer_norm.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.embed_tokens.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight']\n",
"- This IS expected if you are initializing T5ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing T5ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at cointegrated/rut5-small-normalizer and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "HAaUgwMFdV-Z"
},
"source": [
"trainer = Trainer(\n",
" model=clf1, \n",
" args=training_args, \n",
" train_dataset=tr_tok, \n",
" eval_dataset=val_tok,\n",
" tokenizer=tokenizer, \n",
" compute_metrics=compute_metrics,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 338
},
"id": "cBk4UKTDrH1e",
"outputId": "7c5b55bb-d5f1-47d2-cc2f-332cdfd976e0"
},
"source": [
"trainer.train()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='469' max='469' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [469/469 07:39, Epoch 1/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" <th>Roc Auc</th>\n",
" <th>F1</th>\n",
" <th>Precision</th>\n",
" <th>Recall</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.677000</td>\n",
" <td>0.635048</td>\n",
" <td>0.653333</td>\n",
" <td>0.603660</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.603700</td>\n",
" <td>0.638639</td>\n",
" <td>0.653333</td>\n",
" <td>0.606358</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>0.611200</td>\n",
" <td>0.635728</td>\n",
" <td>0.656667</td>\n",
" <td>0.590757</td>\n",
" <td>0.019048</td>\n",
" <td>1.000000</td>\n",
" <td>0.009615</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>0.600500</td>\n",
" <td>0.655739</td>\n",
" <td>0.653333</td>\n",
" <td>0.601109</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=469, training_loss=0.6192998093074319, metrics={'train_runtime': 460.423, 'train_samples_per_second': 1.019, 'total_flos': 672203289600000.0, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 0, 'init_mem_gpu_alloc_delta': 117480960, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 12288, 'train_mem_gpu_alloc_delta': 353066496, 'train_mem_cpu_peaked_delta': 0, 'train_mem_gpu_peaked_delta': 4094756352})"
]
},
"metadata": {
"tags": []
},
"execution_count": 205
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 312
},
"id": "QhAiW9qaMNNY",
"outputId": "59c148c2-b147-4a64-850b-47a4a99ee30d"
},
"source": [
"trainer.evaluate(test_tok)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='256' max='256' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [256/256 01:10]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'epoch': 1.0,\n",
" 'eval_accuracy': 0.700502020325701,\n",
" 'eval_f1': 0.0,\n",
" 'eval_loss': 0.6001936197280884,\n",
" 'eval_mem_cpu_alloc_delta': 0,\n",
" 'eval_mem_cpu_peaked_delta': 0,\n",
" 'eval_mem_gpu_alloc_delta': 0,\n",
" 'eval_mem_gpu_peaked_delta': 361159680,\n",
" 'eval_precision': 0.0,\n",
" 'eval_recall': 0.0,\n",
" 'eval_roc_auc': 0.598472281832951,\n",
" 'eval_runtime': 70.9211,\n",
" 'eval_samples_per_second': 230.312}"
]
},
"metadata": {
"tags": []
},
"execution_count": 206
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BoaSJnV9UCEJ"
},
"source": [
"clf.to(torch.device('cpu'));"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dU_ayAqQIc2O"
},
"source": [
"# Сравнение с BERT "
]
},
{
"cell_type": "code",
"metadata": {
"id": "j1Jjs-EIKHLx"
},
"source": [
"from transformers import BertForSequenceClassification, BertTokenizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "khHhmyeIKsRQ",
"outputId": "f1d118bc-322e-44c4-e7f4-f7742ede7df5"
},
"source": [
"bert_clf = BertForSequenceClassification.from_pretrained('DeepPavlov/rubert-base-cased-sentence')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased-sentence and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1u5i11ldK6ak"
},
"source": [
"bert_tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased-sentence')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zxy5MW5wVweA"
},
"source": [
"BERT, судя по всему, разбивает текст на более длинные токены, чем маленький T5. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0VZTwcvrVnmg",
"outputId": "690bb4c6-0061-4697-c817-3b7888c196a3"
},
"source": [
"import pandas as pd\n",
"pd.Series([len(bert_tokenizer.tokenize(t)) for t in dev_set['train'].shuffle(seed=42).select(range(10000))['text']]).quantile([0.5, 0.75, 0.9, 0.95, 0.99, 1])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.50 20.00\n",
"0.75 31.00\n",
"0.90 45.00\n",
"0.95 52.00\n",
"0.99 70.01\n",
"1.00 538.00\n",
"dtype: float64"
]
},
"metadata": {
"tags": []
},
"execution_count": 126
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "alcIpNNwLH7Q"
},
"source": [
"def preprocess_bert(x):\n",
" return dict(\n",
" labels=[int(v>0.5) for v in x[\"inappropriate\"]],\n",
" **bert_tokenizer(x[\"text\"], padding=\"max_length\", truncation=True, max_length=128)\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 164,
"referenced_widgets": [
"c372b84e3dbf4427929e3a835deb2088",
"758c714071c146ff9523ae43759dc63d",
"8614f0fada8b404ab4c5825ebe40d1fa",
"ea9fc2b0eab043b29ebd91787d0b7a3f",
"1e4253a16b504637ba899766360d1edf",
"e8ac8cd005ae4a50b06d74277e874f85",
"8041caf1f1f84c2db2b80b28d4ed1890",
"0df29fe57e8b47efabd964c6c46309ea",
"151cb5b916974e68a8a75e65def04e78",
"806d0f299cbf4786adb6b94311af4351",
"4de17bb1d8ef4fe89f488a4561fa3bf3",
"48a1994bbb1b4cdf84f0d480d9e7e8f4",
"c8118ab02dfa4d049f21c3af8d41d818",
"f4618b8d7329490bbee49456d908808e",
"18ff39f1c9cd4092ab4599a3575e5530",
"e9b471e8db414e79a093cfeac2d97695",
"dccb0b60b13948fcadc6363f58ae0869",
"8c4af8e3832740299c2b3cc30df5da2f",
"90a5ccdf86a34a68917659348c9a585b",
"667bb2c7492349aa972d865ef0aae768",
"e878a683929a4e3684d5c77cbdf6aaa1",
"0555e43593724bb2a4630bb0d52db24b",
"181b87b0de0f4b669bad62b8757ce830",
"dc586ad12a704d918160381c15bd0b19"
]
},
"id": "97RcO7dVLFXR",
"outputId": "4a3840da-dd9a-4021-a951-b48128e4b374"
},
"source": [
"tr_tok_bert = dev_set['train'].shuffle(seed=42).select(range(30_000)).map(preprocess_bert, batched=True)\n",
"val_tok_bert = dev_set['val'].shuffle(seed=42).select(range(300)).map(preprocess_bert, batched=True)\n",
"test_tok_bert = dev_set['test'].map(preprocess_bert, batched=True)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c372b84e3dbf4427929e3a835deb2088",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "151cb5b916974e68a8a75e65def04e78",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dccb0b60b13948fcadc6363f58ae0869",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tYDB1DXyKqF2"
},
"source": [
"training_args = TrainingArguments(\n",
" output_dir='bert_clf', \n",
" overwrite_output_dir=True,\n",
" evaluation_strategy=\"steps\",\n",
" gradient_accumulation_steps=1,\n",
" per_device_train_batch_size=64,\n",
" per_device_eval_batch_size=64,\n",
" learning_rate=1e-3,\n",
" num_train_epochs=1,\n",
" save_total_limit=3,\n",
" logging_steps=100,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vSMV8_qiLBJU"
},
"source": [
"trainer = Trainer(\n",
" model=bert_clf, \n",
" args=training_args, \n",
" train_dataset=tr_tok_bert, \n",
" eval_dataset=val_tok_bert,\n",
" tokenizer=bert_tokenizer, \n",
" compute_metrics=compute_metrics,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QB_jlsY7UI1L"
},
"source": [
"import gc\n",
"gc.collect()\n",
"torch.cuda.empty_cache()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UK9Rd379L5Pk"
},
"source": [
"BERT на этом же количестве примеров учится 22 минуты - втрое дольше, чем маленький T5. И памяти GPU он требует существенно больше, под 10ГБ. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 372
},
"id": "KtyXSnHJLwQl",
"outputId": "ae6feec4-8878-420b-b22e-2bb2ad3d8c38"
},
"source": [
"trainer.train()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='469' max='469' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [469/469 22:16, Epoch 1/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" <th>Roc Auc</th>\n",
" <th>F1</th>\n",
" <th>Precision</th>\n",
" <th>Recall</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.687300</td>\n",
" <td>0.646576</td>\n",
" <td>0.653333</td>\n",
" <td>0.493034</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.617700</td>\n",
" <td>0.650726</td>\n",
" <td>0.653333</td>\n",
" <td>0.537456</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>0.624600</td>\n",
" <td>0.655771</td>\n",
" <td>0.653333</td>\n",
" <td>0.527374</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>0.614400</td>\n",
" <td>0.687063</td>\n",
" <td>0.653333</td>\n",
" <td>0.510793</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=469, training_loss=0.6324021638329349, metrics={'train_runtime': 1339.4361, 'train_samples_per_second': 0.35, 'total_flos': 4097778693120000.0, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 0, 'init_mem_gpu_alloc_delta': 712210944, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 12288, 'train_mem_gpu_alloc_delta': 2138866176, 'train_mem_cpu_peaked_delta': 0, 'train_mem_gpu_peaked_delta': 7208912384})"
]
},
"metadata": {
"tags": []
},
"execution_count": 132
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hS9krqhWjxpx"
},
"source": [
"Качество предсказания у BERT никакущее, непонятно почему. Скорее всего, я допустил тупой косяк при его обучении. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 312
},
"id": "1gRw2_VfL0MJ",
"outputId": "71f1ff1d-6fb2-4b6d-d1fd-91711689794f"
},
"source": [
"trainer.evaluate(test_tok_bert)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='256' max='256' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [256/256 03:48]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'epoch': 1.0,\n",
" 'eval_accuracy': 0.700502020325701,\n",
" 'eval_f1': 0.0,\n",
" 'eval_loss': 0.6121679544448853,\n",
" 'eval_mem_cpu_alloc_delta': 385024,\n",
" 'eval_mem_cpu_peaked_delta': 0,\n",
" 'eval_mem_gpu_alloc_delta': 0,\n",
" 'eval_mem_gpu_peaked_delta': 528894976,\n",
" 'eval_precision': 0.0,\n",
" 'eval_recall': 0.0,\n",
" 'eval_roc_auc': 0.5066094392951732,\n",
" 'eval_runtime': 229.9186,\n",
" 'eval_samples_per_second': 71.043}"
]
},
"metadata": {
"tags": []
},
"execution_count": 133
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4Z3JcwGiZgNw",
"outputId": "f54b48b1-3f7c-4431-d85f-da788847ab25"
},
"source": [
"!nvidia-smi"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Wed May 12 21:42:14 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 465.19.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 43C P0 59W / 149W | 1178MiB / 11441MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8cllhbxqcSwa"
},
"source": [
"bert_clf.to(torch.device('cpu'));"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EPICmm0vpIDZ"
},
"source": [
"А ещё модели можно сравнить по числу параметров: 29 миллионов у t5 против 178 миллионов у BERT. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XurqVyOKpAxf",
"outputId": "319dd8e1-8880-49e5-bfd8-68c8eaf1e718"
},
"source": [
"def msize(m):\n",
" return sum(p.numel() for p in m.parameters())\n",
"print(msize(clf))\n",
"print(msize(bert_clf))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"29175490\n",
"177854978\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L9NvV-HScXaa"
},
"source": [
"# Сравнение с простым бейзлайном"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lFrUb-imh-RI"
},
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"\n",
"pipe = make_pipeline(CountVectorizer(min_df=3), LogisticRegression(max_iter=1000))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "39w3dfmliQ2J"
},
"source": [
"def simplify(x):\n",
" return pd.DataFrame({'text': x['text'], 'label': [int(v>0.5) for v in x[\"inappropriate\"]]})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Tp_NvqQtiNF5"
},
"source": [
"tr_tok_sk = simplify(dev_set['train'].shuffle(seed=42).select(range(30_000)))\n",
"val_tok_sk = simplify(dev_set['val'].shuffle(seed=42).select(range(300)))\n",
"test_tok_sk = simplify(dev_set['test'])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fsTENRxNit0h",
"outputId": "3eed1fa4-26a5-45a4-f79f-a2d9aba02b58"
},
"source": [
"pipe.fit(tr_tok_sk.text, tr_tok_sk.label)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('countvectorizer',\n",
" CountVectorizer(analyzer='word', binary=False,\n",
" decode_error='strict',\n",
" dtype=<class 'numpy.int64'>, encoding='utf-8',\n",
" input='content', lowercase=True, max_df=1.0,\n",
" max_features=None, min_df=3,\n",
" ngram_range=(1, 1), preprocessor=None,\n",
" stop_words=None, strip_accents=None,\n",
" token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n",
" tokenizer=None, vocabulary=None)),\n",
" ('logisticregression',\n",
" LogisticRegression(C=1.0, class_weight=None, dual=False,\n",
" fit_intercept=True, intercept_scaling=1,\n",
" l1_ratio=None, max_iter=1000,\n",
" multi_class='auto', n_jobs=None,\n",
" penalty='l2', random_state=None,\n",
" solver='lbfgs', tol=0.0001, verbose=0,\n",
" warm_start=False))],\n",
" verbose=False)"
]
},
"metadata": {
"tags": []
},
"execution_count": 183
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "S3m91Tk4jDM6",
"outputId": "f5a29f4d-dd43-41f7-8b2b-5a05d68159f0"
},
"source": [
"labels = test_tok_sk.label\n",
"preds = pipe.predict(test_tok_sk.text)\n",
"preds_proba = pipe.predict_proba(test_tok_sk.text)[:, 1]\n",
"\n",
"print(accuracy_score(labels, preds))\n",
"precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')\n",
"print(precision, recall, f1)\n",
"print(roc_auc_score(labels, preds_proba))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"0.7192971715440186\n",
"0.556371648916636 0.3096892886345053 0.39789888378200916\n",
"0.680552539288413\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gubx27tJjYDz"
},
"source": [
"Что ж, 68% AUC, которые даёт доученная логистическая регрессия, явно хуже, чем 78% от возможно недоученного T5. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "uH2UmOEtjwJF"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment