Skip to content

Instantly share code, notes, and snippets.

@keturn
Created January 5, 2024 19:06
Show Gist options
  • Save keturn/2007d471c0a2b04f3d7eeca054da4d18 to your computer and use it in GitHub Desktop.
Save keturn/2007d471c0a2b04f3d7eeca054da4d18 to your computer and use it in GitHub Desktop.
exploring CLIP token embeddings as used for Stable Diffusion inputs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "27582118-8452-4e16-bbdf-d6bdb7e665d6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from pathlib import Path\n",
"\n",
"import diffusers, torch, transformers, tokenizers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3ea9f738-051b-41d1-b5a4-641554ba4ace",
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH = Path(os.environ['INVOKEAI_ROOT']) / 'models' / 'sdxl' / 'main' / 'stable-diffusion-xl-base-1-0'\n",
"assert MODEL_PATH.exists() and MODEL_PATH.is_dir()\n",
"\n",
"device = torch.device(\"cuda:0\")\n",
"torch.set_default_device(device)\n",
"cx = torch.inference_mode(True).__enter__()\n",
"\n",
"clip = transformers.CLIPTextModelWithProjection.from_pretrained(MODEL_PATH / 'text_encoder_2', use_safetensors=True)\n",
"embeddings = clip.get_input_embeddings()\n",
"positional_embeddings = clip.text_model.embeddings.position_embedding\n",
"\n",
"tokenizer = transformers.CLIPTokenizerFast.from_pretrained(MODEL_PATH / 'tokenizer_2', use_safetensors=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9fea5e58-052a-4c75-bd17-b29200bbcb70",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([49406, 320, 2713, 2870, 3086, 633, 518, 7991, 6913, 281,\n",
" 1170, 970, 593, 8922, 11729, 269, 49407], device='cuda:0')"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prompt = \"a classic oil painting from the dutch masters: still life with banana sushi.\"\n",
"token_ids = torch.tensor(tokenizer(prompt)['input_ids'])\n",
"display(token_ids)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a3309e51-8688-4421-92d8-698a1f7f8382",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([17, 1280])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"tensor([[-0.0030, 0.0038, 0.0004, ..., -0.0003, 0.0038, -0.0007],\n",
" [-0.0020, -0.0234, 0.0159, ..., -0.0003, -0.0149, 0.0073],\n",
" [-0.0136, -0.0032, -0.0109, ..., 0.0050, -0.0047, -0.0118],\n",
" ...,\n",
" [ 0.0173, 0.0183, -0.0243, ..., -0.0076, -0.0182, -0.0043],\n",
" [ 0.0042, 0.0105, -0.0081, ..., -0.0052, 0.0157, -0.0053],\n",
" [ 0.0024, 0.0029, 0.0049, ..., 0.0023, -0.0013, 0.0016]],\n",
" device='cuda:0', grad_fn=<EmbeddingBackward0>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prompt_token_embeddings = embeddings(token_ids)\n",
"display(prompt_token_embeddings.shape, prompt_token_embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6e67c6e6-7af4-45b0-b467-11dbe96a3734",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n",
" device='cuda:0')"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"torch.Size([1, 17, 1280])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"tensor([[[-2.9507e-03, 3.7556e-03, 4.4179e-04, ..., -2.7585e-04,\n",
" 3.7937e-03, -7.4005e-04],\n",
" [-4.5052e-03, 4.9324e-03, -2.1343e-03, ..., -3.2425e-05,\n",
" 2.7370e-03, -2.0905e-03],\n",
" [-1.7939e-03, -5.5611e-05, -2.3365e-03, ..., 1.3590e-04,\n",
" 2.1152e-03, -7.2002e-04],\n",
" ...,\n",
" [-6.4888e-03, 1.2619e-02, 1.1116e-02, ..., 1.4324e-03,\n",
" 1.5526e-03, -4.5509e-03],\n",
" [-2.0618e-03, 8.5602e-03, 9.1782e-03, ..., 4.4751e-04,\n",
" -2.1400e-03, -5.3635e-03],\n",
" [-5.2986e-03, 5.6686e-03, 6.6223e-03, ..., 1.4315e-03,\n",
" 1.6890e-03, -4.6959e-03]]], device='cuda:0',\n",
" grad_fn=<EmbeddingBackward0>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"position_ids = clip.text_model.embeddings.position_ids[:, :token_ids.shape[-1]]\n",
"display(position_ids)\n",
"prompt_position_embeddings = positional_embeddings(position_ids)\n",
"display(prompt_position_embeddings.shape, prompt_position_embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1553e9e6-f706-444f-b887-6c8f25eb0492",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 17, 1280])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"tensor([[[-0.0059, 0.0075, 0.0009, ..., -0.0006, 0.0076, -0.0015],\n",
" [-0.0065, -0.0185, 0.0137, ..., -0.0003, -0.0121, 0.0053],\n",
" [-0.0154, -0.0032, -0.0133, ..., 0.0051, -0.0026, -0.0125],\n",
" ...,\n",
" [ 0.0109, 0.0309, -0.0131, ..., -0.0062, -0.0166, -0.0089],\n",
" [ 0.0021, 0.0191, 0.0011, ..., -0.0048, 0.0136, -0.0107],\n",
" [-0.0029, 0.0086, 0.0116, ..., 0.0037, 0.0004, -0.0031]]],\n",
" device='cuda:0', grad_fn=<AddBackward0>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Use the CLIPTextEmbedding's complete forward method, which both gets the tokene embeddings and applies the positional embeddings to them.\n",
"prompt_combined_embeddings = clip.text_model.embeddings(token_ids)\n",
"display(prompt_combined_embeddings.shape, prompt_combined_embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "4acdee70-e769-459d-96fe-cdb43399310a",
"metadata": {},
"outputs": [],
"source": [
"# reversing that addition in reduced precision to see how much of the positional information is intact\n",
"bf16_positions = prompt_combined_embeddings.to(dtype=torch.bfloat16) - prompt_token_embeddings.to(dtype=torch.bfloat16)\n",
"fp16_positions = prompt_combined_embeddings.to(dtype=torch.float16) - prompt_token_embeddings.to(dtype=torch.float16)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ae908245-1dd8-43e8-b297-e7632c0cdd87",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.0291, 2.8052, 3.0012, 2.8337, 2.8358, 2.9264, 2.6316, 2.7296,\n",
" 2.6183, 2.4568, 2.6012, 2.5254, 2.3077, 2.2098, 1.9971, 1.9277,\n",
" 0.5930]], device='cuda:0', grad_fn=<SubBackward0>)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# trying to get an idea of how many bits the position info is shifted from the token info\n",
"prompt_token_embeddings.abs().mean(-1).log2() - bf16_positions.abs().mean(-1).log2()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6de74ec4-539d-4432-95e6-8f306ca83641",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0041, 0.0696, nan, 0.0773, 0.0743, 0.0745, 0.0678, 0.0708, 0.0742,\n",
" 0.0604, 0.0655, 0.0639, 0.0599, 0.0488, 0.0423, 0.0431, 0.0181]],\n",
" device='cuda:0', grad_fn=<MeanBackward1>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"position_mantissas = prompt_position_embeddings.frexp().mantissa\n",
"bf16_position_mantissas = bf16_positions.to(dtype=torch.float32).frexp().mantissa # bf16 does not implement frexp itself\n",
"mantissa_error = (bf16_position_mantissas - position_mantissas) / position_mantissas\n",
"display(mantissa_error.abs().mean(-1))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8680e9dd-5bc0-4c2a-8c55-9b24bf624db1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[8.7450e-06, 1.1561e-02, nan, 1.1818e-02, 1.1582e-02, 1.1504e-02,\n",
" 9.6999e-03, 1.2325e-02, 1.4067e-02, 1.2303e-02, 9.1884e-03, 1.1376e-02,\n",
" 9.2605e-03, 8.0006e-03, 9.0526e-03, 6.5079e-03, 2.7143e-03]],\n",
" device='cuda:0', grad_fn=<MeanBackward1>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fp16_position_mantissas = fp16_positions.frexp().mantissa\n",
"fp16_mantissa_error = (fp16_position_mantissas - position_mantissas) / position_mantissas\n",
"display(fp16_mantissa_error.abs().mean(-1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c29a0cce-ce2d-4041-9db7-2abd88184519",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "v311_invoke",
"language": "python",
"name": "v311_invoke"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment