Created
January 5, 2024 19:06
-
-
Save keturn/2007d471c0a2b04f3d7eeca054da4d18 to your computer and use it in GitHub Desktop.
exploring CLIP token embeddings as used for Stable Diffusion inputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "27582118-8452-4e16-bbdf-d6bdb7e665d6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"\n", | |
"from pathlib import Path\n", | |
"\n", | |
"import diffusers, torch, transformers, tokenizers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "3ea9f738-051b-41d1-b5a4-641554ba4ace", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"MODEL_PATH = Path(os.environ['INVOKEAI_ROOT']) / 'models' / 'sdxl' / 'main' / 'stable-diffusion-xl-base-1-0'\n", | |
"assert MODEL_PATH.exists() and MODEL_PATH.is_dir()\n", | |
"\n", | |
"device = torch.device(\"cuda:0\")\n", | |
"torch.set_default_device(device)\n", | |
"cx = torch.inference_mode(True).__enter__()\n", | |
"\n", | |
"clip = transformers.CLIPTextModelWithProjection.from_pretrained(MODEL_PATH / 'text_encoder_2', use_safetensors=True)\n", | |
"embeddings = clip.get_input_embeddings()\n", | |
"positional_embeddings = clip.text_model.embeddings.position_embedding\n", | |
"\n", | |
"tokenizer = transformers.CLIPTokenizerFast.from_pretrained(MODEL_PATH / 'tokenizer_2', use_safetensors=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "9fea5e58-052a-4c75-bd17-b29200bbcb70", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([49406, 320, 2713, 2870, 3086, 633, 518, 7991, 6913, 281,\n", | |
" 1170, 970, 593, 8922, 11729, 269, 49407], device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"prompt = \"a classic oil painting from the dutch masters: still life with banana sushi.\"\n", | |
"token_ids = torch.tensor(tokenizer(prompt)['input_ids'])\n", | |
"display(token_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a3309e51-8688-4421-92d8-698a1f7f8382", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([17, 1280])" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.0030, 0.0038, 0.0004, ..., -0.0003, 0.0038, -0.0007],\n", | |
" [-0.0020, -0.0234, 0.0159, ..., -0.0003, -0.0149, 0.0073],\n", | |
" [-0.0136, -0.0032, -0.0109, ..., 0.0050, -0.0047, -0.0118],\n", | |
" ...,\n", | |
" [ 0.0173, 0.0183, -0.0243, ..., -0.0076, -0.0182, -0.0043],\n", | |
" [ 0.0042, 0.0105, -0.0081, ..., -0.0052, 0.0157, -0.0053],\n", | |
" [ 0.0024, 0.0029, 0.0049, ..., 0.0023, -0.0013, 0.0016]],\n", | |
" device='cuda:0', grad_fn=<EmbeddingBackward0>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"prompt_token_embeddings = embeddings(token_ids)\n", | |
"display(prompt_token_embeddings.shape, prompt_token_embeddings)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "6e67c6e6-7af4-45b0-b467-11dbe96a3734", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n", | |
" device='cuda:0')" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 17, 1280])" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-2.9507e-03, 3.7556e-03, 4.4179e-04, ..., -2.7585e-04,\n", | |
" 3.7937e-03, -7.4005e-04],\n", | |
" [-4.5052e-03, 4.9324e-03, -2.1343e-03, ..., -3.2425e-05,\n", | |
" 2.7370e-03, -2.0905e-03],\n", | |
" [-1.7939e-03, -5.5611e-05, -2.3365e-03, ..., 1.3590e-04,\n", | |
" 2.1152e-03, -7.2002e-04],\n", | |
" ...,\n", | |
" [-6.4888e-03, 1.2619e-02, 1.1116e-02, ..., 1.4324e-03,\n", | |
" 1.5526e-03, -4.5509e-03],\n", | |
" [-2.0618e-03, 8.5602e-03, 9.1782e-03, ..., 4.4751e-04,\n", | |
" -2.1400e-03, -5.3635e-03],\n", | |
" [-5.2986e-03, 5.6686e-03, 6.6223e-03, ..., 1.4315e-03,\n", | |
" 1.6890e-03, -4.6959e-03]]], device='cuda:0',\n", | |
" grad_fn=<EmbeddingBackward0>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"position_ids = clip.text_model.embeddings.position_ids[:, :token_ids.shape[-1]]\n", | |
"display(position_ids)\n", | |
"prompt_position_embeddings = positional_embeddings(position_ids)\n", | |
"display(prompt_position_embeddings.shape, prompt_position_embeddings)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1553e9e6-f706-444f-b887-6c8f25eb0492", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 17, 1280])" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-0.0059, 0.0075, 0.0009, ..., -0.0006, 0.0076, -0.0015],\n", | |
" [-0.0065, -0.0185, 0.0137, ..., -0.0003, -0.0121, 0.0053],\n", | |
" [-0.0154, -0.0032, -0.0133, ..., 0.0051, -0.0026, -0.0125],\n", | |
" ...,\n", | |
" [ 0.0109, 0.0309, -0.0131, ..., -0.0062, -0.0166, -0.0089],\n", | |
" [ 0.0021, 0.0191, 0.0011, ..., -0.0048, 0.0136, -0.0107],\n", | |
" [-0.0029, 0.0086, 0.0116, ..., 0.0037, 0.0004, -0.0031]]],\n", | |
" device='cuda:0', grad_fn=<AddBackward0>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Use the CLIPTextEmbedding's complete forward method, which both gets the tokene embeddings and applies the positional embeddings to them.\n", | |
"prompt_combined_embeddings = clip.text_model.embeddings(token_ids)\n", | |
"display(prompt_combined_embeddings.shape, prompt_combined_embeddings)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "4acdee70-e769-459d-96fe-cdb43399310a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# reversing that addition in reduced precision to see how much of the positional information is intact\n", | |
"bf16_positions = prompt_combined_embeddings.to(dtype=torch.bfloat16) - prompt_token_embeddings.to(dtype=torch.bfloat16)\n", | |
"fp16_positions = prompt_combined_embeddings.to(dtype=torch.float16) - prompt_token_embeddings.to(dtype=torch.float16)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "ae908245-1dd8-43e8-b297-e7632c0cdd87", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.0291, 2.8052, 3.0012, 2.8337, 2.8358, 2.9264, 2.6316, 2.7296,\n", | |
" 2.6183, 2.4568, 2.6012, 2.5254, 2.3077, 2.2098, 1.9971, 1.9277,\n", | |
" 0.5930]], device='cuda:0', grad_fn=<SubBackward0>)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# trying to get an idea of how many bits the position info is shifted from the token info\n", | |
"prompt_token_embeddings.abs().mean(-1).log2() - bf16_positions.abs().mean(-1).log2()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "6de74ec4-539d-4432-95e6-8f306ca83641", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[0.0041, 0.0696, nan, 0.0773, 0.0743, 0.0745, 0.0678, 0.0708, 0.0742,\n", | |
" 0.0604, 0.0655, 0.0639, 0.0599, 0.0488, 0.0423, 0.0431, 0.0181]],\n", | |
" device='cuda:0', grad_fn=<MeanBackward1>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"position_mantissas = prompt_position_embeddings.frexp().mantissa\n", | |
"bf16_position_mantissas = bf16_positions.to(dtype=torch.float32).frexp().mantissa # bf16 does not implement frexp itself\n", | |
"mantissa_error = (bf16_position_mantissas - position_mantissas) / position_mantissas\n", | |
"display(mantissa_error.abs().mean(-1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "8680e9dd-5bc0-4c2a-8c55-9b24bf624db1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[8.7450e-06, 1.1561e-02, nan, 1.1818e-02, 1.1582e-02, 1.1504e-02,\n", | |
" 9.6999e-03, 1.2325e-02, 1.4067e-02, 1.2303e-02, 9.1884e-03, 1.1376e-02,\n", | |
" 9.2605e-03, 8.0006e-03, 9.0526e-03, 6.5079e-03, 2.7143e-03]],\n", | |
" device='cuda:0', grad_fn=<MeanBackward1>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fp16_position_mantissas = fp16_positions.frexp().mantissa\n", | |
"fp16_mantissa_error = (fp16_position_mantissas - position_mantissas) / position_mantissas\n", | |
"display(fp16_mantissa_error.abs().mean(-1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c29a0cce-ce2d-4041-9db7-2abd88184519", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "v311_invoke", | |
"language": "python", | |
"name": "v311_invoke" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment