Skip to content

Instantly share code, notes, and snippets.

@SethGA
Created May 14, 2024 20:58
Show Gist options
  • Save SethGA/9b842d93baf253f250a20271de38b77d to your computer and use it in GitHub Desktop.
Save SethGA/9b842d93baf253f250a20271de38b77d to your computer and use it in GitHub Desktop.
EmoCLIP_keysTest.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyOGeHO89IJM4Mw5+xTMqGH/",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SethGA/9b842d93baf253f250a20271de38b77d/emoclip_keystest.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrTWw5qSo2V9"
},
"outputs": [],
"source": [
"! pip install ftfy regex tqdm\n",
"! pip install git+https://github.com/openai/CLIP.git\n",
"! pip install einops\n",
"import numpy as np\n",
"import torch\n",
"from pkg_resources import packaging\n",
"import clip"
]
},
{
"cell_type": "code",
"source": [
"device = torch.device(\"cuda\")\n",
"emoclip_weights = torch.load('downstream.pth', map_location=device)\n",
"print(emoclip_weights.keys())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EZd97jRDp0s-",
"outputId": "014d71c5-ddc1-4dce-9143-108c594b66dc"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"odict_keys(['module.logit_scale', 'module.backbone.positional_embedding', 'module.backbone.text_projection', 'module.backbone.logit_scale', 'module.backbone.visual.class_embedding', 'module.backbone.visual.positional_embedding', 'module.backbone.visual.proj', 'module.backbone.visual.conv1.weight', 'module.backbone.visual.ln_pre.weight', 'module.backbone.visual.ln_pre.bias', 'module.backbone.visual.transformer.resblocks.0.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.0.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.0.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.0.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.0.ln_1.weight', 'module.backbone.visual.transformer.resblocks.0.ln_1.bias', 'module.backbone.visual.transformer.resblocks.0.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.0.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.0.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.0.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.0.ln_2.weight', 'module.backbone.visual.transformer.resblocks.0.ln_2.bias', 'module.backbone.visual.transformer.resblocks.1.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.1.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.1.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.1.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.1.ln_1.weight', 'module.backbone.visual.transformer.resblocks.1.ln_1.bias', 'module.backbone.visual.transformer.resblocks.1.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.1.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.1.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.1.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.1.ln_2.weight', 'module.backbone.visual.transformer.resblocks.1.ln_2.bias', 'module.backbone.visual.transformer.resblocks.2.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.2.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.2.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.2.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.2.ln_1.weight', 'module.backbone.visual.transformer.resblocks.2.ln_1.bias', 'module.backbone.visual.transformer.resblocks.2.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.2.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.2.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.2.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.2.ln_2.weight', 'module.backbone.visual.transformer.resblocks.2.ln_2.bias', 'module.backbone.visual.transformer.resblocks.3.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.3.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.3.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.3.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.3.ln_1.weight', 'module.backbone.visual.transformer.resblocks.3.ln_1.bias', 'module.backbone.visual.transformer.resblocks.3.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.3.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.3.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.3.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.3.ln_2.weight', 'module.backbone.visual.transformer.resblocks.3.ln_2.bias', 'module.backbone.visual.transformer.resblocks.4.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.4.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.4.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.4.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.4.ln_1.weight', 'module.backbone.visual.transformer.resblocks.4.ln_1.bias', 'module.backbone.visual.transformer.resblocks.4.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.4.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.4.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.4.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.4.ln_2.weight', 'module.backbone.visual.transformer.resblocks.4.ln_2.bias', 'module.backbone.visual.transformer.resblocks.5.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.5.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.5.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.5.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.5.ln_1.weight', 'module.backbone.visual.transformer.resblocks.5.ln_1.bias', 'module.backbone.visual.transformer.resblocks.5.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.5.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.5.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.5.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.5.ln_2.weight', 'module.backbone.visual.transformer.resblocks.5.ln_2.bias', 'module.backbone.visual.transformer.resblocks.6.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.6.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.6.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.6.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.6.ln_1.weight', 'module.backbone.visual.transformer.resblocks.6.ln_1.bias', 'module.backbone.visual.transformer.resblocks.6.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.6.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.6.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.6.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.6.ln_2.weight', 'module.backbone.visual.transformer.resblocks.6.ln_2.bias', 'module.backbone.visual.transformer.resblocks.7.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.7.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.7.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.7.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.7.ln_1.weight', 'module.backbone.visual.transformer.resblocks.7.ln_1.bias', 'module.backbone.visual.transformer.resblocks.7.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.7.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.7.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.7.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.7.ln_2.weight', 'module.backbone.visual.transformer.resblocks.7.ln_2.bias', 'module.backbone.visual.transformer.resblocks.8.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.8.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.8.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.8.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.8.ln_1.weight', 'module.backbone.visual.transformer.resblocks.8.ln_1.bias', 'module.backbone.visual.transformer.resblocks.8.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.8.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.8.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.8.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.8.ln_2.weight', 'module.backbone.visual.transformer.resblocks.8.ln_2.bias', 'module.backbone.visual.transformer.resblocks.9.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.9.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.9.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.9.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.9.ln_1.weight', 'module.backbone.visual.transformer.resblocks.9.ln_1.bias', 'module.backbone.visual.transformer.resblocks.9.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.9.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.9.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.9.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.9.ln_2.weight', 'module.backbone.visual.transformer.resblocks.9.ln_2.bias', 'module.backbone.visual.transformer.resblocks.10.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.10.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.10.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.10.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.10.ln_1.weight', 'module.backbone.visual.transformer.resblocks.10.ln_1.bias', 'module.backbone.visual.transformer.resblocks.10.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.10.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.10.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.10.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.10.ln_2.weight', 'module.backbone.visual.transformer.resblocks.10.ln_2.bias', 'module.backbone.visual.transformer.resblocks.11.attn.in_proj_weight', 'module.backbone.visual.transformer.resblocks.11.attn.in_proj_bias', 'module.backbone.visual.transformer.resblocks.11.attn.out_proj.weight', 'module.backbone.visual.transformer.resblocks.11.attn.out_proj.bias', 'module.backbone.visual.transformer.resblocks.11.ln_1.weight', 'module.backbone.visual.transformer.resblocks.11.ln_1.bias', 'module.backbone.visual.transformer.resblocks.11.mlp.c_fc.weight', 'module.backbone.visual.transformer.resblocks.11.mlp.c_fc.bias', 'module.backbone.visual.transformer.resblocks.11.mlp.c_proj.weight', 'module.backbone.visual.transformer.resblocks.11.mlp.c_proj.bias', 'module.backbone.visual.transformer.resblocks.11.ln_2.weight', 'module.backbone.visual.transformer.resblocks.11.ln_2.bias', 'module.backbone.visual.ln_post.weight', 'module.backbone.visual.ln_post.bias', 'module.backbone.transformer.resblocks.0.attn.in_proj_weight', 'module.backbone.transformer.resblocks.0.attn.in_proj_bias', 'module.backbone.transformer.resblocks.0.attn.out_proj.weight', 'module.backbone.transformer.resblocks.0.attn.out_proj.bias', 'module.backbone.transformer.resblocks.0.ln_1.weight', 'module.backbone.transformer.resblocks.0.ln_1.bias', 'module.backbone.transformer.resblocks.0.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.0.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.0.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.0.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.0.ln_2.weight', 'module.backbone.transformer.resblocks.0.ln_2.bias', 'module.backbone.transformer.resblocks.1.attn.in_proj_weight', 'module.backbone.transformer.resblocks.1.attn.in_proj_bias', 'module.backbone.transformer.resblocks.1.attn.out_proj.weight', 'module.backbone.transformer.resblocks.1.attn.out_proj.bias', 'module.backbone.transformer.resblocks.1.ln_1.weight', 'module.backbone.transformer.resblocks.1.ln_1.bias', 'module.backbone.transformer.resblocks.1.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.1.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.1.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.1.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.1.ln_2.weight', 'module.backbone.transformer.resblocks.1.ln_2.bias', 'module.backbone.transformer.resblocks.2.attn.in_proj_weight', 'module.backbone.transformer.resblocks.2.attn.in_proj_bias', 'module.backbone.transformer.resblocks.2.attn.out_proj.weight', 'module.backbone.transformer.resblocks.2.attn.out_proj.bias', 'module.backbone.transformer.resblocks.2.ln_1.weight', 'module.backbone.transformer.resblocks.2.ln_1.bias', 'module.backbone.transformer.resblocks.2.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.2.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.2.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.2.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.2.ln_2.weight', 'module.backbone.transformer.resblocks.2.ln_2.bias', 'module.backbone.transformer.resblocks.3.attn.in_proj_weight', 'module.backbone.transformer.resblocks.3.attn.in_proj_bias', 'module.backbone.transformer.resblocks.3.attn.out_proj.weight', 'module.backbone.transformer.resblocks.3.attn.out_proj.bias', 'module.backbone.transformer.resblocks.3.ln_1.weight', 'module.backbone.transformer.resblocks.3.ln_1.bias', 'module.backbone.transformer.resblocks.3.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.3.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.3.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.3.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.3.ln_2.weight', 'module.backbone.transformer.resblocks.3.ln_2.bias', 'module.backbone.transformer.resblocks.4.attn.in_proj_weight', 'module.backbone.transformer.resblocks.4.attn.in_proj_bias', 'module.backbone.transformer.resblocks.4.attn.out_proj.weight', 'module.backbone.transformer.resblocks.4.attn.out_proj.bias', 'module.backbone.transformer.resblocks.4.ln_1.weight', 'module.backbone.transformer.resblocks.4.ln_1.bias', 'module.backbone.transformer.resblocks.4.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.4.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.4.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.4.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.4.ln_2.weight', 'module.backbone.transformer.resblocks.4.ln_2.bias', 'module.backbone.transformer.resblocks.5.attn.in_proj_weight', 'module.backbone.transformer.resblocks.5.attn.in_proj_bias', 'module.backbone.transformer.resblocks.5.attn.out_proj.weight', 'module.backbone.transformer.resblocks.5.attn.out_proj.bias', 'module.backbone.transformer.resblocks.5.ln_1.weight', 'module.backbone.transformer.resblocks.5.ln_1.bias', 'module.backbone.transformer.resblocks.5.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.5.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.5.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.5.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.5.ln_2.weight', 'module.backbone.transformer.resblocks.5.ln_2.bias', 'module.backbone.transformer.resblocks.6.attn.in_proj_weight', 'module.backbone.transformer.resblocks.6.attn.in_proj_bias', 'module.backbone.transformer.resblocks.6.attn.out_proj.weight', 'module.backbone.transformer.resblocks.6.attn.out_proj.bias', 'module.backbone.transformer.resblocks.6.ln_1.weight', 'module.backbone.transformer.resblocks.6.ln_1.bias', 'module.backbone.transformer.resblocks.6.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.6.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.6.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.6.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.6.ln_2.weight', 'module.backbone.transformer.resblocks.6.ln_2.bias', 'module.backbone.transformer.resblocks.7.attn.in_proj_weight', 'module.backbone.transformer.resblocks.7.attn.in_proj_bias', 'module.backbone.transformer.resblocks.7.attn.out_proj.weight', 'module.backbone.transformer.resblocks.7.attn.out_proj.bias', 'module.backbone.transformer.resblocks.7.ln_1.weight', 'module.backbone.transformer.resblocks.7.ln_1.bias', 'module.backbone.transformer.resblocks.7.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.7.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.7.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.7.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.7.ln_2.weight', 'module.backbone.transformer.resblocks.7.ln_2.bias', 'module.backbone.transformer.resblocks.8.attn.in_proj_weight', 'module.backbone.transformer.resblocks.8.attn.in_proj_bias', 'module.backbone.transformer.resblocks.8.attn.out_proj.weight', 'module.backbone.transformer.resblocks.8.attn.out_proj.bias', 'module.backbone.transformer.resblocks.8.ln_1.weight', 'module.backbone.transformer.resblocks.8.ln_1.bias', 'module.backbone.transformer.resblocks.8.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.8.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.8.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.8.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.8.ln_2.weight', 'module.backbone.transformer.resblocks.8.ln_2.bias', 'module.backbone.transformer.resblocks.9.attn.in_proj_weight', 'module.backbone.transformer.resblocks.9.attn.in_proj_bias', 'module.backbone.transformer.resblocks.9.attn.out_proj.weight', 'module.backbone.transformer.resblocks.9.attn.out_proj.bias', 'module.backbone.transformer.resblocks.9.ln_1.weight', 'module.backbone.transformer.resblocks.9.ln_1.bias', 'module.backbone.transformer.resblocks.9.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.9.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.9.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.9.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.9.ln_2.weight', 'module.backbone.transformer.resblocks.9.ln_2.bias', 'module.backbone.transformer.resblocks.10.attn.in_proj_weight', 'module.backbone.transformer.resblocks.10.attn.in_proj_bias', 'module.backbone.transformer.resblocks.10.attn.out_proj.weight', 'module.backbone.transformer.resblocks.10.attn.out_proj.bias', 'module.backbone.transformer.resblocks.10.ln_1.weight', 'module.backbone.transformer.resblocks.10.ln_1.bias', 'module.backbone.transformer.resblocks.10.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.10.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.10.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.10.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.10.ln_2.weight', 'module.backbone.transformer.resblocks.10.ln_2.bias', 'module.backbone.transformer.resblocks.11.attn.in_proj_weight', 'module.backbone.transformer.resblocks.11.attn.in_proj_bias', 'module.backbone.transformer.resblocks.11.attn.out_proj.weight', 'module.backbone.transformer.resblocks.11.attn.out_proj.bias', 'module.backbone.transformer.resblocks.11.ln_1.weight', 'module.backbone.transformer.resblocks.11.ln_1.bias', 'module.backbone.transformer.resblocks.11.mlp.c_fc.weight', 'module.backbone.transformer.resblocks.11.mlp.c_fc.bias', 'module.backbone.transformer.resblocks.11.mlp.c_proj.weight', 'module.backbone.transformer.resblocks.11.mlp.c_proj.bias', 'module.backbone.transformer.resblocks.11.ln_2.weight', 'module.backbone.transformer.resblocks.11.ln_2.bias', 'module.backbone.token_embedding.weight', 'module.backbone.ln_final.weight', 'module.backbone.ln_final.bias', 'module.temporal.cls_token', 'module.temporal.pos_embedding', 'module.temporal.temporal_transformer.layers.0.0.fn.norm.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.norm.bias', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_qkv.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_out.0.weight', 'module.temporal.temporal_transformer.layers.0.0.fn.fn.to_out.0.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.norm.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.norm.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.0.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.0.bias', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.3.weight', 'module.temporal.temporal_transformer.layers.0.1.fn.fn.net.3.bias', 'module.temporal.temporal_transformer.layers.1.0.fn.norm.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.norm.bias', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_qkv.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_out.0.weight', 'module.temporal.temporal_transformer.layers.1.0.fn.fn.to_out.0.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.norm.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.norm.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.0.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.0.bias', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.3.weight', 'module.temporal.temporal_transformer.layers.1.1.fn.fn.net.3.bias'])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Clean state dict\n",
"def remove_module_prefix(state_dict):\n",
" \"\"\"\n",
" Removes the 'module.' prefix from state dictionary keys.\n",
" \"\"\"\n",
" new_state_dict = {}\n",
" for k, v in state_dict.items():\n",
" new_key = k.replace('module.', '')\n",
" new_state_dict[new_key] = v\n",
" return new_state_dict\n",
"\n",
"cleaned_state_dict = remove_module_prefix(emoclip_weights)"
],
"metadata": {
"id": "uQclE7o5p26G"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from architecture.video_clip import VClip\n",
"model = VClip()"
],
"metadata": {
"id": "DKCoIZP8p6mO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"missing_keys, unexpected_keys = model.load_state_dict(cleaned_state_dict, strict=False)\n",
"print(\"Missing keys: \" + str(missing_keys))\n",
"print(\"Unexpected keys: \" + str(unexpected_keys))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SKX-836Mp7Y-",
"outputId": "cc706a6e-7098-4930-d140-0b91c7bc9921"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Missing keys: ['temporal.temporal_transformer.layers.2.0.fn.norm.weight', 'temporal.temporal_transformer.layers.2.0.fn.norm.bias', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_qkv.weight', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_out.0.weight', 'temporal.temporal_transformer.layers.2.0.fn.fn.to_out.0.bias', 'temporal.temporal_transformer.layers.2.1.fn.norm.weight', 'temporal.temporal_transformer.layers.2.1.fn.norm.bias', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.0.weight', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.0.bias', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.3.weight', 'temporal.temporal_transformer.layers.2.1.fn.fn.net.3.bias', 'temporal.temporal_transformer.layers.3.0.fn.norm.weight', 'temporal.temporal_transformer.layers.3.0.fn.norm.bias', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_qkv.weight', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_out.0.weight', 'temporal.temporal_transformer.layers.3.0.fn.fn.to_out.0.bias', 'temporal.temporal_transformer.layers.3.1.fn.norm.weight', 'temporal.temporal_transformer.layers.3.1.fn.norm.bias', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.0.weight', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.0.bias', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.3.weight', 'temporal.temporal_transformer.layers.3.1.fn.fn.net.3.bias']\n",
"Unexpected keys: []\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment