Skip to content

Instantly share code, notes, and snippets.

@un1tz3r0
Last active June 21, 2024 07:17
Show Gist options
  • Save un1tz3r0/6a0a0808cebed689fff6742c3df17469 to your computer and use it in GitHub Desktop.
Save un1tz3r0/6a0a0808cebed689fff6742c3df17469 to your computer and use it in GitHub Desktop.
refusal_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"41720b13bae74e54a9fa2c570dde6f9c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_fac24f9d8d264ec2961bd61a023ac333",
"IPY_MODEL_86f0014193c24d32955c0c0e21b90bf5",
"IPY_MODEL_ad8735b93fef48cb833a884bc820c9b5"
],
"layout": "IPY_MODEL_c0b8b455ffa04c5ea06fad6fb9b5b98c"
}
},
"fac24f9d8d264ec2961bd61a023ac333": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6c801f43bb784821b7395f4f290e92ba",
"placeholder": "​",
"style": "IPY_MODEL_6b5f2f051c724d09bf2826b0bebd2ba0",
"value": "config.json: 100%"
}
},
"86f0014193c24d32955c0c0e21b90bf5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_cb3903028c014332aa97961d4768013e",
"max": 654,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_9c569b2737964b5ba453f21497e032d2",
"value": 654
}
},
"ad8735b93fef48cb833a884bc820c9b5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_04e216219058474fbab6de50812e0435",
"placeholder": "​",
"style": "IPY_MODEL_358a422051194b1fa581b0c58bf05d73",
"value": " 654/654 [00:00<00:00, 45.4kB/s]"
}
},
"c0b8b455ffa04c5ea06fad6fb9b5b98c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6c801f43bb784821b7395f4f290e92ba": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6b5f2f051c724d09bf2826b0bebd2ba0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"cb3903028c014332aa97961d4768013e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9c569b2737964b5ba453f21497e032d2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"04e216219058474fbab6de50812e0435": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"358a422051194b1fa581b0c58bf05d73": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"afcb5207e4dd4f4e8de7a93967866745": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_a41abec678404238856a68cb4ea45632",
"IPY_MODEL_cb199eb476254ff788eb53ee72a2e927",
"IPY_MODEL_f6bb201707d04190b1045654791e56f7"
],
"layout": "IPY_MODEL_17077d8f905a4a47a810ef462c2a15c1"
}
},
"a41abec678404238856a68cb4ea45632": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_aeca3c94581d45f0ae61b8e3e5b57c2e",
"placeholder": "​",
"style": "IPY_MODEL_5ba652d347174a849853bbda04cbba14",
"value": "model.safetensors.index.json: 100%"
}
},
"cb199eb476254ff788eb53ee72a2e927": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_45d2d5a22e0b4c4a865b522aa23875e2",
"max": 23950,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_94535cd1198341a48cb22e53a81d7877",
"value": 23950
}
},
"f6bb201707d04190b1045654791e56f7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ebafc69cdc0f4a669401242b3e579eb1",
"placeholder": "​",
"style": "IPY_MODEL_303f73ea47d84871aea4dd645147dbe7",
"value": " 23.9k/23.9k [00:00<00:00, 1.64MB/s]"
}
},
"17077d8f905a4a47a810ef462c2a15c1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"aeca3c94581d45f0ae61b8e3e5b57c2e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5ba652d347174a849853bbda04cbba14": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"45d2d5a22e0b4c4a865b522aa23875e2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"94535cd1198341a48cb22e53a81d7877": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"ebafc69cdc0f4a669401242b3e579eb1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"303f73ea47d84871aea4dd645147dbe7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7b30f98e36fd4b2fa71231791fff4514": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_051f7608c8b24882b04ad0e287e33f6e",
"IPY_MODEL_94df276378e74ae0a17f3919277a2102",
"IPY_MODEL_110d7422b9f4466f83b9b0de741ca85d"
],
"layout": "IPY_MODEL_7b3e5d3cc2b5495ba8295cfc12085473"
}
},
"051f7608c8b24882b04ad0e287e33f6e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6315bee9cbbe4966becb22b83faddfea",
"placeholder": "​",
"style": "IPY_MODEL_2efde713df514325ab2fc4861d7e74f1",
"value": "Downloading shards:   0%"
}
},
"94df276378e74ae0a17f3919277a2102": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fd99a60b4fb141409e20a73c10dd604e",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_01e09df184d7431a9211f763fbdcb116",
"value": 0
}
},
"110d7422b9f4466f83b9b0de741ca85d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_af96395f4ca3498dba84e1cd3510c345",
"placeholder": "​",
"style": "IPY_MODEL_0163186da54a4b909bef76b2c8776e0c",
"value": " 0/4 [00:00<?, ?it/s]"
}
},
"7b3e5d3cc2b5495ba8295cfc12085473": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6315bee9cbbe4966becb22b83faddfea": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2efde713df514325ab2fc4861d7e74f1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"fd99a60b4fb141409e20a73c10dd604e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"01e09df184d7431a9211f763fbdcb116": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"af96395f4ca3498dba84e1cd3510c345": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"0163186da54a4b909bef76b2c8776e0c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7e46fdf68e9f4b62ad152fddb4e47ee1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_f35e9c05295b47d4b84cfadd4a199a43",
"IPY_MODEL_a1d7c504aba74e95a2cb87b0890ef007",
"IPY_MODEL_35d1329b8294434798603fc32af10dc7"
],
"layout": "IPY_MODEL_775028d9af414b01a2474461956d784b"
}
},
"f35e9c05295b47d4b84cfadd4a199a43": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2f3ad398f20c42089719ea517c51b98d",
"placeholder": "​",
"style": "IPY_MODEL_5ecac4746dd648c4b854922d36fc46c0",
"value": "model-00001-of-00004.safetensors:  60%"
}
},
"a1d7c504aba74e95a2cb87b0890ef007": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_86355ba0065f4a1087ab97c76e2848ae",
"max": 4976698672,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_a47411eddbe8467fb42dcaf03c13ba8c",
"value": 3019898880
}
},
"35d1329b8294434798603fc32af10dc7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8cc19ef932124884b6b630e090026fa2",
"placeholder": "​",
"style": "IPY_MODEL_c5dfe0778dc04bb2b864c42cea4f1c11",
"value": " 2.99G/4.98G [00:22<00:09, 204MB/s]"
}
},
"775028d9af414b01a2474461956d784b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2f3ad398f20c42089719ea517c51b98d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5ecac4746dd648c4b854922d36fc46c0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"86355ba0065f4a1087ab97c76e2848ae": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a47411eddbe8467fb42dcaf03c13ba8c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"8cc19ef932124884b6b630e090026fa2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c5dfe0778dc04bb2b864c42cea4f1c11": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/un1tz3r0/6a0a0808cebed689fff6742c3df17469/refusal_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Demo of bypassing refusal"
],
"metadata": {
"id": "82acAhWYGIPx"
}
},
{
"cell_type": "markdown",
"source": [
">[Demo of bypassing refusal](#scrollTo=82acAhWYGIPx)\n",
"\n",
">>[Setup](#scrollTo=fcxHyDZw6b86)\n",
"\n",
">>>[Load model](#scrollTo=6ZOoJagxD49V)\n",
"\n",
">>>[Load harmful / harmless datasets](#scrollTo=rF7e-u20EFTe)\n",
"\n",
">>>[Tokenization utils](#scrollTo=KOKYA61k8LWt)\n",
"\n",
">>>[Generation utils](#scrollTo=gtrIK8x78SZh)\n",
"\n",
">>[Finding the \"refusal direction\"](#scrollTo=W9O8dm0_EQRk)\n",
"\n",
">>[Ablate \"refusal direction\" via inference-time intervention](#scrollTo=2EoxY5i1CWe3)\n",
"\n",
">>[Orthogonalize weights w.r.t. \"refusal direction\"](#scrollTo=t9KooaWaCDc_)\n",
"\n"
],
"metadata": {
"colab_type": "toc",
"id": "s-_vu8HuGEb-"
}
},
{
"cell_type": "markdown",
"source": [
"This notebook demonstrates our method for bypassing refusal, levaraging the insight that refusal is mediated by a 1-dimensional subspace. We recommend reading the [research post](https://www.lesswrong.com/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction) for a more thorough explanation.\n",
"\n",
"In this demo, we use [Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat) and implement interventions and weight updates using [TransformerLens](https://github.com/neelnanda-io/TransformerLens). To extract the \"refusal direction,\" we use just 32 harmful instructions from [AdvBench](https://github.com/llm-attacks/llm-attacks/blob/main/data/advbench/harmful_behaviors.csv) and 32 harmless instructions from [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)."
],
"metadata": {
"id": "j7hOtw7UHXdD"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "fcxHyDZw6b86"
}
},
{
"cell_type": "code",
"source": [
"%pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama"
],
"metadata": {
"id": "dLeei4-T6Wef",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "58ac03f5-94b8-460c-c1e0-2b0bfcafc6c4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.0)\n",
"Requirement already satisfied: transformers_stream_generator in /usr/local/lib/python3.10/dist-packages (0.0.5)\n",
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (0.7.0)\n",
"Requirement already satisfied: transformer_lens in /usr/local/lib/python3.10/dist-packages (1.17.0)\n",
"Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (0.8.0)\n",
"Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (0.2.28)\n",
"Requirement already satisfied: colorama in /usr/local/lib/python3.10/dist-packages (0.4.6)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.14.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n",
"Requirement already satisfied: accelerate>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.30.1)\n",
"Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.14.1)\n",
"Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.0.3)\n",
"Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (2.19.1)\n",
"Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.0.3)\n",
"Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (2.0.3)\n",
"Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (13.7.1)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.1.99)\n",
"Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (2.3.0+cu121)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (4.11.0)\n",
"Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.10/dist-packages (from transformer_lens) (0.17.0)\n",
"Requirement already satisfied: typeguard==2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping) (2.13.3)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer_lens) (5.9.5)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (14.0.2)\n",
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n",
"Requirement already satisfied: fsspec[http]<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer_lens) (3.9.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer_lens) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer_lens) (2023.4)\n",
"Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer_lens) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer_lens) (2.16.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (3.1.4)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (8.9.2.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (2.20.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (12.1.105)\n",
"Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer_lens) (2.3.0)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10->transformer_lens) (12.5.40)\n",
"Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n",
"Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n",
"Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n",
"Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (4.2.2)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (3.20.3)\n",
"Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (2.3.1)\n",
"Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer_lens) (67.7.2)\n",
"Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (4.0.3)\n",
"Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->transformer_lens) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->transformer_lens) (1.3.0)\n",
"Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import functools\n",
"import einops\n",
"import requests\n",
"import pandas as pd\n",
"import io\n",
"import textwrap\n",
"import gc\n",
"\n",
"from datasets import load_dataset\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm import tqdm\n",
"from torch import Tensor\n",
"from typing import List, Callable\n",
"from transformer_lens import HookedTransformer, utils\n",
"from transformer_lens.hook_points import HookPoint\n",
"from transformers import AutoTokenizer\n",
"from jaxtyping import Float, Int\n",
"from colorama import Fore"
],
"metadata": {
"id": "_vhhwl-2-jPg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Load model"
],
"metadata": {
"id": "6ZOoJagxD49V"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"41720b13bae74e54a9fa2c570dde6f9c",
"fac24f9d8d264ec2961bd61a023ac333",
"86f0014193c24d32955c0c0e21b90bf5",
"ad8735b93fef48cb833a884bc820c9b5",
"c0b8b455ffa04c5ea06fad6fb9b5b98c",
"6c801f43bb784821b7395f4f290e92ba",
"6b5f2f051c724d09bf2826b0bebd2ba0",
"cb3903028c014332aa97961d4768013e",
"9c569b2737964b5ba453f21497e032d2",
"04e216219058474fbab6de50812e0435",
"358a422051194b1fa581b0c58bf05d73",
"afcb5207e4dd4f4e8de7a93967866745",
"a41abec678404238856a68cb4ea45632",
"cb199eb476254ff788eb53ee72a2e927",
"f6bb201707d04190b1045654791e56f7",
"17077d8f905a4a47a810ef462c2a15c1",
"aeca3c94581d45f0ae61b8e3e5b57c2e",
"5ba652d347174a849853bbda04cbba14",
"45d2d5a22e0b4c4a865b522aa23875e2",
"94535cd1198341a48cb22e53a81d7877",
"ebafc69cdc0f4a669401242b3e579eb1",
"303f73ea47d84871aea4dd645147dbe7",
"7b30f98e36fd4b2fa71231791fff4514",
"051f7608c8b24882b04ad0e287e33f6e",
"94df276378e74ae0a17f3919277a2102",
"110d7422b9f4466f83b9b0de741ca85d",
"7b3e5d3cc2b5495ba8295cfc12085473",
"6315bee9cbbe4966becb22b83faddfea",
"2efde713df514325ab2fc4861d7e74f1",
"fd99a60b4fb141409e20a73c10dd604e",
"01e09df184d7431a9211f763fbdcb116",
"af96395f4ca3498dba84e1cd3510c345",
"0163186da54a4b909bef76b2c8776e0c",
"7e46fdf68e9f4b62ad152fddb4e47ee1",
"f35e9c05295b47d4b84cfadd4a199a43",
"a1d7c504aba74e95a2cb87b0890ef007",
"35d1329b8294434798603fc32af10dc7",
"775028d9af414b01a2474461956d784b",
"2f3ad398f20c42089719ea517c51b98d",
"5ecac4746dd648c4b854922d36fc46c0",
"86355ba0065f4a1087ab97c76e2848ae",
"a47411eddbe8467fb42dcaf03c13ba8c",
"8cc19ef932124884b6b630e090026fa2",
"c5dfe0778dc04bb2b864c42cea4f1c11"
]
},
"id": "Vnp65Vsg5x-5",
"outputId": "3dc71111-52bc-43f0-9be1-7abf82289e83"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"FullArgSpec(args=['cls', 'model_name', 'fold_ln', 'center_writing_weights', 'center_unembed', 'refactor_factored_attn_matrices', 'fold_value_biases', 'dtype', 'default_prepend_bos', 'default_padding_side'], varargs=None, varkw='from_pretrained_kwargs', defaults=(False, False, False, False, False, torch.float32, True, 'right'), kwonlyargs=[], kwonlydefaults=None, annotations={'model_name': <class 'str'>})\n",
"{ '01-ai/yi-34b': '01-ai/Yi-34B',\n",
" '01-ai/yi-34b-chat': '01-ai/Yi-34B-Chat',\n",
" '01-ai/yi-6b': '01-ai/Yi-6B',\n",
" '01-ai/yi-6b-chat': '01-ai/Yi-6B-Chat',\n",
" 'alias-gpt2-small-x21': 'stanford-crfm/alias-gpt2-small-x21',\n",
" 'arthurconmy/redwood_attn_2l': 'ArthurConmy/redwood_attn_2l',\n",
" 'arwen-gpt2-medium-x21': 'stanford-crfm/arwen-gpt2-medium-x21',\n",
" 'attn-only-1l': 'NeelNanda/Attn_Only_1L512W_C4_Code',\n",
" 'attn-only-1l-c4-code': 'NeelNanda/Attn_Only_1L512W_C4_Code',\n",
" 'attn-only-1l-new': 'NeelNanda/Attn_Only_1L512W_C4_Code',\n",
" 'attn-only-2l': 'NeelNanda/Attn_Only_2L512W_C4_Code',\n",
" 'attn-only-2l-c4-code': 'NeelNanda/Attn_Only_2L512W_C4_Code',\n",
" 'attn-only-2l-demo': 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr',\n",
" 'attn-only-2l-induction-demo': 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr',\n",
" 'attn-only-2l-new': 'NeelNanda/Attn_Only_2L512W_C4_Code',\n",
" 'attn-only-2l-shortformer-6b-big-lr': 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr',\n",
" 'attn-only-3l': 'NeelNanda/Attn_Only_3L512W_C4_Code',\n",
" 'attn-only-3l-c4-code': 'NeelNanda/Attn_Only_3L512W_C4_Code',\n",
" 'attn-only-3l-new': 'NeelNanda/Attn_Only_3L512W_C4_Code',\n",
" 'attn-only-4l': 'NeelNanda/Attn_Only_4L512W_C4_Code',\n",
" 'attn-only-4l-c4-code': 'NeelNanda/Attn_Only_4L512W_C4_Code',\n",
" 'attn-only-4l-new': 'NeelNanda/Attn_Only_4L512W_C4_Code',\n",
" 'attn-only-demo': 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr',\n",
" 'baidicoot/othello-gpt-transformer-lens': 'Baidicoot/Othello-GPT-Transformer-Lens',\n",
" 'battlestar-gpt2-small-x49': 'stanford-crfm/battlestar-gpt2-small-x49',\n",
" 'beren-gpt2-medium-x49': 'stanford-crfm/beren-gpt2-medium-x49',\n",
" 'bert-base-cased': 'bert-base-cased',\n",
" 'bigcode/santacoder': 'bigcode/santacoder',\n",
" 'bigscience/bloom-1b1': 'bigscience/bloom-1b1',\n",
" 'bigscience/bloom-1b7': 'bigscience/bloom-1b7',\n",
" 'bigscience/bloom-3b': 'bigscience/bloom-3b',\n",
" 'bigscience/bloom-560m': 'bigscience/bloom-560m',\n",
" 'bigscience/bloom-7b1': 'bigscience/bloom-7b1',\n",
" 'bloom-1b1': 'bigscience/bloom-1b1',\n",
" 'bloom-1b7': 'bigscience/bloom-1b7',\n",
" 'bloom-3b': 'bigscience/bloom-3b',\n",
" 'bloom-560m': 'bigscience/bloom-560m',\n",
" 'bloom-7b1': 'bigscience/bloom-7b1',\n",
" 'caprica-gpt2-small-x81': 'stanford-crfm/caprica-gpt2-small-x81',\n",
" 'celebrimbor-gpt2-medium-x81': 'stanford-crfm/celebrimbor-gpt2-medium-x81',\n",
" 'codellama-7b-hf': 'CodeLlama-7b-hf',\n",
" 'codellama-7b-instruct': 'CodeLlama-7b-Instruct-hf',\n",
" 'codellama-7b-instruct-hf': 'CodeLlama-7b-Instruct-hf',\n",
" 'codellama-7b-python': 'CodeLlama-7b-Python-hf',\n",
" 'codellama-7b-python-hf': 'CodeLlama-7b-Python-hf',\n",
" 'codellama/codellama-7b-hf': 'CodeLlama-7b-hf',\n",
" 'codellama/codellama-7b-instruct-hf': 'CodeLlama-7b-Instruct-hf',\n",
" 'codellama/codellama-7b-python-hf': 'CodeLlama-7b-Python-hf',\n",
" 'codellamallama-2-7b': 'CodeLlama-7b-hf',\n",
" 'darkmatter-gpt2-small-x343': 'stanford-crfm/darkmatter-gpt2-small-x343',\n",
" 'distil-gpt2': 'distilgpt2',\n",
" 'distilgpt2': 'distilgpt2',\n",
" 'distill-gpt2': 'distilgpt2',\n",
" 'distillgpt2': 'distilgpt2',\n",
" 'durin-gpt2-medium-x343': 'stanford-crfm/durin-gpt2-medium-x343',\n",
" 'eleutherai/gpt-j-6b': 'EleutherAI/gpt-j-6B',\n",
" 'eleutherai/gpt-neo-1.3b': 'EleutherAI/gpt-neo-1.3B',\n",
" 'eleutherai/gpt-neo-125m': 'EleutherAI/gpt-neo-125M',\n",
" 'eleutherai/gpt-neo-2.7b': 'EleutherAI/gpt-neo-2.7B',\n",
" 'eleutherai/gpt-neox-20b': 'EleutherAI/gpt-neox-20b',\n",
" 'eleutherai/pythia-1.3b': 'EleutherAI/pythia-1.4b',\n",
" 'eleutherai/pythia-1.3b-deduped': 'EleutherAI/pythia-1.4b-deduped',\n",
" 'eleutherai/pythia-1.3b-deduped-v0': 'EleutherAI/pythia-1.4b-deduped-v0',\n",
" 'eleutherai/pythia-1.3b-v0': 'EleutherAI/pythia-1.4b-v0',\n",
" 'eleutherai/pythia-1.4b': 'EleutherAI/pythia-1.4b',\n",
" 'eleutherai/pythia-1.4b-deduped': 'EleutherAI/pythia-1.4b-deduped',\n",
" 'eleutherai/pythia-1.4b-deduped-v0': 'EleutherAI/pythia-1.4b-deduped-v0',\n",
" 'eleutherai/pythia-1.4b-v0': 'EleutherAI/pythia-1.4b-v0',\n",
" 'eleutherai/pythia-125m': 'EleutherAI/pythia-160m',\n",
" 'eleutherai/pythia-125m-deduped': 'EleutherAI/pythia-160m-deduped',\n",
" 'eleutherai/pythia-125m-deduped-v0': 'EleutherAI/pythia-160m-deduped-v0',\n",
" 'eleutherai/pythia-125m-seed1': 'EleutherAI/pythia-160m-seed1',\n",
" 'eleutherai/pythia-125m-seed2': 'EleutherAI/pythia-160m-seed2',\n",
" 'eleutherai/pythia-125m-seed3': 'EleutherAI/pythia-160m-seed3',\n",
" 'eleutherai/pythia-125m-v0': 'EleutherAI/pythia-160m-v0',\n",
" 'eleutherai/pythia-12b': 'EleutherAI/pythia-12b',\n",
" 'eleutherai/pythia-12b-deduped': 'EleutherAI/pythia-12b-deduped',\n",
" 'eleutherai/pythia-12b-deduped-v0': 'EleutherAI/pythia-12b-deduped-v0',\n",
" 'eleutherai/pythia-12b-v0': 'EleutherAI/pythia-12b-v0',\n",
" 'eleutherai/pythia-13b': 'EleutherAI/pythia-12b',\n",
" 'eleutherai/pythia-13b-deduped': 'EleutherAI/pythia-12b-deduped',\n",
" 'eleutherai/pythia-13b-deduped-v0': 'EleutherAI/pythia-12b-deduped-v0',\n",
" 'eleutherai/pythia-13b-v0': 'EleutherAI/pythia-12b-v0',\n",
" 'eleutherai/pythia-14m': 'EleutherAI/pythia-14m',\n",
" 'eleutherai/pythia-160m': 'EleutherAI/pythia-160m',\n",
" 'eleutherai/pythia-160m-deduped': 'EleutherAI/pythia-160m-deduped',\n",
" 'eleutherai/pythia-160m-deduped-v0': 'EleutherAI/pythia-160m-deduped-v0',\n",
" 'eleutherai/pythia-160m-seed1': 'EleutherAI/pythia-160m-seed1',\n",
" 'eleutherai/pythia-160m-seed2': 'EleutherAI/pythia-160m-seed2',\n",
" 'eleutherai/pythia-160m-seed3': 'EleutherAI/pythia-160m-seed3',\n",
" 'eleutherai/pythia-160m-v0': 'EleutherAI/pythia-160m-v0',\n",
" 'eleutherai/pythia-19m': 'EleutherAI/pythia-70m',\n",
" 'eleutherai/pythia-19m-deduped': 'EleutherAI/pythia-70m-deduped',\n",
" 'eleutherai/pythia-19m-deduped-v0': 'EleutherAI/pythia-70m-deduped-v0',\n",
" 'eleutherai/pythia-19m-v0': 'EleutherAI/pythia-70m-v0',\n",
" 'eleutherai/pythia-1b': 'EleutherAI/pythia-1b',\n",
" 'eleutherai/pythia-1b-deduped': 'EleutherAI/pythia-1b-deduped',\n",
" 'eleutherai/pythia-1b-deduped-v0': 'EleutherAI/pythia-1b-deduped-v0',\n",
" 'eleutherai/pythia-1b-v0': 'EleutherAI/pythia-1b-v0',\n",
" 'eleutherai/pythia-2.7b': 'EleutherAI/pythia-2.8b',\n",
" 'eleutherai/pythia-2.7b-deduped': 'EleutherAI/pythia-2.8b-deduped',\n",
" 'eleutherai/pythia-2.7b-deduped-v0': 'EleutherAI/pythia-2.8b-deduped-v0',\n",
" 'eleutherai/pythia-2.7b-v0': 'EleutherAI/pythia-2.8b-v0',\n",
" 'eleutherai/pythia-2.8b': 'EleutherAI/pythia-2.8b',\n",
" 'eleutherai/pythia-2.8b-deduped': 'EleutherAI/pythia-2.8b-deduped',\n",
" 'eleutherai/pythia-2.8b-deduped-v0': 'EleutherAI/pythia-2.8b-deduped-v0',\n",
" 'eleutherai/pythia-2.8b-v0': 'EleutherAI/pythia-2.8b-v0',\n",
" 'eleutherai/pythia-31m': 'EleutherAI/pythia-31m',\n",
" 'eleutherai/pythia-350m': 'EleutherAI/pythia-410m',\n",
" 'eleutherai/pythia-350m-deduped': 'EleutherAI/pythia-410m-deduped',\n",
" 'eleutherai/pythia-350m-deduped-v0': 'EleutherAI/pythia-410m-deduped-v0',\n",
" 'eleutherai/pythia-350m-v0': 'EleutherAI/pythia-410m-v0',\n",
" 'eleutherai/pythia-410m': 'EleutherAI/pythia-410m',\n",
" 'eleutherai/pythia-410m-deduped': 'EleutherAI/pythia-410m-deduped',\n",
" 'eleutherai/pythia-410m-deduped-v0': 'EleutherAI/pythia-410m-deduped-v0',\n",
" 'eleutherai/pythia-410m-v0': 'EleutherAI/pythia-410m-v0',\n",
" 'eleutherai/pythia-6.7b': 'EleutherAI/pythia-6.9b',\n",
" 'eleutherai/pythia-6.7b-deduped': 'EleutherAI/pythia-6.9b-deduped',\n",
" 'eleutherai/pythia-6.7b-deduped-v0': 'EleutherAI/pythia-6.9b-deduped-v0',\n",
" 'eleutherai/pythia-6.7b-v0': 'EleutherAI/pythia-6.9b-v0',\n",
" 'eleutherai/pythia-6.9b': 'EleutherAI/pythia-6.9b',\n",
" 'eleutherai/pythia-6.9b-deduped': 'EleutherAI/pythia-6.9b-deduped',\n",
" 'eleutherai/pythia-6.9b-deduped-v0': 'EleutherAI/pythia-6.9b-deduped-v0',\n",
" 'eleutherai/pythia-6.9b-v0': 'EleutherAI/pythia-6.9b-v0',\n",
" 'eleutherai/pythia-70m': 'EleutherAI/pythia-70m',\n",
" 'eleutherai/pythia-70m-deduped': 'EleutherAI/pythia-70m-deduped',\n",
" 'eleutherai/pythia-70m-deduped-v0': 'EleutherAI/pythia-70m-deduped-v0',\n",
" 'eleutherai/pythia-70m-v0': 'EleutherAI/pythia-70m-v0',\n",
" 'eleutherai/pythia-800m': 'EleutherAI/pythia-1b',\n",
" 'eleutherai/pythia-800m-deduped': 'EleutherAI/pythia-1b-deduped',\n",
" 'eleutherai/pythia-800m-deduped-v0': 'EleutherAI/pythia-1b-deduped-v0',\n",
" 'eleutherai/pythia-800m-v0': 'EleutherAI/pythia-1b-v0',\n",
" 'eowyn-gpt2-medium-x777': 'stanford-crfm/eowyn-gpt2-medium-x777',\n",
" 'expanse-gpt2-small-x777': 'stanford-crfm/expanse-gpt2-small-x777',\n",
" 'facebook/opt-1.3b': 'facebook/opt-1.3b',\n",
" 'facebook/opt-125m': 'facebook/opt-125m',\n",
" 'facebook/opt-13b': 'facebook/opt-13b',\n",
" 'facebook/opt-2.7b': 'facebook/opt-2.7b',\n",
" 'facebook/opt-30b': 'facebook/opt-30b',\n",
" 'facebook/opt-6.7b': 'facebook/opt-6.7b',\n",
" 'facebook/opt-66b': 'facebook/opt-66b',\n",
" 'gelu-1l': 'NeelNanda/GELU_1L512W_C4_Code',\n",
" 'gelu-1l-c4-code': 'NeelNanda/GELU_1L512W_C4_Code',\n",
" 'gelu-1l-new': 'NeelNanda/GELU_1L512W_C4_Code',\n",
" 'gelu-2l': 'NeelNanda/GELU_2L512W_C4_Code',\n",
" 'gelu-2l-c4-code': 'NeelNanda/GELU_2L512W_C4_Code',\n",
" 'gelu-2l-new': 'NeelNanda/GELU_2L512W_C4_Code',\n",
" 'gelu-3l': 'NeelNanda/GELU_3L512W_C4_Code',\n",
" 'gelu-3l-c4-code': 'NeelNanda/GELU_3L512W_C4_Code',\n",
" 'gelu-3l-new': 'NeelNanda/GELU_3L512W_C4_Code',\n",
" 'gelu-4l': 'NeelNanda/GELU_4L512W_C4_Code',\n",
" 'gelu-4l-c4-code': 'NeelNanda/GELU_4L512W_C4_Code',\n",
" 'gelu-4l-new': 'NeelNanda/GELU_4L512W_C4_Code',\n",
" 'gemma-2b': 'google/gemma-2b',\n",
" 'gemma-2b-it': 'google/gemma-2b-it',\n",
" 'gemma-7b': 'google/gemma-7b',\n",
" 'gemma-7b-it': 'google/gemma-7b-it',\n",
" 'google/gemma-2b': 'google/gemma-2b',\n",
" 'google/gemma-2b-it': 'google/gemma-2b-it',\n",
" 'google/gemma-7b': 'google/gemma-7b',\n",
" 'google/gemma-7b-it': 'google/gemma-7b-it',\n",
" 'gpt-j': 'EleutherAI/gpt-j-6B',\n",
" 'gpt-j-6b': 'EleutherAI/gpt-j-6B',\n",
" 'gpt-neo-1.3b': 'EleutherAI/gpt-neo-1.3B',\n",
" 'gpt-neo-125m': 'EleutherAI/gpt-neo-125M',\n",
" 'gpt-neo-2.7b': 'EleutherAI/gpt-neo-2.7B',\n",
" 'gpt-neo-large': 'EleutherAI/gpt-neo-2.7B',\n",
" 'gpt-neo-medium': 'EleutherAI/gpt-neo-1.3B',\n",
" 'gpt-neo-small': 'EleutherAI/gpt-neo-125M',\n",
" 'gpt-neox': 'EleutherAI/gpt-neox-20b',\n",
" 'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',\n",
" 'gpt2': 'gpt2',\n",
" 'gpt2-large': 'gpt2-large',\n",
" 'gpt2-medium': 'gpt2-medium',\n",
" 'gpt2-medium-small-a': 'stanford-crfm/arwen-gpt2-medium-x21',\n",
" 'gpt2-medium-small-b': 'stanford-crfm/beren-gpt2-medium-x49',\n",
" 'gpt2-medium-small-c': 'stanford-crfm/celebrimbor-gpt2-medium-x81',\n",
" 'gpt2-medium-small-d': 'stanford-crfm/durin-gpt2-medium-x343',\n",
" 'gpt2-medium-small-e': 'stanford-crfm/eowyn-gpt2-medium-x777',\n",
" 'gpt2-mistral-small-a': 'stanford-crfm/alias-gpt2-small-x21',\n",
" 'gpt2-mistral-small-b': 'stanford-crfm/battlestar-gpt2-small-x49',\n",
" 'gpt2-mistral-small-c': 'stanford-crfm/caprica-gpt2-small-x81',\n",
" 'gpt2-mistral-small-d': 'stanford-crfm/darkmatter-gpt2-small-x343',\n",
" 'gpt2-mistral-small-e': 'stanford-crfm/expanse-gpt2-small-x777',\n",
" 'gpt2-small': 'gpt2',\n",
" 'gpt2-stanford-medium-a': 'stanford-crfm/arwen-gpt2-medium-x21',\n",
" 'gpt2-stanford-medium-b': 'stanford-crfm/beren-gpt2-medium-x49',\n",
" 'gpt2-stanford-medium-d': 'stanford-crfm/durin-gpt2-medium-x343',\n",
" 'gpt2-stanford-medium-e': 'stanford-crfm/eowyn-gpt2-medium-x777',\n",
" 'gpt2-stanford-small-a': 'stanford-crfm/alias-gpt2-small-x21',\n",
" 'gpt2-stanford-small-c': 'stanford-crfm/caprica-gpt2-small-x81',\n",
" 'gpt2-xl': 'gpt2-xl',\n",
" 'gpt2-xs': 'distilgpt2',\n",
" 'gptj': 'EleutherAI/gpt-j-6B',\n",
" 'llama-13b': 'llama-13b-hf',\n",
" 'llama-13b-hf': 'llama-13b-hf',\n",
" 'llama-2-13b': 'meta-llama/Llama-2-13b-hf',\n",
" 'llama-2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf',\n",
" 'llama-2-70b-chat': 'meta-llama/Llama-2-70b-chat-hf',\n",
" 'llama-2-7b': 'meta-llama/Llama-2-7b-hf',\n",
" 'llama-2-7b-chat': 'meta-llama/Llama-2-7b-chat-hf',\n",
" 'llama-30b': 'llama-30b-hf',\n",
" 'llama-30b-hf': 'llama-30b-hf',\n",
" 'llama-65b': 'llama-65b-hf',\n",
" 'llama-65b-hf': 'llama-65b-hf',\n",
" 'llama-7b': 'llama-7b-hf',\n",
" 'llama-7b-hf': 'llama-7b-hf',\n",
" 'meta-llama-2-70b-chat-hf': 'meta-llama/Llama-2-70b-chat-hf',\n",
" 'meta-llama/llama-2-13b-chat-hf': 'meta-llama/Llama-2-13b-chat-hf',\n",
" 'meta-llama/llama-2-13b-hf': 'meta-llama/Llama-2-13b-hf',\n",
" 'meta-llama/llama-2-70b-chat-hf': 'meta-llama/Llama-2-70b-chat-hf',\n",
" 'meta-llama/llama-2-7b-chat-hf': 'meta-llama/Llama-2-7b-chat-hf',\n",
" 'meta-llama/llama-2-7b-hf': 'meta-llama/Llama-2-7b-hf',\n",
" 'meta-llama/meta-llama-3-70b': 'meta-llama/Meta-Llama-3-70B',\n",
" 'meta-llama/meta-llama-3-70b-instruct': 'meta-llama/Meta-Llama-3-70B-Instruct',\n",
" 'meta-llama/meta-llama-3-8b': 'meta-llama/Meta-Llama-3-8B',\n",
" 'meta-llama/meta-llama-3-8b-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct',\n",
" 'microsoft/phi-1': 'microsoft/phi-1',\n",
" 'microsoft/phi-1_5': 'microsoft/phi-1_5',\n",
" 'microsoft/phi-2': 'microsoft/phi-2',\n",
" 'mistral-7b': 'mistralai/Mistral-7B-v0.1',\n",
" 'mistral-7b-instruct': 'mistralai/Mistral-7B-Instruct-v0.1',\n",
" 'mistralai/mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1',\n",
" 'mistralai/mistral-7b-v0.1': 'mistralai/Mistral-7B-v0.1',\n",
" 'mistralai/mixtral-8x7b-instruct-v0.1': 'mistralai/Mixtral-8x7B-Instruct-v0.1',\n",
" 'mistralai/mixtral-8x7b-v0.1': 'mistralai/Mixtral-8x7B-v0.1',\n",
" 'mixtral': 'mistralai/Mixtral-8x7B-v0.1',\n",
" 'mixtral-8x7b': 'mistralai/Mixtral-8x7B-v0.1',\n",
" 'mixtral-8x7b-instruct': 'mistralai/Mixtral-8x7B-Instruct-v0.1',\n",
" 'mixtral-instruct': 'mistralai/Mixtral-8x7B-Instruct-v0.1',\n",
" 'neelnanda/attn-only-2l512w-shortformer-6b-big-lr': 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr',\n",
" 'neelnanda/attn_only_1l512w_c4_code': 'NeelNanda/Attn_Only_1L512W_C4_Code',\n",
" 'neelnanda/attn_only_2l512w_c4_code': 'NeelNanda/Attn_Only_2L512W_C4_Code',\n",
" 'neelnanda/attn_only_3l512w_c4_code': 'NeelNanda/Attn_Only_3L512W_C4_Code',\n",
" 'neelnanda/attn_only_4l512w_c4_code': 'NeelNanda/Attn_Only_4L512W_C4_Code',\n",
" 'neelnanda/gelu_1l512w_c4_code': 'NeelNanda/GELU_1L512W_C4_Code',\n",
" 'neelnanda/gelu_2l512w_c4_code': 'NeelNanda/GELU_2L512W_C4_Code',\n",
" 'neelnanda/gelu_3l512w_c4_code': 'NeelNanda/GELU_3L512W_C4_Code',\n",
" 'neelnanda/gelu_4l512w_c4_code': 'NeelNanda/GELU_4L512W_C4_Code',\n",
" 'neelnanda/solu_10l1280w_c4_code': 'NeelNanda/SoLU_10L1280W_C4_Code',\n",
" 'neelnanda/solu_10l_v22_old': 'NeelNanda/SoLU_10L_v22_old',\n",
" 'neelnanda/solu_12l1536w_c4_code': 'NeelNanda/SoLU_12L1536W_C4_Code',\n",
" 'neelnanda/solu_12l_v23_old': 'NeelNanda/SoLU_12L_v23_old',\n",
" 'neelnanda/solu_1l512w_c4_code': 'NeelNanda/SoLU_1L512W_C4_Code',\n",
" 'neelnanda/solu_1l512w_wiki_finetune': 'NeelNanda/SoLU_1L512W_Wiki_Finetune',\n",
" 'neelnanda/solu_1l_v9_old': 'NeelNanda/SoLU_1L_v9_old',\n",
" 'neelnanda/solu_2l512w_c4_code': 'NeelNanda/SoLU_2L512W_C4_Code',\n",
" 'neelnanda/solu_2l_v10_old': 'NeelNanda/SoLU_2L_v10_old',\n",
" 'neelnanda/solu_3l512w_c4_code': 'NeelNanda/SoLU_3L512W_C4_Code',\n",
" 'neelnanda/solu_4l512w_c4_code': 'NeelNanda/SoLU_4L512W_C4_Code',\n",
" 'neelnanda/solu_4l512w_wiki_finetune': 'NeelNanda/SoLU_4L512W_Wiki_Finetune',\n",
" 'neelnanda/solu_4l_v11_old': 'NeelNanda/SoLU_4L_v11_old',\n",
" 'neelnanda/solu_6l768w_c4_code': 'NeelNanda/SoLU_6L768W_C4_Code',\n",
" 'neelnanda/solu_6l_v13_old': 'NeelNanda/SoLU_6L_v13_old',\n",
" 'neelnanda/solu_8l1024w_c4_code': 'NeelNanda/SoLU_8L1024W_C4_Code',\n",
" 'neelnanda/solu_8l_v21_old': 'NeelNanda/SoLU_8L_v21_old',\n",
" 'neo': 'EleutherAI/gpt-neo-125M',\n",
" 'neo-large': 'EleutherAI/gpt-neo-2.7B',\n",
" 'neo-medium': 'EleutherAI/gpt-neo-1.3B',\n",
" 'neo-small': 'EleutherAI/gpt-neo-125M',\n",
" 'neox': 'EleutherAI/gpt-neox-20b',\n",
" 'opt': 'facebook/opt-125m',\n",
" 'opt-1.3b': 'facebook/opt-1.3b',\n",
" 'opt-125m': 'facebook/opt-125m',\n",
" 'opt-13b': 'facebook/opt-13b',\n",
" 'opt-2.7b': 'facebook/opt-2.7b',\n",
" 'opt-30b': 'facebook/opt-30b',\n",
" 'opt-6.7b': 'facebook/opt-6.7b',\n",
" 'opt-66b': 'facebook/opt-66b',\n",
" 'opt-large': 'facebook/opt-2.7b',\n",
" 'opt-medium': 'facebook/opt-1.3b',\n",
" 'opt-small': 'facebook/opt-125m',\n",
" 'opt-xl': 'facebook/opt-6.7b',\n",
" 'opt-xxl': 'facebook/opt-13b',\n",
" 'opt-xxxl': 'facebook/opt-30b',\n",
" 'opt-xxxxl': 'facebook/opt-66b',\n",
" 'othello-gpt': 'Baidicoot/Othello-GPT-Transformer-Lens',\n",
" 'phi-1': 'microsoft/phi-1',\n",
" 'phi-1_5': 'microsoft/phi-1_5',\n",
" 'phi-2': 'microsoft/phi-2',\n",
" 'pythia': 'EleutherAI/pythia-70m',\n",
" 'pythia-1.3b': 'EleutherAI/pythia-1.4b',\n",
" 'pythia-1.3b-deduped': 'EleutherAI/pythia-1.4b-deduped',\n",
" 'pythia-1.3b-deduped-v0': 'EleutherAI/pythia-1.4b-deduped-v0',\n",
" 'pythia-1.3b-v0': 'EleutherAI/pythia-1.4b-v0',\n",
" 'pythia-1.4b': 'EleutherAI/pythia-1.4b',\n",
" 'pythia-1.4b-deduped': 'EleutherAI/pythia-1.4b-deduped',\n",
" 'pythia-1.4b-deduped-v0': 'EleutherAI/pythia-1.4b-deduped-v0',\n",
" 'pythia-1.4b-v0': 'EleutherAI/pythia-1.4b-v0',\n",
" 'pythia-125m': 'EleutherAI/pythia-160m',\n",
" 'pythia-125m-deduped': 'EleutherAI/pythia-160m-deduped',\n",
" 'pythia-125m-deduped-v0': 'EleutherAI/pythia-160m-deduped-v0',\n",
" 'pythia-125m-seed1': 'EleutherAI/pythia-160m-seed1',\n",
" 'pythia-125m-seed2': 'EleutherAI/pythia-160m-seed2',\n",
" 'pythia-125m-seed3': 'EleutherAI/pythia-160m-seed3',\n",
" 'pythia-125m-v0': 'EleutherAI/pythia-160m-v0',\n",
" 'pythia-12b': 'EleutherAI/pythia-12b',\n",
" 'pythia-12b-deduped': 'EleutherAI/pythia-12b-deduped',\n",
" 'pythia-12b-deduped-v0': 'EleutherAI/pythia-12b-deduped-v0',\n",
" 'pythia-12b-v0': 'EleutherAI/pythia-12b-v0',\n",
" 'pythia-13b': 'EleutherAI/pythia-12b',\n",
" 'pythia-13b-deduped': 'EleutherAI/pythia-12b-deduped',\n",
" 'pythia-13b-deduped-v0': 'EleutherAI/pythia-12b-deduped-v0',\n",
" 'pythia-13b-v0': 'EleutherAI/pythia-12b-v0',\n",
" 'pythia-14m': 'EleutherAI/pythia-14m',\n",
" 'pythia-160m': 'EleutherAI/pythia-160m',\n",
" 'pythia-160m-deduped': 'EleutherAI/pythia-160m-deduped',\n",
" 'pythia-160m-deduped-v0': 'EleutherAI/pythia-160m-deduped-v0',\n",
" 'pythia-160m-seed1': 'EleutherAI/pythia-160m-seed1',\n",
" 'pythia-160m-seed2': 'EleutherAI/pythia-160m-seed2',\n",
" 'pythia-160m-seed3': 'EleutherAI/pythia-160m-seed3',\n",
" 'pythia-160m-v0': 'EleutherAI/pythia-160m-v0',\n",
" 'pythia-19m': 'EleutherAI/pythia-70m',\n",
" 'pythia-19m-deduped': 'EleutherAI/pythia-70m-deduped',\n",
" 'pythia-19m-deduped-v0': 'EleutherAI/pythia-70m-deduped-v0',\n",
" 'pythia-19m-v0': 'EleutherAI/pythia-70m-v0',\n",
" 'pythia-1b': 'EleutherAI/pythia-1b',\n",
" 'pythia-1b-deduped': 'EleutherAI/pythia-1b-deduped',\n",
" 'pythia-1b-deduped-v0': 'EleutherAI/pythia-1b-deduped-v0',\n",
" 'pythia-1b-v0': 'EleutherAI/pythia-1b-v0',\n",
" 'pythia-2.7b': 'EleutherAI/pythia-2.8b',\n",
" 'pythia-2.7b-deduped': 'EleutherAI/pythia-2.8b-deduped',\n",
" 'pythia-2.7b-deduped-v0': 'EleutherAI/pythia-2.8b-deduped-v0',\n",
" 'pythia-2.7b-v0': 'EleutherAI/pythia-2.8b-v0',\n",
" 'pythia-2.8b': 'EleutherAI/pythia-2.8b',\n",
" 'pythia-2.8b-deduped': 'EleutherAI/pythia-2.8b-deduped',\n",
" 'pythia-2.8b-deduped-v0': 'EleutherAI/pythia-2.8b-deduped-v0',\n",
" 'pythia-2.8b-v0': 'EleutherAI/pythia-2.8b-v0',\n",
" 'pythia-31m': 'EleutherAI/pythia-31m',\n",
" 'pythia-350m': 'EleutherAI/pythia-410m',\n",
" 'pythia-350m-deduped': 'EleutherAI/pythia-410m-deduped',\n",
" 'pythia-350m-deduped-v0': 'EleutherAI/pythia-410m-deduped-v0',\n",
" 'pythia-350m-v0': 'EleutherAI/pythia-410m-v0',\n",
" 'pythia-410m': 'EleutherAI/pythia-410m',\n",
" 'pythia-410m-deduped': 'EleutherAI/pythia-410m-deduped',\n",
" 'pythia-410m-deduped-v0': 'EleutherAI/pythia-410m-deduped-v0',\n",
" 'pythia-410m-v0': 'EleutherAI/pythia-410m-v0',\n",
" 'pythia-6.7b': 'EleutherAI/pythia-6.9b',\n",
" 'pythia-6.7b-deduped': 'EleutherAI/pythia-6.9b-deduped',\n",
" 'pythia-6.7b-deduped-v0': 'EleutherAI/pythia-6.9b-deduped-v0',\n",
" 'pythia-6.7b-v0': 'EleutherAI/pythia-6.9b-v0',\n",
" 'pythia-6.9b': 'EleutherAI/pythia-6.9b',\n",
" 'pythia-6.9b-deduped': 'EleutherAI/pythia-6.9b-deduped',\n",
" 'pythia-6.9b-deduped-v0': 'EleutherAI/pythia-6.9b-deduped-v0',\n",
" 'pythia-6.9b-v0': 'EleutherAI/pythia-6.9b-v0',\n",
" 'pythia-70m': 'EleutherAI/pythia-70m',\n",
" 'pythia-70m-deduped': 'EleutherAI/pythia-70m-deduped',\n",
" 'pythia-70m-deduped-v0': 'EleutherAI/pythia-70m-deduped-v0',\n",
" 'pythia-70m-v0': 'EleutherAI/pythia-70m-v0',\n",
" 'pythia-800m': 'EleutherAI/pythia-1b',\n",
" 'pythia-800m-deduped': 'EleutherAI/pythia-1b-deduped',\n",
" 'pythia-800m-deduped-v0': 'EleutherAI/pythia-1b-deduped-v0',\n",
" 'pythia-800m-v0': 'EleutherAI/pythia-1b-v0',\n",
" 'pythia-v0': 'EleutherAI/pythia-70m-v0',\n",
" 'qwen-1.8b': 'Qwen/Qwen-1_8B',\n",
" 'qwen-1.8b-chat': 'Qwen/Qwen-1_8B-Chat',\n",
" 'qwen-14b': 'Qwen/Qwen-14B',\n",
" 'qwen-14b-chat': 'Qwen/Qwen-14B-Chat',\n",
" 'qwen-7b': 'Qwen/Qwen-7B',\n",
" 'qwen-7b-chat': 'Qwen/Qwen-7B-Chat',\n",
" 'qwen/qwen-14b': 'Qwen/Qwen-14B',\n",
" 'qwen/qwen-14b-chat': 'Qwen/Qwen-14B-Chat',\n",
" 'qwen/qwen-1_8b': 'Qwen/Qwen-1_8B',\n",
" 'qwen/qwen-1_8b-chat': 'Qwen/Qwen-1_8B-Chat',\n",
" 'qwen/qwen-7b': 'Qwen/Qwen-7B',\n",
" 'qwen/qwen-7b-chat': 'Qwen/Qwen-7B-Chat',\n",
" 'qwen/qwen1.5-0.5b': 'Qwen/Qwen1.5-0.5B',\n",
" 'qwen/qwen1.5-0.5b-chat': 'Qwen/Qwen1.5-0.5B-Chat',\n",
" 'qwen/qwen1.5-1.8b': 'Qwen/Qwen1.5-1.8B',\n",
" 'qwen/qwen1.5-1.8b-chat': 'Qwen/Qwen1.5-1.8B-Chat',\n",
" 'qwen/qwen1.5-14b': 'Qwen/Qwen1.5-14B',\n",
" 'qwen/qwen1.5-14b-chat': 'Qwen/Qwen1.5-14B-Chat',\n",
" 'qwen/qwen1.5-4b': 'Qwen/Qwen1.5-4B',\n",
" 'qwen/qwen1.5-4b-chat': 'Qwen/Qwen1.5-4B-Chat',\n",
" 'qwen/qwen1.5-7b': 'Qwen/Qwen1.5-7B',\n",
" 'qwen/qwen1.5-7b-chat': 'Qwen/Qwen1.5-7B-Chat',\n",
" 'qwen1.5-0.5b': 'Qwen/Qwen1.5-0.5B',\n",
" 'qwen1.5-0.5b-chat': 'Qwen/Qwen1.5-0.5B-Chat',\n",
" 'qwen1.5-1.8b': 'Qwen/Qwen1.5-1.8B',\n",
" 'qwen1.5-1.8b-chat': 'Qwen/Qwen1.5-1.8B-Chat',\n",
" 'qwen1.5-14b': 'Qwen/Qwen1.5-14B',\n",
" 'qwen1.5-14b-chat': 'Qwen/Qwen1.5-14B-Chat',\n",
" 'qwen1.5-4b': 'Qwen/Qwen1.5-4B',\n",
" 'qwen1.5-4b-chat': 'Qwen/Qwen1.5-4B-Chat',\n",
" 'qwen1.5-7b': 'Qwen/Qwen1.5-7B',\n",
" 'qwen1.5-7b-chat': 'Qwen/Qwen1.5-7B-Chat',\n",
" 'redwood_attn_2l': 'ArthurConmy/redwood_attn_2l',\n",
" 'roneneldan/tinystories-1layer-21m': 'roneneldan/TinyStories-1Layer-21M',\n",
" 'roneneldan/tinystories-1m': 'roneneldan/TinyStories-1M',\n",
" 'roneneldan/tinystories-28m': 'roneneldan/TinyStories-28M',\n",
" 'roneneldan/tinystories-2layers-33m': 'roneneldan/TinyStories-2Layers-33M',\n",
" 'roneneldan/tinystories-33m': 'roneneldan/TinyStories-33M',\n",
" 'roneneldan/tinystories-3m': 'roneneldan/TinyStories-3M',\n",
" 'roneneldan/tinystories-8m': 'roneneldan/TinyStories-8M',\n",
" 'roneneldan/tinystories-instruct-1m': 'roneneldan/TinyStories-Instruct-1M',\n",
" 'roneneldan/tinystories-instruct-28m': 'roneneldan/TinyStories-Instruct-28M',\n",
" 'roneneldan/tinystories-instruct-2layers-33m': 'roneneldan/TinyStories-Instruct-2Layers-33M',\n",
" 'roneneldan/tinystories-instruct-33m': 'roneneldan/TinyStories-Instruct-33M',\n",
" 'roneneldan/tinystories-instruct-3m': 'roneneldan/TinyStories-Instruct-3M',\n",
" 'roneneldan/tinystories-instruct-8m': 'roneneldan/TinyStories-Instruct-8M',\n",
" 'roneneldan/tinystories-instuct-1layer-21m': 'roneneldan/TinyStories-Instuct-1Layer-21M',\n",
" 'santacoder': 'bigcode/santacoder',\n",
" 'solu-10l': 'NeelNanda/SoLU_10L1280W_C4_Code',\n",
" 'solu-10l-c4-code': 'NeelNanda/SoLU_10L1280W_C4_Code',\n",
" 'solu-10l-new': 'NeelNanda/SoLU_10L1280W_C4_Code',\n",
" 'solu-10l-old': 'NeelNanda/SoLU_10L_v22_old',\n",
" 'solu-10l-pile': 'NeelNanda/SoLU_10L_v22_old',\n",
" 'solu-12l': 'NeelNanda/SoLU_12L1536W_C4_Code',\n",
" 'solu-12l-c4-code': 'NeelNanda/SoLU_12L1536W_C4_Code',\n",
" 'solu-12l-new': 'NeelNanda/SoLU_12L1536W_C4_Code',\n",
" 'solu-12l-old': 'NeelNanda/SoLU_12L_v23_old',\n",
" 'solu-12l-pile': 'NeelNanda/SoLU_12L_v23_old',\n",
" 'solu-1l': 'NeelNanda/SoLU_1L512W_C4_Code',\n",
" 'solu-1l-c4-code': 'NeelNanda/SoLU_1L512W_C4_Code',\n",
" 'solu-1l-finetune': 'NeelNanda/SoLU_1L512W_Wiki_Finetune',\n",
" 'solu-1l-new': 'NeelNanda/SoLU_1L512W_C4_Code',\n",
" 'solu-1l-old': 'NeelNanda/SoLU_1L_v9_old',\n",
" 'solu-1l-pile': 'NeelNanda/SoLU_1L_v9_old',\n",
" 'solu-1l-wiki': 'NeelNanda/SoLU_1L512W_Wiki_Finetune',\n",
" 'solu-1l-wiki-finetune': 'NeelNanda/SoLU_1L512W_Wiki_Finetune',\n",
" 'solu-2l': 'NeelNanda/SoLU_2L512W_C4_Code',\n",
" 'solu-2l-c4-code': 'NeelNanda/SoLU_2L512W_C4_Code',\n",
" 'solu-2l-new': 'NeelNanda/SoLU_2L512W_C4_Code',\n",
" 'solu-2l-old': 'NeelNanda/SoLU_2L_v10_old',\n",
" 'solu-2l-pile': 'NeelNanda/SoLU_2L_v10_old',\n",
" 'solu-3l': 'NeelNanda/SoLU_3L512W_C4_Code',\n",
" 'solu-3l-c4-code': 'NeelNanda/SoLU_3L512W_C4_Code',\n",
" 'solu-3l-new': 'NeelNanda/SoLU_3L512W_C4_Code',\n",
" 'solu-4l': 'NeelNanda/SoLU_4L512W_C4_Code',\n",
" 'solu-4l-c4-code': 'NeelNanda/SoLU_4L512W_C4_Code',\n",
" 'solu-4l-finetune': 'NeelNanda/SoLU_4L512W_Wiki_Finetune',\n",
" 'solu-4l-new': 'NeelNanda/SoLU_4L512W_C4_Code',\n",
" 'solu-4l-old': 'NeelNanda/SoLU_4L_v11_old',\n",
" 'solu-4l-pile': 'NeelNanda/SoLU_4L_v11_old',\n",
" 'solu-4l-wiki': 'NeelNanda/SoLU_4L512W_Wiki_Finetune',\n",
" 'solu-4l-wiki-finetune': 'NeelNanda/SoLU_4L512W_Wiki_Finetune',\n",
" 'solu-6l': 'NeelNanda/SoLU_6L768W_C4_Code',\n",
" 'solu-6l-c4-code': 'NeelNanda/SoLU_6L768W_C4_Code',\n",
" 'solu-6l-new': 'NeelNanda/SoLU_6L768W_C4_Code',\n",
" 'solu-6l-old': 'NeelNanda/SoLU_6L_v13_old',\n",
" 'solu-6l-pile': 'NeelNanda/SoLU_6L_v13_old',\n",
" 'solu-8l': 'NeelNanda/SoLU_8L1024W_C4_Code',\n",
" 'solu-8l-c4-code': 'NeelNanda/SoLU_8L1024W_C4_Code',\n",
" 'solu-8l-new': 'NeelNanda/SoLU_8L1024W_C4_Code',\n",
" 'solu-8l-old': 'NeelNanda/SoLU_8L_v21_old',\n",
" 'solu-8l-pile': 'NeelNanda/SoLU_8L_v21_old',\n",
" 'stabilityai/stablelm-base-alpha-3b': 'stabilityai/stablelm-base-alpha-3b',\n",
" 'stabilityai/stablelm-base-alpha-7b': 'stabilityai/stablelm-base-alpha-7b',\n",
" 'stabilityai/stablelm-tuned-alpha-3b': 'stabilityai/stablelm-tuned-alpha-3b',\n",
" 'stabilityai/stablelm-tuned-alpha-7b': 'stabilityai/stablelm-tuned-alpha-7b',\n",
" 'stablelm-base-3b': 'stabilityai/stablelm-base-alpha-3b',\n",
" 'stablelm-base-7b': 'stabilityai/stablelm-base-alpha-7b',\n",
" 'stablelm-base-alpha-3b': 'stabilityai/stablelm-base-alpha-3b',\n",
" 'stablelm-base-alpha-7b': 'stabilityai/stablelm-base-alpha-7b',\n",
" 'stablelm-tuned-3b': 'stabilityai/stablelm-tuned-alpha-3b',\n",
" 'stablelm-tuned-7b': 'stabilityai/stablelm-tuned-alpha-7b',\n",
" 'stablelm-tuned-alpha-3b': 'stabilityai/stablelm-tuned-alpha-3b',\n",
" 'stablelm-tuned-alpha-7b': 'stabilityai/stablelm-tuned-alpha-7b',\n",
" 'stanford-crfm/alias-gpt2-small-x21': 'stanford-crfm/alias-gpt2-small-x21',\n",
" 'stanford-crfm/arwen-gpt2-medium-x21': 'stanford-crfm/arwen-gpt2-medium-x21',\n",
" 'stanford-crfm/battlestar-gpt2-small-x49': 'stanford-crfm/battlestar-gpt2-small-x49',\n",
" 'stanford-crfm/beren-gpt2-medium-x49': 'stanford-crfm/beren-gpt2-medium-x49',\n",
" 'stanford-crfm/caprica-gpt2-small-x81': 'stanford-crfm/caprica-gpt2-small-x81',\n",
" 'stanford-crfm/celebrimbor-gpt2-medium-x81': 'stanford-crfm/celebrimbor-gpt2-medium-x81',\n",
" 'stanford-crfm/darkmatter-gpt2-small-x343': 'stanford-crfm/darkmatter-gpt2-small-x343',\n",
" 'stanford-crfm/durin-gpt2-medium-x343': 'stanford-crfm/durin-gpt2-medium-x343',\n",
" 'stanford-crfm/eowyn-gpt2-medium-x777': 'stanford-crfm/eowyn-gpt2-medium-x777',\n",
" 'stanford-crfm/expanse-gpt2-small-x777': 'stanford-crfm/expanse-gpt2-small-x777',\n",
" 'stanford-gpt2-medium-a': 'stanford-crfm/arwen-gpt2-medium-x21',\n",
" 'stanford-gpt2-medium-b': 'stanford-crfm/beren-gpt2-medium-x49',\n",
" 'stanford-gpt2-medium-c': 'stanford-crfm/celebrimbor-gpt2-medium-x81',\n",
" 'stanford-gpt2-medium-d': 'stanford-crfm/durin-gpt2-medium-x343',\n",
" 'stanford-gpt2-medium-e': 'stanford-crfm/eowyn-gpt2-medium-x777',\n",
" 'stanford-gpt2-small-a': 'stanford-crfm/alias-gpt2-small-x21',\n",
" 'stanford-gpt2-small-b': 'stanford-crfm/battlestar-gpt2-small-x49',\n",
" 'stanford-gpt2-small-c': 'stanford-crfm/caprica-gpt2-small-x81',\n",
" 'stanford-gpt2-small-d': 'stanford-crfm/darkmatter-gpt2-small-x343',\n",
" 'stanford-gpt2-small-e': 'stanford-crfm/expanse-gpt2-small-x777',\n",
" 'tiny-stories-1l-21m': 'roneneldan/TinyStories-1Layer-21M',\n",
" 'tiny-stories-1m': 'roneneldan/TinyStories-1M',\n",
" 'tiny-stories-28m': 'roneneldan/TinyStories-28M',\n",
" 'tiny-stories-2l-33m': 'roneneldan/TinyStories-2Layers-33M',\n",
" 'tiny-stories-33m': 'roneneldan/TinyStories-33M',\n",
" 'tiny-stories-3m': 'roneneldan/TinyStories-3M',\n",
" 'tiny-stories-8m': 'roneneldan/TinyStories-8M',\n",
" 'tiny-stories-instruct-1l-21m': 'roneneldan/TinyStories-Instuct-1Layer-21M',\n",
" 'tiny-stories-instruct-1m': 'roneneldan/TinyStories-Instruct-1M',\n",
" 'tiny-stories-instruct-28m': 'roneneldan/TinyStories-Instruct-28M',\n",
" 'tiny-stories-instruct-2l-33m': 'roneneldan/TinyStories-Instruct-2Layers-33M',\n",
" 'tiny-stories-instruct-33m': 'roneneldan/TinyStories-Instruct-33M',\n",
" 'tiny-stories-instruct-3m': 'roneneldan/TinyStories-Instruct-3M',\n",
" 'tiny-stories-instruct-8m': 'roneneldan/TinyStories-Instruct-8M',\n",
" 'yi-34b': '01-ai/Yi-34B',\n",
" 'yi-34b-chat': '01-ai/Yi-34B-Chat',\n",
" 'yi-6b': '01-ai/Yi-6B',\n",
" 'yi-6b-chat': '01-ai/Yi-6B-Chat'}\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"config.json: 0%| | 0.00/654 [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "41720b13bae74e54a9fa2c570dde6f9c"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/23.9k [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "afcb5207e4dd4f4e8de7a93967866745"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading shards: 0%| | 0/4 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7b30f98e36fd4b2fa71231791fff4514"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7e46fdf68e9f4b62ad152fddb4e47ee1"
}
},
"metadata": {}
}
],
"source": [
"torch.set_grad_enabled(False)\n",
"\n",
"MODEL_PATH = 'meta-llama/Meta-Llama-3-8B-Instruct'\n",
"# MODEL_PATH = 'Qwen/Qwen-1_8B-Chat\n",
"DEVICE = 'cuda'\n",
"\n",
"import inspect, pprint\n",
"from transformer_lens.loading_from_pretrained import make_model_alias_map\n",
"pp = pprint.PrettyPrinter(indent=4)\n",
"pp.pprint(inspect.getfullargspec(HookedTransformer.from_pretrained_no_processing))\n",
"pp.pprint(make_model_alias_map())\n",
"\n",
"model = HookedTransformer.from_pretrained_no_processing(\n",
" MODEL_PATH,\n",
" device=DEVICE,\n",
" dtype=torch.float16,\n",
" default_padding_side='left',\n",
" fp16=True\n",
")\n",
"\n",
"model.tokenizer.padding_side = 'left'\n",
"model.tokenizer.pad_token = '<|extra_0|>'\n",
"model.tokenizer.pad_token_id = model.tokenizer.token_to_id('<|extra_0|>')\n"
]
},
{
"cell_type": "markdown",
"source": [
"### Load harmful / harmless datasets"
],
"metadata": {
"id": "rF7e-u20EFTe"
}
},
{
"cell_type": "code",
"source": [
"def get_harmful_instructions():\n",
" url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'\n",
" response = requests.get(url)\n",
"\n",
" dataset = pd.read_csv(io.StringIO(response.content.decode('utf-8')))\n",
" instructions = dataset['goal'].tolist()\n",
"\n",
" train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
" return train, test\n",
"\n",
"def get_harmless_instructions():\n",
" hf_path = 'tatsu-lab/alpaca'\n",
" dataset = load_dataset(hf_path)\n",
"\n",
" # filter for instructions that do not have inputs\n",
" instructions = []\n",
" for i in range(len(dataset['train'])):\n",
" if dataset['train'][i]['input'].strip() == '':\n",
" instructions.append(dataset['train'][i]['instruction'])\n",
"\n",
" train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
" return train, test"
],
"metadata": {
"id": "5i1XcVIgHEE1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"harmful_inst_train, harmful_inst_test = get_harmful_instructions()\n",
"harmless_inst_train, harmless_inst_test = get_harmless_instructions()"
],
"metadata": {
"id": "Rth8yvLZJsXs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"Harmful instructions:\")\n",
"for i in range(4):\n",
" print(f\"\\t{repr(harmful_inst_train[i])}\")\n",
"print(\"Harmless instructions:\")\n",
"for i in range(4):\n",
" print(f\"\\t{repr(harmless_inst_train[i])}\")"
],
"metadata": {
"id": "Qv2ALDY_J44G"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Tokenization utils"
],
"metadata": {
"id": "KOKYA61k8LWt"
}
},
{
"cell_type": "code",
"source": [
"QWEN_CHAT_TEMPLATE = \"\"\"<|im_start|>user\n",
"{instruction}<|im_end|>\n",
"<|im_start|>assistant\n",
"\"\"\"\n",
"\n",
"def tokenize_instructions_qwen_chat(\n",
" tokenizer: AutoTokenizer,\n",
" instructions: List[str]\n",
") -> Int[Tensor, 'batch_size seq_len']:\n",
" prompts = [QWEN_CHAT_TEMPLATE.format(instruction=instruction) for instruction in instructions]\n",
" return tokenizer(prompts, padding=True,truncation=False, return_tensors=\"pt\").input_ids\n",
"\n",
"tokenize_instructions_fn = functools.partial(tokenize_instructions_qwen_chat, tokenizer=model.tokenizer)"
],
"metadata": {
"id": "P8UPQSfpWOSK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Generation utils"
],
"metadata": {
"id": "gtrIK8x78SZh"
}
},
{
"cell_type": "code",
"source": [
"def _generate_with_hooks(\n",
" model: HookedTransformer,\n",
" toks: Int[Tensor, 'batch_size seq_len'],\n",
" max_tokens_generated: int = 64,\n",
" fwd_hooks = [],\n",
") -> List[str]:\n",
"\n",
" all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)\n",
" all_toks[:, :toks.shape[1]] = toks\n",
"\n",
" for i in range(max_tokens_generated):\n",
" with model.hooks(fwd_hooks=fwd_hooks):\n",
" logits = model(all_toks[:, :-max_tokens_generated + i])\n",
" next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)\n",
" all_toks[:,-max_tokens_generated+i] = next_tokens\n",
"\n",
" return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)\n",
"\n",
"def get_generations(\n",
" model: HookedTransformer,\n",
" instructions: List[str],\n",
" tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],\n",
" fwd_hooks = [],\n",
" max_tokens_generated: int = 64,\n",
" batch_size: int = 4,\n",
") -> List[str]:\n",
"\n",
" generations = []\n",
"\n",
" for i in tqdm(range(0, len(instructions), batch_size)):\n",
" toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])\n",
" generation = _generate_with_hooks(\n",
" model,\n",
" toks,\n",
" max_tokens_generated=max_tokens_generated,\n",
" fwd_hooks=fwd_hooks,\n",
" )\n",
" generations.extend(generation)\n",
"\n",
" return generations"
],
"metadata": {
"id": "94jRJDR0DRoY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Finding the \"refusal direction\""
],
"metadata": {
"id": "W9O8dm0_EQRk"
}
},
{
"cell_type": "code",
"source": [
"N_INST_TRAIN = 32\n",
"\n",
"# tokenize instructions\n",
"harmful_toks = tokenize_instructions_fn(instructions=harmful_inst_train[:N_INST_TRAIN])\n",
"harmless_toks = tokenize_instructions_fn(instructions=harmless_inst_train[:N_INST_TRAIN])\n",
"\n",
"# run model on harmful and harmless instructions, caching intermediate activations\n",
"harmful_logits, harmful_cache = model.run_with_cache(harmful_toks, names_filter=lambda hook_name: 'resid' in hook_name)\n",
"harmless_logits, harmless_cache = model.run_with_cache(harmless_toks, names_filter=lambda hook_name: 'resid' in hook_name)"
],
"metadata": {
"id": "MbY79kSP8oOg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# compute difference of means between harmful and harmless activations at an intermediate layer\n",
"\n",
"pos = -1\n",
"layer = 14\n",
"\n",
"harmful_mean_act = harmful_cache['resid_pre', layer][:, pos, :].mean(dim=0)\n",
"harmless_mean_act = harmless_cache['resid_pre', layer][:, pos, :].mean(dim=0)\n",
"\n",
"refusal_dir = harmful_mean_act - harmless_mean_act\n",
"refusal_dir = refusal_dir / refusal_dir.norm()"
],
"metadata": {
"id": "tqD5E8Vc_w5d"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# clean up memory\n",
"del harmful_cache, harmless_cache, harmful_logits, harmless_logits\n",
"gc.collect(); torch.cuda.empty_cache()"
],
"metadata": {
"id": "NU9rjXPT4uQ_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Ablate \"refusal direction\" via inference-time intervention\n",
"\n",
"Given a \"refusal direction\" $\\widehat{r} \\in \\mathbb{R}^{d_{\\text{model}}}$ with unit norm, we can ablate this direction from the model's activations $a_{l}$:\n",
"$${a}_{l}' \\leftarrow a_l - (a_l \\cdot \\widehat{r}) \\widehat{r}$$\n",
"\n",
"By performing this ablation on all intermediate activations, we enforce that the model can never express this direction (or \"feature\")."
],
"metadata": {
"id": "2EoxY5i1CWe3"
}
},
{
"cell_type": "code",
"source": [
"def direction_ablation_hook(\n",
" activation: Float[Tensor, \"... d_act\"],\n",
" hook: HookPoint,\n",
" direction: Float[Tensor, \"d_act\"]\n",
"):\n",
" proj = einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction\n",
" return activation - proj"
],
"metadata": {
"id": "26rf-yncB2PT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"N_INST_TEST = 32\n",
"intervention_dir = refusal_dir\n",
"intervention_layers = list(range(model.cfg.n_layers)) # all layers\n",
"\n",
"hook_fn = functools.partial(direction_ablation_hook,direction=intervention_dir)\n",
"fwd_hooks = [(utils.get_act_name(act_name, l), hook_fn) for l in intervention_layers for act_name in ['resid_pre', 'resid_mid', 'resid_post']]\n",
"\n",
"intervention_generations = get_generations(model, harmful_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=fwd_hooks)\n",
"baseline_generations = get_generations(model, harmful_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=[])"
],
"metadata": {
"id": "sR1G5bXoEDty"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for i in range(N_INST_TEST):\n",
" print(f\"INSTRUCTION {i}: {repr(harmful_inst_test[i])}\")\n",
" print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
" print(textwrap.fill(repr(baseline_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n",
" print(Fore.RED + f\"INTERVENTION COMPLETION:\")\n",
" print(textwrap.fill(repr(intervention_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n",
" print(Fore.RESET)"
],
"metadata": {
"id": "pxbJr4vCFCOL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Orthogonalize weights w.r.t. \"refusal direction\"\n",
"\n",
"We can implement the intervention equivalently by directly orthogonalizing the weight matrices that write to the residual stream with respect to the refusal direction $\\widehat{r}$:\n",
"$$W_{\\text{out}}' \\leftarrow W_{\\text{out}} - \\widehat{r}\\widehat{r}^{\\mathsf{T}} W_{\\text{out}}$$\n",
"\n",
"By orthogonalizing these weight matrices, we enforce that the model is unable to write direction $r$ to the residual stream at all!"
],
"metadata": {
"id": "t9KooaWaCDc_"
}
},
{
"cell_type": "code",
"source": [
"def get_orthogonalized_matrix(matrix: Float[Tensor, '... d_model'], vec: Float[Tensor, 'd_model']) -> Float[Tensor, '... d_model']:\n",
" proj = einops.einsum(matrix, vec.view(-1, 1), '... d_model, d_model single -> ... single') * vec\n",
" return matrix - proj"
],
"metadata": {
"id": "8fhx0i9vCEou"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model.W_E.data = get_orthogonalized_matrix(model.W_E, refusal_dir)\n",
"\n",
"for block in model.blocks:\n",
" block.attn.W_O.data = get_orthogonalized_matrix(block.attn.W_O, refusal_dir)\n",
" block.mlp.W_out.data = get_orthogonalized_matrix(block.mlp.W_out, refusal_dir)"
],
"metadata": {
"id": "GC7cpMXZCG64"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"orthogonalized_generations = get_generations(model, harmful_inst_test[:N_INST_TEST], tokenize_instructions_fn, fwd_hooks=[])"
],
"metadata": {
"id": "1Y-qtouNGf3t"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for i in range(N_INST_TEST):\n",
" print(f\"INSTRUCTION {i}: {repr(harmful_inst_test[i])}\")\n",
" print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
" print(textwrap.fill(repr(baseline_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n",
" print(Fore.RED + f\"INTERVENTION COMPLETION:\")\n",
" print(textwrap.fill(repr(intervention_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n",
" print(Fore.MAGENTA + f\"ORTHOGONALIZED COMPLETION:\")\n",
" print(textwrap.fill(repr(orthogonalized_generations[i]), width=100, initial_indent='\\t', subsequent_indent='\\t'))\n",
" print(Fore.RESET)"
],
"metadata": {
"id": "r68O4_4DG3P7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "exUh3PEHRe9x"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment