Created
May 14, 2024 20:58
-
-
Save SethGA/9b842d93baf253f250a20271de38b77d to your computer and use it in GitHub Desktop.
EmoCLIP_keysTest.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyOGeHO89IJM4Mw5+xTMqGH/", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/SethGA/9b842d93baf253f250a20271de38b77d/emoclip_keystest.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "WrTWw5qSo2V9" | |
}, | |
"outputs": [], | |
"source": [ | |
"! pip install ftfy regex tqdm\n", | |
"! pip install git+https://github.com/openai/CLIP.git\n", | |
"! pip install einops\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"from pkg_resources import packaging\n", | |
"import clip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"device = torch.device(\"cuda\")\n", | |
"emoclip_weights = torch.load('downstream.pth', map_location=device)\n", | |
"print(emoclip_weights.keys())" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EZd97jRDp0s-", | |
"outputId": "014d71c5-ddc1-4dce-9143-108c594b66dc" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"odict_keys(['module.logit_scale', 'module.backbone.positional_embedding', 'module.backbone.text_projection', 'module.backbone.logit_scale', 'module.backbone.visual.class_embedding', 'module.backbone.visual.positional_embedding', 'module.backbone.visual.proj', 'module.backbone.visual.conv1.weight', 'module.backbone.visual.ln_pre.weight', 'module.backbone.visual.ln_pre.bias', 'module.backbone.visual.transformer.resblocks.0.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.0.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.0.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.0.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.0.ln_1.weight', 'module.backbone.visual.transformer.resblocks.0.ln_1.bias', 'module.backbone.visual.transformer.resblocks.0.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.0.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.0.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.0.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.0.ln_2.weight', 'module.backbone.visual.transformer.resblocks.0.ln_2.bias', 'module.backbone.visual.transformer.resblocks.1.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.1.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.1.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.1.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.1.ln_1.weight', 'module.backbone.visual.transformer.resblocks.1.ln_1.bias', 'module.backbone.visual.transformer.resblocks.1.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.1.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.1.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.1.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.1.ln_2.weight', 'module.backbone.visual.transformer.resblocks.1.ln_2.bias', 'module.backbone.visual.transformer.resblocks.2.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.2.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.2.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.2.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.2.ln_1.weight', 'module.backbone.visual.transformer.resblocks.2.ln_1.bias', 'module.backbone.visual.transformer.resblocks.2.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.2.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.2.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.2.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.2.ln_2.weight', 'module.backbone.visual.transformer.resblocks.2.ln_2.bias', 'module.backbone.visual.transformer.resblocks.3.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.3.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.3.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.3.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.3.ln_1.weight', 'module.backbone.visual.transformer.resblocks.3.ln_1.bias', 'module.backbone.visual.transformer.resblocks.3.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.3.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.3.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.3.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.3.ln_2.weight', 'module.backbone.visual.transformer.resblocks.3.ln_2.bias', 'module.backbone.visual.transformer.resblocks.4.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.4.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.4.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.4.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.4.ln_1.weight', 'module.backbone.visual.transformer.resblocks.4.ln_1.bias', 'module.backbone.visual.transformer.resblocks.4.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.4.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.4.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.4.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.4.ln_2.weight', 'module.backbone.visual.transformer.resblocks.4.ln_2.bias', 'module.backbone.visual.transformer.resblocks.5.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.5.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.5.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.5.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.5.ln_1.weight', 'module.backbone.visual.transformer.resblocks.5.ln_1.bias', 'module.backbone.visual.transformer.resblocks.5.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.5.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.5.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.5.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.5.ln_2.weight', 'module.backbone.visual.transformer.resblocks.5.ln_2.bias', 'module.backbone.visual.transformer.resblocks.6.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.6.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.6.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.6.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.6.ln_1.weight', 'module.backbone.visual.transformer.resblocks.6.ln_1.bias', 'module.backbone.visual.transformer.resblocks.6.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.6.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.6.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.6.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.6.ln_2.weight', 'module.backbone.visual.transformer.resblocks.6.ln_2.bias', 'module.backbone.visual.transformer.resblocks.7.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.7.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.7.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.7.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.7.ln_1.weight', 'module.backbone.visual.transformer.resblocks.7.ln_1.bias', 'module.backbone.visual.transformer.resblocks.7.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.7.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.7.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.7.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.7.ln_2.weight', 'module.backbone.visual.transformer.resblocks.7.ln_2.bias', 'module.backbone.visual.transformer.resblocks.8.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.8.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.8.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.8.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.8.ln_1.weight', 'module.backbone.visual.transformer.resblocks.8.ln_1.bias', 'module.backbone.visual.transformer.resblocks.8.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.8.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.8.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.8.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.8.ln_2.weight', 'module.backbone.visual.transformer.resblocks.8.ln_2.bias', 'module.backbone.visual.transformer.resblocks.9.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.9.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.9.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.9.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.9.ln_1.weight', 'module.backbone.visual.transformer.resblocks.9.ln_1.bias', 'module.backbone.visual.transformer.resblocks.9.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.9.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.9.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.9.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.9.ln_2.weight', 'module.backbone.visual.transformer.resblocks.9.ln_2.bias', 'module.backbone.visual.transformer.resblocks.10.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.10.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.10.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.10.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.10.ln_1.weight', 'module.backbone.visual.transformer.resblocks.10.ln_1.bias', 'module.backbone.visual.transformer.resblocks.10.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.10.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.10.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.10.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.10.ln_2.weight', 'module.backbone.visual.transformer.resblocks.10.ln_2.bias', 'module.backbone.visual.transformer.resblocks.11.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.11.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.11.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.11.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.11.ln_1.weight', 'module.backbone.visual.transformer.resblocks.11.ln_1.bias', 'module.backbone.visual.transformer.resblocks.11.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.11.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.11.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.11.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.11.ln_2.weight', 'module.backbone.visual.transformer.resblocks.11.ln_2.bias', 'module.backbone.visual.ln_post.weight', 'module.backbone.visual.ln_post.bias', 'module.backbone.transformer.resblocks.0.attn.in_proj_weight', 'module.backbone.transformer.resblocks.0.attn.in_proj_bias', 'module.backbone.transformer.resblocks.0.attn.out_proj.weight', 'module.backbone.transformer.resblocks.0.attn.out_proj.bias', 'module.backbone.transformer.resblocks.0.ln_1.weight', 'module.backbone.transformer.resblocks.0.ln_1.bias', 'module.backbone.transformer.resblocks.0.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.0.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.0.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.0.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.0.ln_2.weight', 'module.backbone.transformer.resblocks.0.ln_2.bias', 'module.backbone.transformer.resblocks.1.attn.in_proj_weight', 'module.backbone.transformer.resblocks.1.attn.in_proj_bias', 'module.backbone.transformer.resblocks.1.attn.out_proj.weight', 'module.backbone.transformer.resblocks.1.attn.out_proj.bias', 'module.backbone.transformer.resblocks.1.ln_1.weight', 'module.backbone.transformer.resblocks.1.ln_1.bias', 'module.backbone.transformer.resblocks.1.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.1.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.1.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.1.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.1.ln_2.weight', 'module.backbone.transformer.resblocks.1.ln_2.bias', 'module.backbone.transformer.resblocks.2.attn.in_proj_weight', 'module.backbone.transformer.resblocks.2.attn.in_proj_bias', 'module.backbone.transformer.resblocks.2.attn.out_proj.weight', 'module.backbone.transformer.resblocks.2.attn.out_proj.bias', 'module.backbone.transformer.resblocks.2.ln_1.weight', 'module.backbone.transformer.resblocks.2.ln_1.bias', 'module.backbone.transformer.resblocks.2.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.2.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.2.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.2.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.2.ln_2.weight', 'module.backbone.transformer.resblocks.2.ln_2.bias', 'module.backbone.transformer.resblocks.3.attn.in_proj_weight', 'module.backbone.transformer.resblocks.3.attn.in_proj_bias', 'module.backbone.transformer.resblocks.3.attn.out_proj.weight', 'module.backbone.transformer.resblocks.3.attn.out_proj.bias', 'module.backbone.transformer.resblocks.3.ln_1.weight', 'module.backbone.transformer.resblocks.3.ln_1.bias', 'module.backbone.transformer.resblocks.3.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.3.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.3.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.3.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.3.ln_2.weight', 'module.backbone.transformer.resblocks.3.ln_2.bias', 'module.backbone.transformer.resblocks.4.attn.in_proj_weight', 'module.backbone.transformer.resblocks.4.attn.in_proj_bias', 'module.backbone.transformer.resblocks.4.attn.out_proj.weight', 'module.backbone.transformer.resblocks.4.attn.out_proj.bias', 'module.backbone.transformer.resblocks.4.ln_1.weight', 'module.backbone.transformer.resblocks.4.ln_1.bias', 'module.backbone.transformer.resblocks.4.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.4.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.4.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.4.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.4.ln_2.weight', 'module.backbone.transformer.resblocks.4.ln_2.bias', 'module.backbone.transformer.resblocks.5.attn.in_proj_weight', 'module.backbone.transformer.resblocks.5.attn.in_proj_bias', 'module.backbone.transformer.resblocks.5.attn.out_proj.weight', 'module.backbone.transformer.resblocks.5.attn.out_proj.bias', 'module.backbone.transformer.resblocks.5.ln_1.weight', 'module.backbone.transformer.resblocks.5.ln_1.bias', 'module.backbone.transformer.resblocks.5.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.5.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.5.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.5.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.5.ln_2.weight', 'module.backbone.transformer.resblocks.5.ln_2.bias', 'module.backbone.transformer.resblocks.6.attn.in_proj_weight', 'module.backbone.transformer.resblocks.6.attn.in_proj_bias', 'module.backbone.transformer.resblocks.6.attn.out_proj.weight', 'module.backbone.transformer.resblocks.6.attn.out_proj.bias', 'module.backbone.transformer.resblocks.6.ln_1.weight', 'module.backbone.transformer.resblocks.6.ln_1.bias', 'module.backbone.transformer.resblocks.6.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.6.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.6.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.6.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.6.ln_2.weight', 'module.backbone.transformer.resblocks.6.ln_2.bias', 'module.backbone.transformer.resblocks.7.attn.in_proj_weight', 'module.backbone.transformer.resblocks.7.attn.in_proj_bias', 'module.backbone.transformer.resblocks.7.attn.out_proj.weight', 'module.backbone.transformer.resblocks.7.attn.out_proj.bias', 'module.backbone.transformer.resblocks.7.ln_1.weight', 'module.backbone.transformer.resblocks.7.ln_1.bias', 'module.backbone.transformer.resblocks.7.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.7.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.7.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.7.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.7.ln_2.weight', 'module.backbone.transformer.resblocks.7.ln_2.bias', 'module.backbone.transformer.resblocks.8.attn.in_proj_weight', 'module.backbone.transformer.resblocks.8.attn.in_proj_bias', 'module.backbone.transformer.resblocks.8.attn.out_proj.weight', 'module.backbone.transformer.resblocks.8.attn.out_proj.bias', 'module.backbone.transformer.resblocks.8.ln_1.weight', 'module.backbone.transformer.resblocks.8.ln_1.bias', 'module.backbone.transformer.resblocks.8.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.8.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.8.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.8.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.8.ln_2.weight', 'module.backbone.transformer.resblocks.8.ln_2.bias', 'module.backbone.transformer.resblocks.9.attn.in_proj_weight', 'module.backbone.transformer.resblocks.9.attn.in_proj_bias', 'module.backbone.transformer.resblocks.9.attn.out_proj.weight', 'module.backbone.transformer.resblocks.9.attn.out_proj.bias', 'module.backbone.transformer.resblocks.9.ln_1.weight', 'module.backbone.transformer.resblocks.9.ln_1.bias', 'module.backbone.transformer.resblocks.9.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.9.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.9.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.9.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.9.ln_2.weight', 'module.backbone.transformer.resblocks.9.ln_2.bias', 'module.backbone.transformer.resblocks.10.attn.in_proj_weight', 'module.backbone.transformer.resblocks.10.attn.in_proj_bias', 'module.backbone.transformer.resblocks.10.attn.out_proj.weight', 'module.backbone.transformer.resblocks.10.attn.out_proj.bias', 'module.backbone.transformer.resblocks.10.ln_1.weight', 'module.backbone.transformer.resblocks.10.ln_1.bias', 'module.backbone.transformer.resblocks.10.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.10.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.10.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.10.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.10.ln_2.weight', 'module.backbone.transformer.resblocks.10.ln_2.bias', 'module.backbone.transformer.resblocks.11.attn.in_proj_weight', 'module.backbone.transformer.resblocks.11.attn.in_proj_bias', 'module.backbone.transformer.resblocks.11.attn.out_proj.weight', 'module.backbone.transformer.resblocks.11.attn.out_proj.bias', 'module.backbone.transformer.resblocks.11.ln_1.weight', 'module.backbone.transformer.resblocks.11.ln_1.bias', 'module.backbone.transformer.resblocks.11.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.11.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.11.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.11.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.11.ln_2.weight', 'module.backbone.transformer.resblocks.11.ln_2.bias', 'module.backbone.token_embedding.weight', 'module.backbone.ln_final.weight', 'module.backbone.ln_final.bias', 'module.temporal.cls_token', 'module.temporal.pos_embedding', 'module.temporal.temporal_transformer.layers.0.0.fn.norm.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.norm.bias', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_qkv.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_out.0.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_out.0.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.norm.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.norm.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.0.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.0.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.3.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.3.bias', 'module.temporal.temporal_transformer.layers.1.0.fn.norm.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.norm.bias', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_qkv.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_out.0.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_out.0.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.norm.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.norm.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.0.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.0.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.3.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.3.bias'])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Clean state dict\n", | |
"def remove_module_prefix(state_dict):\n", | |
" \"\"\"\n", | |
" Removes the 'module.' prefix from state dictionary keys.\n", | |
" \"\"\"\n", | |
" new_state_dict = {}\n", | |
" for k, v in state_dict.items():\n", | |
" new_key = k.replace('module.', '')\n", | |
" new_state_dict[new_key] = v\n", | |
" return new_state_dict\n", | |
"\n", | |
"cleaned_state_dict = remove_module_prefix(emoclip_weights)" | |
], | |
"metadata": { | |
"id": "uQclE7o5p26G" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from architecture.video_clip import VClip\n", | |
"model = VClip()" | |
], | |
"metadata": { | |
"id": "DKCoIZP8p6mO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"missing_keys, unexpected_keys = model.load_state_dict(cleaned_state_dict, strict=False)\n", | |
"print(\"Missing keys: \" + str(missing_keys))\n", | |
"print(\"Unexpected keys: \" + str(unexpected_keys))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "SKX-836Mp7Y-", | |
"outputId": "cc706a6e-7098-4930-d140-0b91c7bc9921" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Missing keys: ['temporal.temporal_transformer.layers.2.0.fn.norm.weight', 'temporal.temporal_transformer.layers.2.0.fn.norm.bias', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_qkv.weight', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_out.0.weight', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_out.0.bias', 'temporal.temporal_transformer.layers.2.1.fn.norm.weight', 'temporal.temporal_transformer.layers.2.1.fn.norm.bias', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.0.weight', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.0.bias', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.3.weight', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.3.bias', 'temporal.temporal_transformer.layers.3.0.fn.norm.weight', 'temporal.temporal_transformer.layers.3.0.fn.norm.bias', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_qkv.weight', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_out.0.weight', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_out.0.bias', 'temporal.temporal_transformer.layers.3.1.fn.norm.weight', 'temporal.temporal_transformer.layers.3.1.fn.norm.bias', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.0.weight', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.0.bias', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.3.weight', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.3.bias']\n", | |
"Unexpected keys: []\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment