Created
March 13, 2024 12:03
-
-
Save sayakpaul/83fe388cd4795c01f60a765392ed455d to your computer and use it in GitHub Desktop.
This notebook shows how to use a custom `transformers` model within a Stable Diffusion pipeline.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "41843243-1fdd-4037-a433-24537e154b0f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "02ccbacc-e631-47c0-8ce8-66632409c9d2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"/Users/sayakpaul/Downloads/custom-pipeline\n" | |
] | |
} | |
], | |
"source": [ | |
"!mkdir custom-pipeline\n", | |
"%cd custom-pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "111d777a-4976-4ec9-a56f-6fc9c87fdf30", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import PretrainedConfig, PreTrainedModel\n", | |
"import torch\n", | |
"\n", | |
"class TestConfig(PretrainedConfig):\n", | |
"\n", | |
" model_type=\"test_model\"\n", | |
" \n", | |
" def __init__(self, d_model: int=None, n_class: int=None, bias: bool=True, **kwargs):\n", | |
" super().__init__(**kwargs)\n", | |
" self.d_model = d_model\n", | |
" self.n_class = n_class\n", | |
" self.bias = bias\n", | |
"\n", | |
"\n", | |
"class TestModel(PreTrainedModel):\n", | |
" \n", | |
" config_class = TestConfig\n", | |
" \n", | |
" def __init__(self, config: TestConfig):\n", | |
" super().__init__(config)\n", | |
" \n", | |
" self.linear = torch.nn.Linear(\n", | |
" in_features=config.d_model, \n", | |
" out_features=config.n_class,\n", | |
" bias=config.bias\n", | |
" )\n", | |
" \n", | |
" def forward(self, x, **kwargs):\n", | |
" random_outputs = torch.randn(x.shape[0], 77, 768)\n", | |
" return (random_outputs, )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "892853ee-ef56-4434-9ca2-6f74a0460f94", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_model_config = TestConfig(d_model=3, n_class=10)\n", | |
"test_model = TestModel(test_model_config)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "4d574ae7-1566-4a74-af80-1653abf7e2e2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import AutoConfig, AutoModel\n", | |
"\n", | |
"AutoConfig.register(\"test_model\", TestConfig)\n", | |
"AutoModel.register(TestConfig, TestModel)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "b34b1a44-140d-49a4-8884-b05d8e2a6e32", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!mkdir test_model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "b6714878-4650-40d2-95b6-4264f9cb5399", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Writing test_model/configuration_test_model.py\n" | |
] | |
} | |
], | |
"source": [ | |
"%%writefile test_model/configuration_test_model.py\n", | |
"from transformers import PretrainedConfig \n", | |
"\n", | |
"\n", | |
"class TestConfig(PretrainedConfig):\n", | |
" \n", | |
" model_type = \"test_model\"\n", | |
" \n", | |
" def __init__(self, d_model: int=None, n_class: int=None, bias: bool=True, **kwargs):\n", | |
" super().__init__(**kwargs)\n", | |
" self.d_model = d_model\n", | |
" self.n_class = n_class\n", | |
" self.bias = bias" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "da435782-2522-4c7c-b418-5de70aca54b5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Writing test_model/modeling_test_model.py\n" | |
] | |
} | |
], | |
"source": [ | |
"%%writefile test_model/modeling_test_model.py\n", | |
"import torch \n", | |
"from transformers import PreTrainedModel \n", | |
"from .configuration_test_model import TestConfig\n", | |
"\n", | |
"\n", | |
"class TestModel(PreTrainedModel):\n", | |
" \n", | |
" config_class = TestConfig\n", | |
" \n", | |
" def __init__(self, config: TestConfig):\n", | |
" super().__init__(config)\n", | |
" \n", | |
" self.linear = torch.nn.Linear(\n", | |
" in_features=config.d_model, \n", | |
" out_features=config.n_class,\n", | |
" bias=config.bias\n", | |
" )\n", | |
" \n", | |
" def forward(self, x, **kwargs):\n", | |
" random_outputs = torch.randn(x.shape[0], 77, 768)\n", | |
" return (random_outputs, )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "049edb45-b651-4d14-8e73-334a00a8a324", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!touch test_model/__init__.py" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "bdffa93b-177b-4c46-889a-c541bdf93017", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"total 16\n", | |
"-rw-r--r--@ 1 sayakpaul staff 0B Mar 13 17:27 __init__.py\n", | |
"-rw-r--r--@ 1 sayakpaul staff 332B Mar 13 17:27 configuration_test_model.py\n", | |
"-rw-r--r--@ 1 sayakpaul staff 564B Mar 13 17:27 modeling_test_model.py\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls -lh test_model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "afd4e1bf-07fc-4f22-9c27-50bab35db90b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!rm -rf __pycache__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "28d88593-d97d-4def-99bc-762d78431801", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from test_model.configuration_test_model import TestConfig\n", | |
"from test_model.modeling_test_model import TestModel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "239beaf3-d6fc-4c9f-811d-ad3b47ae6e8d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"TestConfig.register_for_auto_class()\n", | |
"TestModel.register_for_auto_class(\"AutoModel\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "6147a27f-c90f-42a8-93f1-2cccbb0ac44f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "38c21b85e5c445d18a0abe4b2c970706", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model.safetensors: 0%| | 0.00/336 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"CommitInfo(commit_url='https://huggingface.co/sayakpaul/test_model_transformers/commit/850cd2f8ae793627b83e382066e3c3a5e5271d5d', commit_message='Upload model', commit_description='', oid='850cd2f8ae793627b83e382066e3c3a5e5271d5d', pr_url=None, pr_revision=None, pr_num=None)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_model_config = TestConfig(d_model=3, n_class=10)\n", | |
"test_model = TestModel(test_model_config)\n", | |
"test_model.push_to_hub(\"test_model_transformers\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "bc40c3d9-3674-4969-bb37-0bf15f0be094", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9c05df91f2d0479e874610fa16060018", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"config.json: 0%| | 0.00/315 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ace2b8e12c0d4c72a1923826b9b23b3f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"configuration_test_model.py: 0%| | 0.00/332 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"A new version of the following files was downloaded from https://huggingface.co/sayakpaul/test_model_transformers:\n", | |
"- configuration_test_model.py\n", | |
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "04d08a3b8b0c4e4fa836c0a46f21446b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"modeling_test_model.py: 0%| | 0.00/564 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"A new version of the following files was downloaded from https://huggingface.co/sayakpaul/test_model_transformers:\n", | |
"- modeling_test_model.py\n", | |
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "995291e6bf694693b05b6ef111d9d87c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model.safetensors: 0%| | 0.00/336 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([10, 77, 768])" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from transformers import AutoModel \n", | |
"\n", | |
"loaded_test_model = AutoModel.from_pretrained(\"sayakpaul/test_model_transformers\", trust_remote_code=True)\n", | |
"loaded_test_model(torch.randn(10, 3))[0].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "a59cc8f2-5c04-4cec-9786-8c3dce002693", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5cd2c8267d004c4ca4a8beb0a3160a1b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model_index.json: 0%| | 0.00/541 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "eb11bb85c24241ceae3734e79c62ff4d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "26a462ab605b4c15ae7f5fb64393463d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"scheduler/scheduler_config.json: 0%| | 0.00/308 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "16c471b8deef4b2ab221c34cccf9739d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer/special_tokens_map.json: 0%| | 0.00/472 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f95fad9118b54f15b1cb955ba03027d3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer/tokenizer_config.json: 0%| | 0.00/806 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "28683486165b4f458fb2507f37980b91", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"safety_checker/config.json: 0%| | 0.00/4.72k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ab59f5de437f4c648e735cc4f8c5f41b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer/merges.txt: 0%| | 0.00/525k [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a3bc93857c884ea1be64628a872b3a18", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"tokenizer/vocab.json: 0%| | 0.00/1.06M [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "71339a4951ed46649dfa17ab8afc924a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"(…)ature_extractor/preprocessor_config.json: 0%| | 0.00/342 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b2188d303c434f15b74cd283fbf66710", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"unet/config.json: 0%| | 0.00/743 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "fd993d6471c046e3acd4b1e6c298bf96", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"vae/config.json: 0%| | 0.00/547 [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "761f242b231f423896ad5df01d4046b2", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"model.fp16.safetensors: 0%| | 0.00/608M [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "22206d311e4b449c9245a4ea7c857a6e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"diffusion_pytorch_model.fp16.safetensors: 0%| | 0.00/167M [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7c86b6d1563940b7b6ba3d56ff306986", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"diffusion_pytorch_model.fp16.safetensors: 0%| | 0.00/1.72G [00:00<?, ?B/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9412ea1a66ce4a76919273bfdd3abe08", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n", | |
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n", | |
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n" | |
] | |
} | |
], | |
"source": [ | |
"from diffusers import DiffusionPipeline\n", | |
"\n", | |
"repo_id = \"runwayml/stable-diffusion-v1-5\"\n", | |
"pipeline = DiffusionPipeline.from_pretrained(repo_id, text_encoder=loaded_test_model, variant=\"fp16\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "7d97120e-394c-4cdf-9e6d-ca19765af604", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pipeline = pipeline.to(\"mps\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "a45f1a18-2712-4ba6-932d-f5b95ce49dac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "498882d863c345cfbdcdbfe315399cfa", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/3 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(512, 512)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipeline(\"okay\", num_inference_steps=3).images[0].size" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.17" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment