Skip to content

Instantly share code, notes, and snippets.

@manisnesan
Created October 30, 2021 23:20
Show Gist options
  • Save manisnesan/404ec6f61376b5a1f3b6798816ffe311 to your computer and use it in GitHub Desktop.
Save manisnesan/404ec6f61376b5a1f3b6798816ffe311 to your computer and use it in GitHub Desktop.
04_modeling-question-answering.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"colab": {
"name": "04_modeling-question-answering.ipynb",
"provenance": [],
"toc_visible": true,
"include_colab_link": true
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"c08d84f274b94ec9ad9510edacfa9c4c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_7869c1b391c24a5ab631bee8803b8fcd",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_b01b12fa8fe44206bded46144f1ebfe9",
"IPY_MODEL_730360df01a44642ba42c420797ed4d0",
"IPY_MODEL_344d0b927c9741bbaeb284918490352f"
]
}
},
"7869c1b391c24a5ab631bee8803b8fcd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"b01b12fa8fe44206bded46144f1ebfe9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_dfc4d83151e34222b6cfd99cd18f6631",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Downloading: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_bae40a51df184386ab0992849a1c4d2a"
}
},
"730360df01a44642ba42c420797ed4d0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_cf902c61da9a4ae1abc81397862cda3d",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 443,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 443,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_de76897c873247ca91e5c76d0fa1c99c"
}
},
"344d0b927c9741bbaeb284918490352f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_f29ed6fc01f34892a23478c04694291c",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 443/443 [00:00<00:00, 6.75kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_50edce2966df4a83a508937f86c69a9f"
}
},
"dfc4d83151e34222b6cfd99cd18f6631": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"bae40a51df184386ab0992849a1c4d2a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"cf902c61da9a4ae1abc81397862cda3d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"de76897c873247ca91e5c76d0fa1c99c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"f29ed6fc01f34892a23478c04694291c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"50edce2966df4a83a508937f86c69a9f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"0885d1d23c8d4c5f983f1e28ce0ea06d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_469d0e9deb994bc394bb258b2e35e90c",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_e41aac826c2547589577f1e1b041780a",
"IPY_MODEL_ea1a71d26db64bc6ab8818d20f9d9194",
"IPY_MODEL_b9cdadac1d964031b1c00247f724adf6"
]
}
},
"469d0e9deb994bc394bb258b2e35e90c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e41aac826c2547589577f1e1b041780a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_b8bd5a08eff14bfbb7b845553615676b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Downloading: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_a3304a4fb65b47e69e5c67b131fe522e"
}
},
"ea1a71d26db64bc6ab8818d20f9d9194": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_a8053bcb7f8f4a7b8ac580922c297bb7",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 28,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 28,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_a1e6401ee0164915a1973772e914fe75"
}
},
"b9cdadac1d964031b1c00247f724adf6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_c515f9bb4d2b4e21be13cc533bb66ff8",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 28.0/28.0 [00:00<00:00, 532B/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_c740233e00fc4e51b87f463408ca3b26"
}
},
"b8bd5a08eff14bfbb7b845553615676b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"a3304a4fb65b47e69e5c67b131fe522e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"a8053bcb7f8f4a7b8ac580922c297bb7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"a1e6401ee0164915a1973772e914fe75": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c515f9bb4d2b4e21be13cc533bb66ff8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"c740233e00fc4e51b87f463408ca3b26": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"019bafae01244276a930515267e191ea": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_63765de959c942849a43e1a7f1fa9aea",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_74f5f6a4baeb4b2c96c55377f62960d8",
"IPY_MODEL_020dc99a21514968b2e54652dc49b4ff",
"IPY_MODEL_09d56dc69a394dbb98b741aec5075209"
]
}
},
"63765de959c942849a43e1a7f1fa9aea": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"74f5f6a4baeb4b2c96c55377f62960d8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_2efd453ded5b481ea12833bc41f06a39",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Downloading: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_aa585eb2fc184f5e972829a852b3c0a0"
}
},
"020dc99a21514968b2e54652dc49b4ff": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_101fde751876428a901b70c6ef3a5e91",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 231508,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 231508,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_078dcffa3fea40beb3d38f6b54db630f"
}
},
"09d56dc69a394dbb98b741aec5075209": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_64af53679ce749ec85c7ac0dc852e561",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 226k/226k [00:00<00:00, 619kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_2a3bfe5c7b3e4d3290be14142e7f4c69"
}
},
"2efd453ded5b481ea12833bc41f06a39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"aa585eb2fc184f5e972829a852b3c0a0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"101fde751876428a901b70c6ef3a5e91": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"078dcffa3fea40beb3d38f6b54db630f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"64af53679ce749ec85c7ac0dc852e561": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"2a3bfe5c7b3e4d3290be14142e7f4c69": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"47c2942844e64023b3b2ed86fec4a265": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_e6854974367d483c95c96f94cd6718c4",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_182fe1165a0c4ae0a22c8b15e9e57193",
"IPY_MODEL_cf365272b084408e81c93eb9b43fcc63",
"IPY_MODEL_80fdd131238e41c0ac8ee3eec5eb7b5a"
]
}
},
"e6854974367d483c95c96f94cd6718c4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"182fe1165a0c4ae0a22c8b15e9e57193": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_bd04152927aa4c369f6d47730f4f8d4a",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Downloading: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_99f85f26fa9540f985ee270dcde0fdec"
}
},
"cf365272b084408e81c93eb9b43fcc63": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_24f0532a427c4506a2fb580a8a9c8343",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 466062,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 466062,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_f85befe20f1b49cbaf859a2295e98a09"
}
},
"80fdd131238e41c0ac8ee3eec5eb7b5a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_dbc5483bb5084a4a9a46f50966f2f489",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 455k/455k [00:00<00:00, 654kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_f851735f974d4de3bc012b9827863e10"
}
},
"bd04152927aa4c369f6d47730f4f8d4a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"99f85f26fa9540f985ee270dcde0fdec": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"24f0532a427c4506a2fb580a8a9c8343": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"f85befe20f1b49cbaf859a2295e98a09": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"dbc5483bb5084a4a9a46f50966f2f489": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"f851735f974d4de3bc012b9827863e10": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e2d94a3e334944e0a1fb906a648f957c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_068338dad5ed4c5084086137f46e33bc",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_001600de55e647e9bfd0fc0086ddece3",
"IPY_MODEL_65d43f0160fe4990b29a6571ecb34a48",
"IPY_MODEL_7c82a4b80b8e45d4993d8f17bb9d42b8"
]
}
},
"068338dad5ed4c5084086137f46e33bc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"001600de55e647e9bfd0fc0086ddece3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_18133ede387345ca9dd16a9f28902eaf",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Downloading: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_21ab512903014ce4ab90c3491e4ea0f2"
}
},
"65d43f0160fe4990b29a6571ecb34a48": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_45dd8d2f31814fc4957f3195ae316c54",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 1340675298,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 1340675298,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_406355b5090641ccb2c7cb6ad04a852b"
}
},
"7c82a4b80b8e45d4993d8f17bb9d42b8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_2f664fe953e548a29c00962487997c8a",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 1.25G/1.25G [00:41<00:00, 31.4MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_ff22853468d243bebc072beb9591b707"
}
},
"18133ede387345ca9dd16a9f28902eaf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"21ab512903014ce4ab90c3491e4ea0f2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"45dd8d2f31814fc4957f3195ae316c54": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"406355b5090641ccb2c7cb6ad04a852b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"2f664fe953e548a29c00962487997c8a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"ff22853468d243bebc072beb9591b707": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/manisnesan/404ec6f61376b5a1f3b6798816ffe311/04_modeling-question-answering.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WyV8kufvex9Y"
},
"source": [
"%%capture\n",
"%pip install ohmeow-blurr\n",
"%pip install fastcore\n",
"%pip install transformers\n",
"%pip install fastai\n",
"%pip install nbdev\n",
"%pip install wwf"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7ZAGc9bCe7O9"
},
"source": [
"from wwf.utils import *"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 148
},
"id": "ZyByEQoMe963",
"outputId": "f468805e-9395-49d0-b260-c3cf53c453f4"
},
"source": [
"state_versions('fastai', 'transformers', 'ohmeow-blurr')"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/markdown": "\n---\nThis article is also a Jupyter Notebook available to be run from the top down. There\nwill be code snippets that you can then run in any environment.\n\nBelow are the versions of `fastai`, `transformers`, and `ohmeow-blurr` currently running at the time of writing this:\n* `fastai` : 2.5.3 \n* `transformers` : 4.12.2 \n* `ohmeow-blurr` : 0.1.0 \n---",
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sjsL_4vZeoZ6"
},
"source": [
"# default_exp modeling.question_answering"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VM15rO2VeoZ9"
},
"source": [
"#all_slow"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "exrOPJigeoZ9"
},
"source": [
"#hide\n",
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "peI3qAy6eoZ-"
},
"source": [
"# modeling.question_answering\n",
"\n",
"> This module contains custom models, loss functions, custom splitters, etc... for question answering tasks"
]
},
{
"cell_type": "code",
"metadata": {
"id": "z2BEVafUeoZ_"
},
"source": [
"#export\n",
"import os, ast, inspect\n",
"from typing import Any, Callable, Dict, List, Optional, Union, Type\n",
"\n",
"from fastcore.all import *\n",
"from fastai.callback.all import *\n",
"from fastai.data.block import DataBlock, CategoryBlock, ColReader, ItemGetter, ColSplitter, RandomSplitter\n",
"from fastai.data.core import DataLoader, DataLoaders, TfmdDL\n",
"from fastai.imports import *\n",
"from fastai.learner import *\n",
"from fastai.losses import CrossEntropyLossFlat\n",
"from fastai.optimizer import Adam, OptimWrapper, params\n",
"from fastai.torch_core import *\n",
"from fastai.torch_imports import *\n",
"from fastprogress.fastprogress import progress_bar,master_bar\n",
"from seqeval import metrics as seq_metrics\n",
"from transformers import (\n",
" AutoModelForQuestionAnswering, logging,\n",
" PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel\n",
")\n",
"\n",
"from blurr.utils import BLURR\n",
"from blurr.data.core import HF_TextBlock, BlurrDataLoader, first_blurr_tfm\n",
"from blurr.modeling.core import HF_BaseModelCallback, HF_PreCalculatedLoss, Blearner\n",
"from blurr.data.question_answering import HF_QuestionAnswerInput, HF_QABeforeBatchTransform\n",
"\n",
"logging.set_verbosity_error()"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J2Wd3ij8eoaA",
"outputId": "a5dfb03c-0d7c-484d-f824-7298ae26f9ab"
},
"source": [
"#hide_input\n",
"import pdb\n",
"\n",
"from fastai.data.external import untar_data, URLs\n",
"from fastcore.test import *\n",
"from nbverbose.showdoc import show_doc\n",
"from transformers import AutoConfig\n",
"\n",
"from blurr.utils import print_versions\n",
"from blurr.modeling.core import HF_BaseModelWrapper, HF_PreCalculatedLoss, hf_splitter\n",
"from blurr.data.question_answering import pre_process_squad\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"print(\"What we're running with at the time this documentation was generated:\")\n",
"print_versions('torch fastai transformers')"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"What we're running with at the time this documentation was generated:\n",
"torch: 1.9.0+cu111\n",
"fastai: 2.5.3\n",
"transformers: 4.12.2\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "q6DzU5QveoaB",
"outputId": "c2d92d14-841a-4c06-dc33-24fc3d297847"
},
"source": [
"#hide\n",
"#cuda\n",
"#torch.cuda.set_device(1)\n",
"print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Using GPU #0: Tesla K80\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zpc-4O6FeoaC"
},
"source": [
"## Question Answer\n",
"\n",
"Given a document (context) and a question, the objective of these models is to predict the start and end token of the correct answer as it exists in the context."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rxaGNJJPeoaC"
},
"source": [
"Again, we'll use a subset of pre-processed SQUAD v2 for our purposes below."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "abeMiD_2eoaD",
"outputId": "0aa45b0b-32db-406d-f957-35d53e216a9b"
},
"source": [
"# full\n",
"# squad_df = pd.read_csv('./data/task-question-answering/squad_cleaned.csv'); len(squad_df)\n",
"!wget https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/squad_sample.csv\n",
"# sample\n",
"squad_df = pd.read_csv('./squad_sample.csv'); len(squad_df)"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1000"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 243
},
"id": "qtfPxS4leoaD",
"outputId": "e5ce4524-186e-4ede-a832-b55619e0dad1"
},
"source": [
"squad_df.head(2)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>context</th>\n",
" <th>question</th>\n",
" <th>answers</th>\n",
" <th>ds_type</th>\n",
" <th>answer_text</th>\n",
" <th>is_impossible</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>56be85543aeaaa14008c9063</td>\n",
" <td>Beyoncé</td>\n",
" <td>Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&amp;B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five G...</td>\n",
" <td>When did Beyonce start becoming popular?</td>\n",
" <td>{'text': ['in the late 1990s'], 'answer_start': [269]}</td>\n",
" <td>train</td>\n",
" <td>in the late 1990s</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>56be85543aeaaa14008c9065</td>\n",
" <td>Beyoncé</td>\n",
" <td>Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&amp;B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five G...</td>\n",
" <td>What areas did Beyonce compete in when she was growing up?</td>\n",
" <td>{'text': ['singing and dancing'], 'answer_start': [207]}</td>\n",
" <td>train</td>\n",
" <td>singing and dancing</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id title ... answer_text is_impossible\n",
"0 56be85543aeaaa14008c9063 Beyoncé ... in the late 1990s False\n",
"1 56be85543aeaaa14008c9065 Beyoncé ... singing and dancing False\n",
"\n",
"[2 rows x 8 columns]"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 176,
"referenced_widgets": [
"c08d84f274b94ec9ad9510edacfa9c4c",
"7869c1b391c24a5ab631bee8803b8fcd",
"b01b12fa8fe44206bded46144f1ebfe9",
"730360df01a44642ba42c420797ed4d0",
"344d0b927c9741bbaeb284918490352f",
"dfc4d83151e34222b6cfd99cd18f6631",
"bae40a51df184386ab0992849a1c4d2a",
"cf902c61da9a4ae1abc81397862cda3d",
"de76897c873247ca91e5c76d0fa1c99c",
"f29ed6fc01f34892a23478c04694291c",
"50edce2966df4a83a508937f86c69a9f",
"0885d1d23c8d4c5f983f1e28ce0ea06d",
"469d0e9deb994bc394bb258b2e35e90c",
"e41aac826c2547589577f1e1b041780a",
"ea1a71d26db64bc6ab8818d20f9d9194",
"b9cdadac1d964031b1c00247f724adf6",
"b8bd5a08eff14bfbb7b845553615676b",
"a3304a4fb65b47e69e5c67b131fe522e",
"a8053bcb7f8f4a7b8ac580922c297bb7",
"a1e6401ee0164915a1973772e914fe75",
"c515f9bb4d2b4e21be13cc533bb66ff8",
"c740233e00fc4e51b87f463408ca3b26",
"019bafae01244276a930515267e191ea",
"63765de959c942849a43e1a7f1fa9aea",
"74f5f6a4baeb4b2c96c55377f62960d8",
"020dc99a21514968b2e54652dc49b4ff",
"09d56dc69a394dbb98b741aec5075209",
"2efd453ded5b481ea12833bc41f06a39",
"aa585eb2fc184f5e972829a852b3c0a0",
"101fde751876428a901b70c6ef3a5e91",
"078dcffa3fea40beb3d38f6b54db630f",
"64af53679ce749ec85c7ac0dc852e561",
"2a3bfe5c7b3e4d3290be14142e7f4c69",
"47c2942844e64023b3b2ed86fec4a265",
"e6854974367d483c95c96f94cd6718c4",
"182fe1165a0c4ae0a22c8b15e9e57193",
"cf365272b084408e81c93eb9b43fcc63",
"80fdd131238e41c0ac8ee3eec5eb7b5a",
"bd04152927aa4c369f6d47730f4f8d4a",
"99f85f26fa9540f985ee270dcde0fdec",
"24f0532a427c4506a2fb580a8a9c8343",
"f85befe20f1b49cbaf859a2295e98a09",
"dbc5483bb5084a4a9a46f50966f2f489",
"f851735f974d4de3bc012b9827863e10",
"e2d94a3e334944e0a1fb906a648f957c",
"068338dad5ed4c5084086137f46e33bc",
"001600de55e647e9bfd0fc0086ddece3",
"65d43f0160fe4990b29a6571ecb34a48",
"7c82a4b80b8e45d4993d8f17bb9d42b8",
"18133ede387345ca9dd16a9f28902eaf",
"21ab512903014ce4ab90c3491e4ea0f2",
"45dd8d2f31814fc4957f3195ae316c54",
"406355b5090641ccb2c7cb6ad04a852b",
"2f664fe953e548a29c00962487997c8a",
"ff22853468d243bebc072beb9591b707"
]
},
"id": "AZ6IcdGleoaE",
"outputId": "eccb3e6b-ab2f-49a7-9e8a-098cbaa9aa6e"
},
"source": [
"pretrained_model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'\n",
"hf_model_cls = AutoModelForQuestionAnswering\n",
"\n",
"hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=hf_model_cls)\n",
"\n",
"# # here's a pre-trained roberta model for squad you can try too\n",
"# pretrained_model_name = \"ahotrod/roberta_large_squad2\"\n",
"# hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, \n",
"# model_cls=AutoModelForQuestionAnswering)\n",
"\n",
"# # here's a pre-trained xlm model for squad you can try too\n",
"# pretrained_model_name = 'xlm-mlm-ende-1024'\n",
"# hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name,\n",
"# model_cls=AutoModelForQuestionAnswering)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c08d84f274b94ec9ad9510edacfa9c4c",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Downloading: 0%| | 0.00/443 [00:00<?, ?B/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0885d1d23c8d4c5f983f1e28ce0ea06d",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Downloading: 0%| | 0.00/28.0 [00:00<?, ?B/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "019bafae01244276a930515267e191ea",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Downloading: 0%| | 0.00/226k [00:00<?, ?B/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47c2942844e64023b3b2ed86fec4a265",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Downloading: 0%| | 0.00/455k [00:00<?, ?B/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2d94a3e334944e0a1fb906a648f957c",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Downloading: 0%| | 0.00/1.25G [00:00<?, ?B/s]"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "P8koyGpNeoaE"
},
"source": [
"squad_df = squad_df.apply(partial(pre_process_squad, hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), axis=1)"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8GZ2avqQeoaF"
},
"source": [
"max_seq_len= 128"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8rHfO6nAeoaF"
},
"source": [
"squad_df = squad_df[(squad_df.tokenized_input_len < max_seq_len) & (squad_df.is_impossible == False)]"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 643
},
"id": "u6cQJ1B5eoaF",
"outputId": "6b6a8227-e5c7-4bd2-dd22-69b742e0b110"
},
"source": [
"#hide\n",
"squad_df.head(2)"
],
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>context</th>\n",
" <th>question</th>\n",
" <th>answers</th>\n",
" <th>ds_type</th>\n",
" <th>answer_text</th>\n",
" <th>is_impossible</th>\n",
" <th>tokenized_input</th>\n",
" <th>tokenized_input_len</th>\n",
" <th>tok_answer_start</th>\n",
" <th>tok_answer_end</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>214</th>\n",
" <td>56be97c73aeaaa14008c912a</td>\n",
" <td>Beyoncé</td>\n",
" <td>Beyoncé announced a hiatus from her music career in January 2010, heeding her mother's advice, \"to live life, to be inspired by things again\". During the break she and her father parted ways as business partners. Beyoncé's musical break lasted nine months and saw her visit multiple European cities, the Great Wall of China, the Egyptian pyramids, Australia, English music festivals and various museums and ballet performances.</td>\n",
" <td>Beyonce would take a break from music in which year?</td>\n",
" <td>{'text': ['2010'], 'answer_start': [60]}</td>\n",
" <td>train</td>\n",
" <td>2010</td>\n",
" <td>False</td>\n",
" <td>[[CLS], beyonce, would, take, a, break, from, music, in, which, year, ?, [SEP], beyonce, announced, a, hiatus, from, her, music, career, in, january, 2010, ,, hee, ##ding, her, mother, ', s, advice, ,, \", to, live, life, ,, to, be, inspired, by, things, again, \", ., during, the, break, she, and, her, father, parted, ways, as, business, partners, ., beyonce, ', s, musical, break, lasted, nine, months, and, saw, her, visit, multiple, european, cities, ,, the, great, wall, of, china, ,, the, egyptian, pyramid, ##s, ,, australia, ,, english, music, festivals, and, various, museums, and, ballet...</td>\n",
" <td>99</td>\n",
" <td>23</td>\n",
" <td>24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>215</th>\n",
" <td>56be97c73aeaaa14008c912b</td>\n",
" <td>Beyoncé</td>\n",
" <td>Beyoncé announced a hiatus from her music career in January 2010, heeding her mother's advice, \"to live life, to be inspired by things again\". During the break she and her father parted ways as business partners. Beyoncé's musical break lasted nine months and saw her visit multiple European cities, the Great Wall of China, the Egyptian pyramids, Australia, English music festivals and various museums and ballet performances.</td>\n",
" <td>Which year did Beyonce and her father part business ways?</td>\n",
" <td>{'text': ['2010'], 'answer_start': [60]}</td>\n",
" <td>train</td>\n",
" <td>2010</td>\n",
" <td>False</td>\n",
" <td>[[CLS], which, year, did, beyonce, and, her, father, part, business, ways, ?, [SEP], beyonce, announced, a, hiatus, from, her, music, career, in, january, 2010, ,, hee, ##ding, her, mother, ', s, advice, ,, \", to, live, life, ,, to, be, inspired, by, things, again, \", ., during, the, break, she, and, her, father, parted, ways, as, business, partners, ., beyonce, ', s, musical, break, lasted, nine, months, and, saw, her, visit, multiple, european, cities, ,, the, great, wall, of, china, ,, the, egyptian, pyramid, ##s, ,, australia, ,, english, music, festivals, and, various, museums, and, b...</td>\n",
" <td>99</td>\n",
" <td>23</td>\n",
" <td>24</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id title ... tok_answer_start tok_answer_end\n",
"214 56be97c73aeaaa14008c912a Beyoncé ... 23 24\n",
"215 56be97c73aeaaa14008c912b Beyoncé ... 23 24\n",
"\n",
"[2 rows x 12 columns]"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gce3KJs4eoaF"
},
"source": [
"vocab = list(range(max_seq_len))\n",
"# vocab = dict(enumerate(range(max_seq_len)));"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "r9sA6cUAeoaG"
},
"source": [
"# account for tokenizers that pad on right or left side\n",
"trunc_strat = 'only_second' if (hf_tokenizer.padding_side == 'right') else 'only_first'\n",
"\n",
"before_batch_tfm = HF_QABeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,\n",
" max_length=max_seq_len, \n",
" truncation=trunc_strat, \n",
" tok_kwargs={ 'return_special_tokens_mask': True })\n",
"\n",
"blocks = (\n",
" HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_QuestionAnswerInput), \n",
" CategoryBlock(vocab=vocab),\n",
" CategoryBlock(vocab=vocab)\n",
")\n",
"\n",
"def get_x(x):\n",
" return (x.question, x.context) if (hf_tokenizer.padding_side == 'right') else (x.context, x.question)\n",
"\n",
"dblock = DataBlock(blocks=blocks, \n",
" get_x=get_x,\n",
" get_y=[ColReader('tok_answer_start'), ColReader('tok_answer_end')],\n",
" splitter=RandomSplitter(),\n",
" n_inp=1)"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qiRG9Vx7eoaG"
},
"source": [
"dls = dblock.dataloaders(squad_df, bs=4)"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_jYP1fLkeoaG",
"outputId": "a6ae5ccc-268f-460c-9930-194ce63d9c0b"
},
"source": [
"len(dls.vocab), dls.vocab[0], dls.vocab[1]"
],
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2,\n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127],\n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127])"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 176
},
"id": "QtKR9oMweoaG",
"outputId": "bccdab02-8044-4c1e-b1c1-43cb3b5e4304"
},
"source": [
"dls.show_batch(dataloaders=dls, max_n=2)"
],
"execution_count": 26,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>start/end</th>\n",
" <th>answer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>what language does she mainly sing? beyonce's music is generally r &amp; b, but she also incorporates pop, soul and funk into her songs. 4 demonstrated beyonce's exploration of 90s - style r &amp; b, as well as further use of soul and hip hop than compared to previous releases. while she almost exclusively releases english songs, beyonce recorded several spanish songs for irreemplazable ( re - recordings of songs from b'day for a spanish - language audience ), and the re - release of b'day. to record these, beyonce was coached phonetically by american record producer rudy perez.</td>\n",
" <td>(67, 68)</td>\n",
" <td>english</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>what other language has she sung? beyonce's music is generally r &amp; b, but she also incorporates pop, soul and funk into her songs. 4 demonstrated beyonce's exploration of 90s - style r &amp; b, as well as further use of soul and hip hop than compared to previous releases. while she almost exclusively releases english songs, beyonce recorded several spanish songs for irreemplazable ( re - recordings of songs from b'day for a spanish - language audience ), and the re - release of b'day. to record these, beyonce was coached phonetically by american record producer rudy perez.</td>\n",
" <td>(73, 74)</td>\n",
" <td>spanish</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f-X6U3WOeoaH"
},
"source": [
"### Training\n",
"\n",
"Here we create a question/answer specific subclass of `HF_BaseModelCallback` in order to get all the start and end prediction. We also add here a new loss function that can handle multiple targets"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LqjqUIJJeoaH"
},
"source": [
"#export\n",
"class HF_QstAndAnsModelCallback(HF_BaseModelCallback): \n",
" \"\"\"The prediction is a combination start/end logits\"\"\"\n",
" def after_pred(self):\n",
" super().after_pred()\n",
" self.learn.pred = (self.pred.start_logits, self.pred.end_logits)"
],
"execution_count": 27,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wgZUtpzeoaH"
},
"source": [
"And here we provide a custom loss function our question answer task, expanding on some techniques learned from here and here.\n",
"\n",
"In fact, this new loss function can be used in many other multi-modal architectures, with any mix of loss functions. For example, this can be ammended to include the `is_impossible` task, as well as the start/end token tasks in the SQUAD v2 dataset."
]
},
{
"cell_type": "code",
"metadata": {
"id": "4nuwWYIbeoaH"
},
"source": [
"#export\n",
"class MultiTargetLoss(Module):\n",
" \"\"\"Provides the ability to apply different loss functions to multi-modal targets/predictions\"\"\"\n",
" def __init__(\n",
" self, \n",
" # The loss function for each target\n",
" loss_classes:List[Callable]=[CrossEntropyLossFlat, CrossEntropyLossFlat], \n",
" # Any kwargs you want to pass to the loss functions above\n",
" loss_classes_kwargs:List[dict]=[{}, {}], \n",
" # The weights you want to apply to each loss (default: [1,1])\n",
" weights:Union[List[float], List[int]]=[1, 1], \n",
" # The `reduction` parameter of the lass function (default: 'mean')\n",
" reduction:str='mean'\n",
" ):\n",
" loss_funcs = [ cls(reduction=reduction, **kwargs) for cls, kwargs in zip(loss_classes, loss_classes_kwargs) ]\n",
" store_attr(self=self, names='loss_funcs, weights')\n",
" self._reduction = reduction\n",
" \n",
" # custom loss function must have either a reduction attribute or a reduction argument (like all fastai and\n",
" # PyTorch loss functions) so that the framework can change this as needed (e.g., when doing lear.get_preds \n",
" # it will set = 'none'). see this forum topic for more info: https://bit.ly/3br2Syz\n",
" @property\n",
" def reduction(self): return self._reduction\n",
" \n",
" @reduction.setter\n",
" def reduction(self, v): \n",
" self._reduction = v\n",
" for lf in self.loss_funcs: lf.reduction = v\n",
"\n",
" def forward(self, outputs, *targets):\n",
" loss = 0.\n",
" for i, loss_func, weights, output, target in zip(range(len(outputs)), \n",
" self.loss_funcs, self.weights,\n",
" outputs, targets):\n",
" loss += weights * loss_func(output, target) \n",
" \n",
" return loss\n",
" \n",
" def activation(self, outs): \n",
" acts = [ self.loss_funcs[i].activation(o) for i, o in enumerate(outs) ]\n",
" return acts\n",
"\n",
" def decodes(self, outs): \n",
" decodes = [ self.loss_funcs[i].decodes(o) for i, o in enumerate(outs) ]\n",
" return decodes\n"
],
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ETwtCUKueoaI"
},
"source": [
"model = HF_BaseModelWrapper(hf_model)\n",
"\n",
"learn = Learner(dls, \n",
" model,\n",
" opt_func=partial(Adam, decouple_wd=True),\n",
" cbs=[HF_QstAndAnsModelCallback],\n",
" splitter=hf_splitter)\n",
"\n",
"learn.loss_func=MultiTargetLoss()\n",
"learn.create_opt() # -> will create your layer groups based on your \"splitter\" function\n",
"learn.freeze()"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5rYtg3steoaI"
},
"source": [
"Notice above how I had to define the loss function *after* creating the `Learner` object. I'm not sure why, but the `MultiTargetLoss` above prohibits the learner from being exported if I do."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 334
},
"id": "rBHeNwDXeoaI",
"outputId": "6ac031d6-93d7-47ac-fdf4-c056bc6c3eb1"
},
"source": [
"#hide_output\n",
"learn.summary()"
],
"execution_count": 32,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {}
},
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-32-f3eae8d3a65c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#hide_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/fastai/callback/hook.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;34m\"Print a summary of the model, optimizer and loss function.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0mxb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"n_inp\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34mf\"Optimizer used: {self.opt_func}\\nLoss function: {self.loss_func}\\n\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/fastai/callback/hook.py\u001b[0m in \u001b[0;36mmodule_summary\u001b[0;34m(learn, *xb)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0minfos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m76\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfind_bs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 177\u001b[0;31m \u001b[0minp_sz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_print_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"{type(learn.model).__name__} (Input shape: {inp_sz})\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"=\"\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"\\n\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/fastai/callback/hook.py\u001b[0m in \u001b[0;36m_print_shapes\u001b[0;34m(o, bs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_print_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_get_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_get_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0m_print_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/fastai/callback/hook.py\u001b[0m in \u001b[0;36m_get_shapes\u001b[0;34m(o, bs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;31m# Cell\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m_get_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m' x '\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_print_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_get_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: unhashable type: 'slice'"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OhMKlbaFeoaI",
"outputId": "9217df1c-49ea-4008-fb7b-587d754b9a55"
},
"source": [
"print(len(learn.opt.param_groups))"
],
"execution_count": null,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "f1Uobp2PeoaI",
"outputId": "4e51947c-041d-458c-fbe0-29d6267f26f1"
},
"source": [
"x, y_start, y_end = dls.one_batch()\n",
"preds = learn.model(x)\n",
"len(preds),preds[0].shape"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"(2, torch.Size([4, 127]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "q-3PWw1IeoaJ",
"outputId": "96b05545-2534-4579-cbf9-f22b383bd0bb"
},
"source": [
"learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wgilliam/miniconda3/envs/blurr/lib/python3.9/site-packages/fastai/callback/schedule.py:270: UserWarning: color is redundantly defined by the 'color' keyword argument and the fmt string \"ro\" (-> color='r'). The keyword argument will take precedence.\n",
" ax.plot(val, idx, 'ro', label=nm, c=color)\n"
]
},
{
"data": {
"text/plain": [
"SuggestedLRs(minimum=0.003981071710586548, steep=0.0010000000474974513, valley=0.0008317637839354575, slide=0.0020892962347716093)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mcuW52ZveoaJ",
"outputId": "85654380-7081-4693-b5bb-143351757e18"
},
"source": [
"learn.fit_one_cycle(3, lr_max=1e-3)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4.139011</td>\n",
" <td>1.324671</td>\n",
" <td>00:04</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.428752</td>\n",
" <td>0.630240</td>\n",
" <td>00:04</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.736481</td>\n",
" <td>0.554712</td>\n",
" <td>00:04</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lS7xPxaoeoaJ"
},
"source": [
"### Showing results\n",
"\n",
"Below we'll add in additional functionality to more intuitively show the results of our model."
]
},
{
"cell_type": "code",
"metadata": {
"id": "KYjjtD0JeoaJ"
},
"source": [
"#export\n",
"@typedispatch\n",
"def show_results(\n",
" # This typedispatched `show_results` will be called for `HF_QuestionAnswerInput` typed inputs\n",
" x:HF_QuestionAnswerInput, \n",
" # The targets\n",
" y, \n",
" # Your raw inputs/targets\n",
" samples, \n",
" # The model's predictions\n",
" outs, \n",
" # Your `Learner`. This is required so as to get at the Hugging Face objects for decoding them into \n",
" # something understandable\n",
" learner, \n",
" # Whether you want to remove special tokens during decoding/showing the outputs\n",
" skip_special_tokens=True, \n",
" # Your `show_results` context\n",
" ctxs=None, \n",
" # The maximum number of items to show\n",
" max_n=6, \n",
" # Any truncation your want applied to your decoded inputs\n",
" trunc_at=None, \n",
" # Any other keyword arguments you want applied to `show_results`\n",
" **kwargs\n",
"): \n",
" tfm = first_blurr_tfm(learner.dls)\n",
" hf_tokenizer = tfm.hf_tokenizer\n",
" \n",
" res = L()\n",
" for sample, input_ids, start, end, pred in zip(samples, x, *y, outs):\n",
" txt = hf_tokenizer.decode(sample[0], skip_special_tokens=True)[:trunc_at]\n",
" ans_toks = hf_tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=False)[start:end]\n",
" pred_ans_toks = hf_tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=False)[int(pred[0]):int(pred[1])]\n",
" \n",
" res.append((txt,\n",
" (start.item(),end.item()), hf_tokenizer.convert_tokens_to_string(ans_toks),\n",
" (int(pred[0]),int(pred[1])), hf_tokenizer.convert_tokens_to_string(pred_ans_toks)))\n",
"\n",
" df = pd.DataFrame(res, columns=['text', 'start/end', 'answer', 'pred start/end', 'pred answer'])\n",
" display_df(df[:max_n])\n",
" return ctxs"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1vJfgVsGeoaJ",
"outputId": "e31f9cd8-5e98-4c5c-cd31-f9cee901cbd0"
},
"source": [
"learn.show_results(learner=learn, skip_special_tokens=True, max_n=2, trunc_at=500)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>start/end</th>\n",
" <th>answer</th>\n",
" <th>pred start/end</th>\n",
" <th>pred answer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>where did beyonce exclusively release her single, formation? on february 6, 2016, one day before her performance at the super bowl, beyonce released a new single exclusively on music streaming service tidal called \" formation \".</td>\n",
" <td>(38, 39)</td>\n",
" <td>tidal</td>\n",
" <td>(38, 39)</td>\n",
" <td>tidal</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>what word does \" bey hive \" derive from? the bey hive is the name given to beyonce's fan base. fans were previously titled \" the beyontourage \", ( a portmanteau of beyonce and entourage ). the name bey hive derives from the word beehive, purposely misspelled to resemble her first name, and was penned by fans after petitions on the online social networking service twitter and online news reports during competitions.</td>\n",
" <td>(58, 61)</td>\n",
" <td>beehive</td>\n",
" <td>(58, 61)</td>\n",
" <td>beehive</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TCDlngdyeoaK"
},
"source": [
"... and lets see how `Learner.blurr_predict` works with question/answering tasks"
]
},
{
"cell_type": "code",
"metadata": {
"id": "djQt6L4meoaK",
"outputId": "7fb1e7cd-afb4-4212-842d-c00ea99b7cc3"
},
"source": [
"inf_df = pd.DataFrame.from_dict([{\n",
" 'question': 'What did George Lucas make?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
"}], \n",
" orient='columns')\n",
"\n",
"learn.blurr_predict(inf_df.iloc[0])"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"[(('11', '13'),\n",
" (#2) [tensor(11),tensor(13)],\n",
" (#2) [tensor([3.0268e-07, 6.9921e-08, 5.9632e-09, 1.2420e-08, 8.5584e-09, 7.5558e-09,\n",
" 9.2788e-10, 3.0270e-07, 3.8582e-04, 2.7305e-05, 8.3689e-04, 9.9857e-01,\n",
" 1.5739e-04, 4.2566e-07, 7.8813e-06, 5.0365e-07, 4.5226e-06, 4.6080e-06,\n",
" 3.3246e-08, 2.2053e-06, 8.2759e-07, 1.2332e-07, 2.5745e-07]),tensor([1.6131e-03, 8.3521e-05, 5.9296e-06, 2.2950e-06, 9.9383e-06, 6.0271e-06,\n",
" 2.4133e-05, 1.6132e-03, 3.1182e-05, 1.2563e-04, 9.7756e-05, 1.4331e-05,\n",
" 5.6870e-02, 5.2701e-01, 2.6772e-02, 3.5492e-01, 2.6779e-04, 8.0594e-05,\n",
" 2.0279e-04, 1.8910e-04, 6.8663e-03, 2.1703e-02, 1.4861e-03])])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cLn3RutNeoaK",
"outputId": "e14a1c45-551e-4c0a-b89a-518f14a4e653"
},
"source": [
"inf_df = pd.DataFrame.from_dict([\n",
" {\n",
" 'question': 'What did George Lucas make?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What year did Star Wars come out?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What did George Lucas do?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }], \n",
" orient='columns')\n",
"\n",
"learn.blurr_predict(inf_df)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"[(('11', '13'),\n",
" (#2) [tensor(11),tensor(13)],\n",
" (#2) [tensor([3.0268e-07, 6.9921e-08, 5.9632e-09, 1.2420e-08, 8.5584e-09, 7.5558e-09,\n",
" 9.2787e-10, 3.0270e-07, 3.8582e-04, 2.7305e-05, 8.3689e-04, 9.9857e-01,\n",
" 1.5739e-04, 4.2566e-07, 7.8812e-06, 5.0365e-07, 4.5226e-06, 4.6080e-06,\n",
" 3.3246e-08, 2.2053e-06, 8.2759e-07, 1.2332e-07, 2.5745e-07, 5.4841e-10,\n",
" 7.3210e-10]),tensor([1.6131e-03, 8.3521e-05, 5.9296e-06, 2.2950e-06, 9.9383e-06, 6.0271e-06,\n",
" 2.4133e-05, 1.6132e-03, 3.1182e-05, 1.2563e-04, 9.7755e-05, 1.4331e-05,\n",
" 5.6870e-02, 5.2701e-01, 2.6772e-02, 3.5492e-01, 2.6779e-04, 8.0594e-05,\n",
" 2.0279e-04, 1.8910e-04, 6.8663e-03, 2.1703e-02, 1.4861e-03, 1.3298e-06,\n",
" 8.1794e-07])]),\n",
" (('16', '17'),\n",
" (#2) [tensor(16),tensor(17)],\n",
" (#2) [tensor([1.8138e-06, 3.6914e-06, 7.9606e-08, 5.7100e-08, 4.9475e-08, 3.7448e-08,\n",
" 5.3773e-08, 6.4744e-08, 2.9010e-08, 1.8139e-06, 1.5196e-06, 1.7933e-06,\n",
" 2.9139e-06, 5.4099e-06, 2.6638e-06, 6.3237e-05, 9.9991e-01, 2.0777e-06,\n",
" 3.3066e-07, 3.0617e-07, 6.3106e-08, 3.2725e-07, 6.9513e-07, 9.8958e-07,\n",
" 1.8122e-06]),tensor([3.1355e-03, 6.7197e-04, 5.7665e-04, 2.1287e-04, 8.4543e-05, 1.6099e-04,\n",
" 1.0164e-04, 1.8573e-04, 4.5048e-04, 3.1355e-03, 6.3488e-04, 1.0317e-03,\n",
" 7.2071e-04, 3.0524e-04, 1.0807e-03, 1.1295e-03, 7.2798e-03, 9.5853e-01,\n",
" 1.0582e-02, 8.9976e-04, 8.4763e-04, 7.5183e-04, 2.1189e-03, 2.2401e-03,\n",
" 3.1360e-03])]),\n",
" (('17', '21'),\n",
" (#2) [tensor(17),tensor(21)],\n",
" (#2) [tensor([8.9343e-06, 3.5278e-07, 8.2588e-08, 1.7797e-07, 7.2601e-08, 9.2827e-08,\n",
" 2.2630e-08, 8.9359e-06, 4.9041e-03, 5.3172e-04, 1.2702e-01, 4.9925e-04,\n",
" 3.2418e-05, 2.7297e-06, 4.7746e-05, 1.4660e-05, 1.0720e-01, 7.3418e-01,\n",
" 2.3097e-05, 2.5464e-02, 3.5826e-05, 1.3317e-05, 8.6433e-06, 1.0467e-08,\n",
" 1.6150e-08]),tensor([3.4201e-03, 2.1435e-05, 1.3353e-05, 5.0382e-06, 2.2327e-05, 1.2700e-05,\n",
" 3.4319e-05, 3.4203e-03, 3.7694e-05, 2.4082e-04, 3.4410e-04, 3.1209e-04,\n",
" 1.6885e-02, 2.4353e-02, 7.8388e-03, 7.9421e-02, 4.1039e-05, 8.2386e-05,\n",
" 1.3766e-03, 3.6992e-03, 3.0023e-01, 5.5487e-01, 3.3109e-03, 3.7449e-06,\n",
" 1.6923e-06])])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vuO4JCLpeoaL",
"outputId": "97d55a40-c88b-43a2-c19e-830356b2757b"
},
"source": [
"inp_ids = hf_tokenizer.encode('What did George Lucas make?',\n",
" 'George Lucas created Star Wars in 1977. He directed and produced it.')\n",
"\n",
"hf_tokenizer.convert_ids_to_tokens(inp_ids, skip_special_tokens=False)[11:13]"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"['star', 'wars']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eSMKKuTFeoaL"
},
"source": [
"Note that there is a bug currently in fastai v2 (or with how I'm assembling everything) that currently prevents us from seeing the decoded predictions and probabilities for the \"end\" token."
]
},
{
"cell_type": "code",
"metadata": {
"id": "cauKqimneoaL",
"outputId": "475b554b-a457-4082-a54d-c4e195f68d53"
},
"source": [
"inf_df = pd.DataFrame.from_dict([{\n",
" 'question': 'When was Star Wars made?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.'\n",
"}], \n",
" orient='columns')\n",
"\n",
"test_dl = dls.test_dl(inf_df)\n",
"inp = test_dl.one_batch()[0]['input_ids']\n",
"probs, _, preds = learn.get_preds(dl=test_dl, with_input=False, with_decoded=True)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9m5gyJpbeoaL",
"outputId": "218a61ae-b873-4cda-bc1e-8e3b5e684a17"
},
"source": [
"hf_tokenizer.convert_ids_to_tokens(inp.tolist()[0], \n",
" skip_special_tokens=False)[torch.argmax(probs[0]):torch.argmax(probs[1])]"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"['1977']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "irH_58a6eoaL"
},
"source": [
"We can unfreeze and continue training like normal"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xiSHPsOueoaL"
},
"source": [
"learn.unfreeze()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hVYPxVPTeoaM",
"outputId": "b392cbbb-128c-4cf4-e010-4bdcee493e24"
},
"source": [
"learn.fit_one_cycle(3, lr_max=slice(1e-7, 1e-4))"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.998746</td>\n",
" <td>0.454062</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.788146</td>\n",
" <td>0.426064</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.641020</td>\n",
" <td>0.407723</td>\n",
" <td>00:08</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WKTuTOzyeoaM",
"outputId": "62204780-b8c8-4f2b-aea1-496cd2349300"
},
"source": [
"learn.recorder.plot_loss()"
],
"execution_count": null,
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wcNpjYNfeoaM",
"outputId": "7895529c-1e3e-43f5-8f7f-7330d656a3af"
},
"source": [
"learn.show_results(learner=learn, max_n=2, trunc_at=100)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>start/end</th>\n",
" <th>answer</th>\n",
" <th>pred start/end</th>\n",
" <th>pred answer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>where did beyonce exclusively release her single, formation? on february 6, 2016, one day before her</td>\n",
" <td>(38, 39)</td>\n",
" <td>tidal</td>\n",
" <td>(38, 39)</td>\n",
" <td>tidal</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>her first appearance performing since giving birth was where? on january 7, 2012, beyonce gave birth</td>\n",
" <td>(52, 61)</td>\n",
" <td>revel atlantic city's ovation hall</td>\n",
" <td>(52, 61)</td>\n",
" <td>revel atlantic city's ovation hall</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "t38wXhG1eoaM",
"outputId": "1c698074-00cd-4c51-ab87-ce588d8a6b97"
},
"source": [
"learn.blurr_predict(inf_df.iloc[0])"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"[(('14', '15'),\n",
" (#2) [tensor(14),tensor(15)],\n",
" (#2) [tensor([2.0139e-07, 6.0186e-08, 1.3398e-08, 8.2077e-09, 7.6575e-09, 2.5011e-08,\n",
" 3.6197e-09, 2.0140e-07, 1.6501e-06, 7.0809e-07, 6.9617e-06, 4.8223e-06,\n",
" 1.0504e-06, 9.7600e-04, 9.9901e-01, 9.0924e-07, 8.0747e-08, 4.3993e-08,\n",
" 6.1624e-09, 3.3514e-08, 7.0549e-08, 6.1190e-08, 1.9218e-07]),tensor([4.5972e-04, 1.5797e-05, 6.4806e-06, 2.9688e-06, 4.6566e-06, 2.9982e-06,\n",
" 1.1183e-05, 4.5972e-04, 2.5425e-05, 3.4846e-05, 5.2598e-05, 9.6193e-06,\n",
" 3.9660e-05, 1.1938e-04, 1.1622e-03, 9.9376e-01, 3.1741e-03, 2.8255e-05,\n",
" 2.0109e-05, 1.9298e-05, 7.3358e-05, 6.7907e-05, 4.4704e-04])])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sTbMvuP-eoaM",
"outputId": "3d8b2cab-dde8-46f4-fc50-b7e366ae7120"
},
"source": [
"preds, pred_classes, probs = zip(*learn.blurr_predict(inf_df.iloc[0]))\n",
"preds"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"(('14', '15'),)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WAT5sH6yeoaM",
"outputId": "09bb7234-d4bd-475a-f5fa-777242b2bb16"
},
"source": [
"inp_ids = hf_tokenizer.encode('When was Star Wars made?',\n",
" 'George Lucas created Star Wars in 1977. He directed and produced it.')\n",
"\n",
"hf_tokenizer.convert_ids_to_tokens(inp_ids, skip_special_tokens=False)[int(preds[0][0]):int(preds[0][1])]"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"['1977']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "spNglbaleoaN"
},
"source": [
"### Inference\n",
"\n",
"Note that I had to replace the loss function because of the above-mentioned issue to exporting the model with the `MultiTargetLoss` loss function. After getting our inference learner, we put it back and we're good to go!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xJZhPFfUeoaN"
},
"source": [
"export_name = 'q_and_a_learn_export'"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cX1kzewseoaN"
},
"source": [
"learn.loss_func = CrossEntropyLossFlat()\n",
"learn.export(fname=f'{export_name}.pkl')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XfkLrgaeeoaN",
"outputId": "2a3cd6d6-ea1b-4035-ac19-3d71675ce97e"
},
"source": [
"inf_learn = load_learner(fname=f'{export_name}.pkl')\n",
"inf_learn.loss_func = MultiTargetLoss()\n",
"\n",
"inf_df = pd.DataFrame.from_dict([\n",
" {\n",
" 'question': 'What did George Lucas make?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What year did Star Wars come out?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What did George Lucas do?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }], \n",
" orient='columns')\n",
"\n",
"inf_learn.blurr_predict(inf_df)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"[(('11', '13'),\n",
" (#2) [tensor(11),tensor(13)],\n",
" (#2) [tensor([1.8343e-07, 5.0318e-08, 4.8847e-09, 8.4643e-09, 5.8875e-09, 7.1094e-09,\n",
" 7.2085e-10, 1.8345e-07, 1.3470e-04, 1.2801e-05, 5.9846e-04, 9.9913e-01,\n",
" 1.1949e-04, 1.4157e-07, 3.5095e-06, 1.6165e-07, 1.3399e-06, 1.5361e-06,\n",
" 1.6692e-08, 7.6845e-07, 2.2477e-07, 4.0865e-08, 1.0515e-07, 4.5671e-10,\n",
" 6.0244e-10]),tensor([6.7522e-04, 2.6496e-05, 1.9889e-06, 7.5134e-07, 2.1462e-06, 1.3290e-06,\n",
" 5.0080e-06, 6.7526e-04, 7.9513e-06, 2.0008e-05, 2.3254e-05, 4.8967e-06,\n",
" 2.2808e-02, 7.4460e-01, 1.1859e-02, 2.1369e-01, 7.4602e-05, 2.3168e-05,\n",
" 4.3674e-05, 4.5102e-05, 1.5111e-03, 3.3360e-03, 5.5911e-04, 4.4061e-07,\n",
" 3.0934e-07])]),\n",
" (('16', '17'),\n",
" (#2) [tensor(16),tensor(17)],\n",
" (#2) [tensor([1.1055e-06, 1.8704e-06, 4.5308e-08, 3.7884e-08, 3.0064e-08, 2.3709e-08,\n",
" 3.2830e-08, 3.9723e-08, 1.7936e-08, 1.1055e-06, 5.2583e-07, 6.4640e-07,\n",
" 9.6506e-07, 1.8904e-06, 1.1272e-06, 2.7888e-05, 9.9996e-01, 5.5272e-07,\n",
" 1.1390e-07, 1.1871e-07, 3.3400e-08, 1.2625e-07, 2.4034e-07, 2.9475e-07,\n",
" 1.1031e-06]),tensor([1.4690e-03, 3.0677e-04, 2.0203e-04, 8.3740e-05, 3.4471e-05, 5.6418e-05,\n",
" 3.7654e-05, 5.9635e-05, 1.2717e-04, 1.4690e-03, 1.7384e-04, 2.6960e-04,\n",
" 2.0993e-04, 1.0054e-04, 2.9895e-04, 3.2731e-04, 3.5134e-03, 9.8572e-01,\n",
" 2.8103e-03, 2.1733e-04, 1.6969e-04, 1.6633e-04, 3.8966e-04, 3.1913e-04,\n",
" 1.4693e-03])]),\n",
" (('17', '21'),\n",
" (#2) [tensor(17),tensor(21)],\n",
" (#2) [tensor([9.8801e-06, 4.2964e-07, 1.1790e-07, 2.2178e-07, 9.1010e-08, 1.3551e-07,\n",
" 2.9758e-08, 9.8821e-06, 3.2094e-03, 4.5714e-04, 1.5957e-01, 3.8280e-04,\n",
" 2.7563e-05, 1.3641e-06, 4.6444e-05, 1.0686e-05, 1.1351e-01, 6.9439e-01,\n",
" 2.3044e-05, 2.8310e-02, 2.4390e-05, 8.8607e-06, 9.0558e-06, 1.2868e-08,\n",
" 2.0218e-08]),tensor([1.9702e-03, 8.6120e-06, 5.9526e-06, 2.1202e-06, 6.5693e-06, 4.6015e-06,\n",
" 1.1350e-05, 1.9704e-03, 1.2754e-05, 6.1397e-05, 1.4907e-04, 1.4174e-04,\n",
" 8.3424e-03, 1.6695e-02, 2.8036e-03, 3.3022e-02, 1.5188e-05, 3.1796e-05,\n",
" 5.4052e-04, 1.2613e-03, 3.0156e-01, 6.2954e-01, 1.8494e-03, 1.7348e-06,\n",
" 8.5971e-07])])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "c1NGB3bueoaN",
"outputId": "27821cbd-cdc3-4f31-e1c8-01cd790f0ecc"
},
"source": [
"inp_ids = hf_tokenizer.encode('What did George Lucas make?',\n",
" 'George Lucas created Star Wars in 1977. He directed and produced it.')\n",
"\n",
"hf_tokenizer.convert_ids_to_tokens(inp_ids, skip_special_tokens=False)[11:13]"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"['star', 'wars']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-Jo6bGWeoaN"
},
"source": [
"## High-level API"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mtIRN6tweoaN"
},
"source": [
"### BLearnerForQuestionAnswering"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6mJZopB_eoaO"
},
"source": [
"#hide\n",
"try: del learn; del inf_learn; torch.cuda.empty_cache()\n",
"except: pass"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nsgDQSgPeoaO"
},
"source": [
"#export\n",
"@delegates(Blearner.__init__)\n",
"class BlearnerForQuestionAnswering(Blearner):\n",
"\n",
" def __init__(\n",
" self, \n",
" dls:DataLoaders, \n",
" hf_model: PreTrainedModel, \n",
" **kwargs\n",
" ):\n",
" kwargs['loss_func'] = kwargs.get('loss_func', MultiTargetLoss())\n",
" super().__init__(dls, hf_model, base_model_cb=HF_QstAndAnsModelCallback, **kwargs)\n",
" \n",
" @classmethod\n",
" def get_model_cls(self): \n",
" return AutoModelForQuestionAnswering\n",
" \n",
" @classmethod\n",
" def _get_x(\n",
" cls, \n",
" x, \n",
" qst, \n",
" ctx, \n",
" padding_side='right'\n",
" ): \n",
" return (x[qst], x[ctx]) if (padding_side == 'right') else (x[ctx], x[qst])\n",
" \n",
" @classmethod\n",
" def _create_learner(\n",
" cls, \n",
" # Your raw dataset\n",
" data, \n",
" # The name or path of the pretrained model you want to fine-tune\n",
" pretrained_model_name_or_path:Optional[Union[str, os.PathLike]],\n",
" # A function to perform any preprocessing required for your Dataset \n",
" preprocess_func:Callable=None, \n",
" # The maximum sequence length to constrain our data\n",
" max_seq_len:int=None,\n",
" # The attribute in your dataset that contains the context (where the answer is included) (default: 'context')\n",
" context_attr:str='context', \n",
" # The attribute in your dataset that contains the question being asked (default: 'question')\n",
" question_attr:str='question', \n",
" # The attribute in your dataset that contains the actual answer (default: 'answer_text')\n",
" answer_text_attr:str='answer_text',\n",
" # The attribute in your dataset that contains the tokenized answer start (default: 'tok_answer_start')\n",
" tok_ans_start_attr:str='tok_answer_start', \n",
" # The attribute in your dataset that contains the tokenized answer end(default: 'tok_answer_end')\n",
" tok_ans_end_attr:str='tok_answer_end', \n",
" # A function that will split your Dataset into a training and validation set\n",
" # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters\n",
" dblock_splitter:Callable=RandomSplitter(), \n",
" # Any kwargs to pass to your `DataLoaders`\n",
" dl_kwargs={}, \n",
" # Any kwargs to pass to your task specific `Blearner`\n",
" learner_kwargs={}\n",
" ):\n",
" hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name_or_path, \n",
" model_cls=cls.get_model_cls())\n",
" \n",
" # potentially used by our preprocess_func, it is the basis for our CategoryBlock vocab\n",
" if (max_seq_len is None):\n",
" max_seq_len = hf_config.get('max_position_embeddings', 128)\n",
" \n",
" # client can pass in a function that takes the raw data, hf objects, and max_seq_len ... and\n",
" # returns a DataFrame with the expected format\n",
" if (preprocess_func):\n",
" data = preprocess_func(data, hf_arch, hf_config, hf_tokenizer, hf_model, max_seq_len, \n",
" context_attr, question_attr, answer_text_attr, \n",
" tok_ans_start_attr, tok_ans_end_attr)\n",
" \n",
" # bits required by our \"before_batch_tfm\" and DataBlock\n",
" vocab = list(range(max_seq_len))\n",
" padding_side = hf_tokenizer.padding_side\n",
" trunc_strat = 'only_second' if (padding_side == 'right') else 'only_first'\n",
"\n",
" before_batch_tfm = HF_QABeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,\n",
" max_length=max_seq_len, \n",
" truncation=trunc_strat, \n",
" tok_kwargs={ 'return_special_tokens_mask': True })\n",
" \n",
" # define getters\n",
" if (isinstance(data, pd.DataFrame)):\n",
" get_x = partial(cls._get_x, qst=question_attr, ctx=context_attr, padding_side=padding_side)\n",
" get_y = [ColReader(tok_ans_start_attr), ColReader(tok_ans_end_attr)]\n",
" else:\n",
" get_x = partial(cls._get_x, qst=question_attr, ctx=context_attr, padding_side=padding_side)\n",
" get_y = [ItemGetter(tok_ans_start_attr), ItemGetter(tok_ans_end_attr)]\n",
" \n",
" # define DataBlock and DataLoaders\n",
" blocks = (\n",
" HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_QuestionAnswerInput), \n",
" CategoryBlock(vocab=vocab),\n",
" CategoryBlock(vocab=vocab)\n",
" )\n",
" \n",
" dblock = DataBlock(blocks=blocks, \n",
" get_x=get_x,\n",
" get_y=get_y,\n",
" splitter=dblock_splitter,\n",
" n_inp=1)\n",
"\n",
" dls = dblock.dataloaders(data, **dl_kwargs.copy())\n",
" \n",
" # return BLearner instance\n",
" return cls(dls, hf_model, **learner_kwargs.copy())\n",
"\n",
" @classmethod\n",
" def from_dataframe(\n",
" cls, \n",
" # Your pandas DataFrame\n",
" df:pd.DataFrame, \n",
" # The name or path of the pretrained model you want to fine-tune\n",
" pretrained_model_name_or_path:Optional[Union[str, os.PathLike]],\n",
" # A function to perform any preprocessing required for your Dataset \n",
" preprocess_func:Callable=None, \n",
" # The maximum sequence length to constrain our data\n",
" max_seq_len:int=None,\n",
" # The attribute in your dataset that contains the context (where the answer is included) (default: 'context')\n",
" context_attr:str='context', \n",
" # The attribute in your dataset that contains the question being asked (default: 'question')\n",
" question_attr:str='question', \n",
" # The attribute in your dataset that contains the actual answer (default: 'answer_text')\n",
" answer_text_attr:str='answer_text',\n",
" # The attribute in your dataset that contains the tokenized answer start (default: 'tok_answer_start')\n",
" tok_ans_start_attr:str='tok_answer_start', \n",
" # The attribute in your dataset that contains the tokenized answer end(default: 'tok_answer_end')\n",
" tok_ans_end_attr:str='tok_answer_end', \n",
" # A function that will split your Dataset into a training and validation set\n",
" # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters\n",
" dblock_splitter:Callable=ColSplitter(), \n",
" # Any kwargs to pass to your `DataLoaders`\n",
" dl_kwargs={}, \n",
" # Any kwargs to pass to your task specific `Blearner`\n",
" learner_kwargs={}\n",
" ):\n",
" return cls._create_learner(df, pretrained_model_name_or_path, preprocess_func, max_seq_len,\n",
" context_attr, question_attr, answer_text_attr,\n",
" tok_ans_start_attr, tok_ans_end_attr, dblock_splitter,\n",
" dl_kwargs, learner_kwargs)\n",
" \n",
" @classmethod\n",
" def from_csv(\n",
" cls, \n",
" # The path to your csv file\n",
" csv_file:Union[Path, str],\n",
" # The name or path of the pretrained model you want to fine-tune\n",
" pretrained_model_name_or_path:Optional[Union[str, os.PathLike]],\n",
" # A function to perform any preprocessing required for your Dataset \n",
" preprocess_func:Callable=None, \n",
" # The maximum sequence length to constrain our data\n",
" max_seq_len:int=None,\n",
" # The attribute in your dataset that contains the context (where the answer is included) (default: 'context')\n",
" context_attr:str='context', \n",
" # The attribute in your dataset that contains the question being asked (default: 'question')\n",
" question_attr:str='question', \n",
" # The attribute in your dataset that contains the actual answer (default: 'answer_text')\n",
" answer_text_attr:str='answer_text',\n",
" # The attribute in your dataset that contains the tokenized answer start (default: 'tok_answer_start')\n",
" tok_ans_start_attr:str='tok_answer_start', \n",
" # The attribute in your dataset that contains the tokenized answer end(default: 'tok_answer_end')\n",
" tok_ans_end_attr:str='tok_answer_end', \n",
" # A function that will split your Dataset into a training and validation set\n",
" # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters\n",
" dblock_splitter:Callable=ColSplitter(), \n",
" # Any kwargs to pass to your `DataLoaders`\n",
" dl_kwargs={}, \n",
" # Any kwargs to pass to your task specific `Blearner`\n",
" learner_kwargs={}\n",
" ):\n",
" df = pd.read_csv(csv_file)\n",
" \n",
" return cls.from_dataframe(df, pretrained_model_name_or_path, preprocess_func, max_seq_len,\n",
" context_attr, question_attr, answer_text_attr,\n",
" tok_ans_start_attr, tok_ans_end_attr, dblock_splitter,\n",
" dl_kwargs, learner_kwargs)\n",
" \n",
" @classmethod\n",
" def from_dictionaries(\n",
" cls, \n",
" # A list of dictionaries\n",
" ds:List[Dict], \n",
" # The name or path of the pretrained model you want to fine-tune\n",
" pretrained_model_name_or_path:Optional[Union[str, os.PathLike]],\n",
" # A function to perform any preprocessing required for your Dataset \n",
" preprocess_func:Callable=None, \n",
" # The maximum sequence length to constrain our data\n",
" max_seq_len:int=None,\n",
" # The attribute in your dataset that contains the context (where the answer is included) (default: 'context')\n",
" context_attr:str='context', \n",
" # The attribute in your dataset that contains the question being asked (default: 'question')\n",
" question_attr:str='question', \n",
" # The attribute in your dataset that contains the actual answer (default: 'answer_text')\n",
" answer_text_attr:str='answer_text',\n",
" # The attribute in your dataset that contains the tokenized answer start (default: 'tok_answer_start')\n",
" tok_ans_start_attr:str='tok_answer_start', \n",
" # The attribute in your dataset that contains the tokenized answer end(default: 'tok_answer_end')\n",
" tok_ans_end_attr:str='tok_answer_end', \n",
" # A function that will split your Dataset into a training and validation set\n",
" # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters\n",
" dblock_splitter:Callable=RandomSplitter(), \n",
" # Any kwargs to pass to your `DataLoaders`\n",
" dl_kwargs={}, \n",
" # Any kwargs to pass to your task specific `Blearner`\n",
" learner_kwargs={}\n",
" ):\n",
" return cls._create_learner(ds, pretrained_model_name_or_path, preprocess_func, max_seq_len,\n",
" context_attr, question_attr, answer_text_attr, \n",
" tok_ans_start_attr, tok_ans_end_attr, dblock_splitter,\n",
" dl_kwargs, learner_kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZbJl8PQ9eoaO"
},
"source": [
"`BLearnerForQuestionAnswering` requires a question, context (within which to find the answer to the question), and the start/end indices of where the answer lies in the *tokenized context*. Because those indices vary by tokenizer, we can pass a `preprocess_func` that will take our raw data, perform any preprocessing we want, and return it in a way that will work for extractive QA."
]
},
{
"cell_type": "code",
"metadata": {
"id": "MnJSKgbfeoaO"
},
"source": [
"def preprocess_df(df, hf_arch, hf_config, hf_tokenizer, hf_model, max_seq_len, \n",
" context_attr, question_attr, answer_text_attr, tok_ans_start_attr, tok_ans_end_attr):\n",
" \n",
" df = df.apply(partial(pre_process_squad, hf_arch=hf_arch, hf_tokenizer=hf_tokenizer, ctx_attr=context_attr, \n",
" qst_attr=question_attr, ans_attr=answer_text_attr), axis=1)\n",
" \n",
" df = df[(df.tokenized_input_len < max_seq_len) & (df.is_impossible == False)]\n",
" \n",
" return df"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_1OF6B_geoaP"
},
"source": [
"Let's re-grab the raw data and use the high-level API to train"
]
},
{
"cell_type": "code",
"metadata": {
"id": "SxeaQ9eQeoaP"
},
"source": [
"squad_df = pd.read_csv('./squad_sample.csv')\n",
"\n",
"pretrained_model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'\n",
"\n",
"learn = BlearnerForQuestionAnswering.from_dataframe(squad_df, pretrained_model_name,\n",
" preprocess_func=preprocess_df, max_seq_len=128,\n",
" dblock_splitter=RandomSplitter(), \n",
" dl_kwargs={ 'bs': 4 }).to_fp16()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-DEYU82GeoaP",
"outputId": "a11ce1a9-18bc-496d-fd39-5654a8b35e0d"
},
"source": [
"learn.dls.show_batch(dataloaders=learn.dls, max_n=2, trunc_at=500)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>start/end</th>\n",
" <th>answer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>where did beyonce perform in 2011? in 2011, documents obtained by wikileaks revealed that beyonce was one of many entertainers who performed for the family of libyan ruler muammar gaddafi. rolling stone reported that the music industry was urging them to return the money they earned for the concerts ; a spokesperson for beyonce later confirmed to the huffington post that she donated the money to the clinton bush haiti fund. later that year she became the first solo female artist to headline the</td>\n",
" <td>(102, 107)</td>\n",
" <td>glastonbury festival</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>what language does she mainly sing? beyonce's music is generally r &amp; b, but she also incorporates pop, soul and funk into her songs. 4 demonstrated beyonce's exploration of 90s - style r &amp; b, as well as further use of soul and hip hop than compared to previous releases. while she almost exclusively releases english songs, beyonce recorded several spanish songs for irreemplazable ( re - recordings of songs from b'day for a spanish - language audience ), and the re - release of b'day. to record th</td>\n",
" <td>(67, 68)</td>\n",
" <td>english</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TT0jSci5eoaP",
"outputId": "06d92855-ef22-4be2-8d60-c77cc58e851a"
},
"source": [
"learn.fit_one_cycle(3, lr_max=1e-3)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4.252031</td>\n",
" <td>1.568532</td>\n",
" <td>00:05</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>2.521215</td>\n",
" <td>0.815515</td>\n",
" <td>00:05</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.775412</td>\n",
" <td>0.717965</td>\n",
" <td>00:05</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JBST4s-deoaP",
"outputId": "891e45a2-9495-42af-f3ee-f5214eed436f"
},
"source": [
"learn.show_results(learner=learn, skip_special_tokens=True, max_n=2, trunc_at=500)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>start/end</th>\n",
" <th>answer</th>\n",
" <th>pred start/end</th>\n",
" <th>pred answer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>how much bail money did they spend? following the death of freddie gray, beyonce and jay - z, among other notable figures, met with his family. after the imprisonment of protesters of gray's death, beyonce and jay - z donated thousands of dollars to bail them out.</td>\n",
" <td>(50, 53)</td>\n",
" <td>thousands of dollars</td>\n",
" <td>(50, 53)</td>\n",
" <td>thousands of dollars</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>how was the suit settled? the release of a video - game starpower : beyonce was cancelled after beyonce pulled out of a $ 100 million with gatefive who alleged the cancellation meant the sacking of 70 staff and millions of pounds lost in development. it was settled out of court by her lawyers in june 2013 who said that they had cancelled because gatefive had lost its financial backers. beyonce also has had deals with american express, nintendo ds and l'oreal since the age of 18.</td>\n",
" <td>(56, 59)</td>\n",
" <td>out of court</td>\n",
" <td>(56, 59)</td>\n",
" <td>out of court</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z3OHHY3KeoaP"
},
"source": [
"learn.loss_func = CrossEntropyLossFlat()\n",
"learn.export(fname=f'{export_name}.pkl')\n",
"inf_learn = load_learner(fname=f'{export_name}.pkl')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1uY1SWQ2eoaP",
"outputId": "3c8d0c46-e358-4876-cda5-768e72c17ac2"
},
"source": [
"inf_learn = load_learner(fname=f'{export_name}.pkl')\n",
"inf_learn.loss_func = MultiTargetLoss()\n",
"\n",
"inf_df = pd.DataFrame.from_dict([\n",
" {\n",
" 'question': 'What did George Lucas make?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What year did Star Wars come out?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }, {\n",
" 'question': 'What did George Lucas do?',\n",
" 'context': 'George Lucas created Star Wars in 1977. He directed and produced it.' \n",
" }], \n",
" orient='columns')\n",
"\n",
"inf_learn.blurr_predict(inf_df)"
],
"execution_count": null,
"outputs": [
{
"data": {
"text/plain": [
"[(('11', '13'),\n",
" (#2) [tensor(11),tensor(13)],\n",
" (#2) [tensor([2.7284e-07, 8.3190e-08, 3.7982e-09, 7.9587e-09, 6.0532e-09, 5.2288e-09,\n",
" 6.7031e-10, 2.7283e-07, 1.3389e-04, 1.2072e-05, 4.0296e-04, 9.9912e-01,\n",
" 2.9206e-04, 4.0871e-07, 1.7246e-05, 5.3299e-07, 7.4445e-06, 9.2675e-06,\n",
" 2.7095e-08, 4.0778e-06, 1.4063e-06, 2.7503e-07, 2.7421e-07, 3.4118e-10,\n",
" 4.1309e-10]),tensor([2.2255e-03, 1.9398e-04, 1.4811e-05, 6.0629e-06, 2.3273e-05, 1.2139e-05,\n",
" 4.9242e-05, 2.2257e-03, 6.3393e-05, 2.5946e-04, 2.7473e-04, 4.4387e-05,\n",
" 8.3237e-02, 4.8414e-01, 1.9951e-02, 3.9145e-01, 3.2640e-04, 1.6923e-04,\n",
" 4.2092e-04, 4.4922e-04, 1.0046e-02, 2.2117e-03, 2.1988e-03, 2.9638e-06,\n",
" 1.8428e-06])]),\n",
" (('16', '17'),\n",
" (#2) [tensor(16),tensor(17)],\n",
" (#2) [tensor([5.6653e-07, 2.0562e-06, 2.7093e-08, 1.3462e-08, 1.3675e-08, 9.8854e-09,\n",
" 1.6482e-08, 1.7449e-08, 8.9821e-09, 5.6657e-07, 6.3263e-07, 7.9726e-07,\n",
" 9.9635e-07, 2.6835e-06, 7.8746e-07, 2.4374e-05, 9.9996e-01, 1.0222e-06,\n",
" 1.4942e-07, 1.1580e-07, 2.3458e-08, 1.3362e-07, 3.7979e-07, 5.6670e-07,\n",
" 5.6669e-07]),tensor([2.9526e-03, 7.3836e-04, 6.0738e-04, 2.0252e-04, 7.6138e-05, 1.5352e-04,\n",
" 8.7644e-05, 1.7705e-04, 5.1405e-04, 2.9525e-03, 7.0142e-04, 1.1698e-03,\n",
" 7.6684e-04, 3.1254e-04, 1.2962e-03, 1.3482e-03, 1.0334e-02, 9.5583e-01,\n",
" 9.3413e-03, 8.3610e-04, 8.9849e-04, 7.6981e-04, 2.0273e-03, 2.9528e-03,\n",
" 2.9529e-03])]),\n",
" (('17', '21'),\n",
" (#2) [tensor(17),tensor(21)],\n",
" (#2) [tensor([7.6245e-06, 2.5309e-07, 4.6802e-08, 1.1821e-07, 5.4243e-08, 5.6511e-08,\n",
" 1.4816e-08, 7.6252e-06, 3.1773e-03, 3.0238e-04, 1.2024e-01, 4.0060e-04,\n",
" 3.5399e-05, 2.5930e-06, 7.0493e-05, 7.6382e-06, 9.2087e-02, 7.4348e-01,\n",
" 1.6772e-05, 4.0094e-02, 5.2537e-05, 1.6554e-05, 7.6619e-06, 6.5125e-09,\n",
" 9.3592e-09]),tensor([3.2386e-03, 3.2943e-05, 2.0672e-05, 7.7881e-06, 3.0920e-05, 1.7127e-05,\n",
" 5.0501e-05, 3.2389e-03, 6.1905e-05, 3.5101e-04, 6.3880e-04, 3.3710e-04,\n",
" 1.3215e-02, 1.3914e-02, 6.1565e-03, 3.2390e-03, 7.0665e-05, 1.4825e-04,\n",
" 1.5560e-03, 6.0942e-03, 3.0879e-01, 6.3557e-01, 3.2132e-03, 5.1384e-06,\n",
" 2.4964e-06])])]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_UoJcLXZeoaQ"
},
"source": [
"## Summary"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O9KJkRTFeoaQ"
},
"source": [
"This module includes all the low, mid, and high-level API bits for extractive Q&A tasks training and inference."
]
},
{
"cell_type": "code",
"metadata": {
"id": "0MI30ReceoaQ",
"outputId": "95d4aeda-bea4-4b8f-ad1d-b5e4b0efef64"
},
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
],
"execution_count": null,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_utils.ipynb.\n",
"Converted 01_data-core.ipynb.\n",
"Converted 01_modeling-core.ipynb.\n",
"Converted 02_data-language-modeling.ipynb.\n",
"Converted 02_modeling-language-modeling.ipynb.\n",
"Converted 03_data-token-classification.ipynb.\n",
"Converted 03_modeling-token-classification.ipynb.\n",
"Converted 04_data-question-answering.ipynb.\n",
"Converted 04_modeling-question-answering.ipynb.\n",
"Converted 10_data-seq2seq-core.ipynb.\n",
"Converted 10_modeling-seq2seq-core.ipynb.\n",
"Converted 11_data-seq2seq-summarization.ipynb.\n",
"Converted 11_modeling-seq2seq-summarization.ipynb.\n",
"Converted 12_data-seq2seq-translation.ipynb.\n",
"Converted 12_modeling-seq2seq-translation.ipynb.\n",
"Converted 99a_examples-high-level-api.ipynb.\n",
"Converted 99b_examples-glue.ipynb.\n",
"Converted 99c_examples-glue-plain-pytorch.ipynb.\n",
"Converted 99d_examples-multilabel.ipynb.\n",
"Converted 99e_examples-causal-lm-gpt2.ipynb.\n",
"Converted index.ipynb.\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mk3ZNNjLeoaQ"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment