Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sayakpaul/83fe388cd4795c01f60a765392ed455d to your computer and use it in GitHub Desktop.
Save sayakpaul/83fe388cd4795c01f60a765392ed455d to your computer and use it in GitHub Desktop.
This notebook shows how to use a custom `transformers` model within a Stable Diffusion pipeline.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "41843243-1fdd-4037-a433-24537e154b0f",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "02ccbacc-e631-47c0-8ce8-66632409c9d2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/sayakpaul/Downloads/custom-pipeline\n"
]
}
],
"source": [
"!mkdir custom-pipeline\n",
"%cd custom-pipeline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "111d777a-4976-4ec9-a56f-6fc9c87fdf30",
"metadata": {},
"outputs": [],
"source": [
"from transformers import PretrainedConfig, PreTrainedModel\n",
"import torch\n",
"\n",
"class TestConfig(PretrainedConfig):\n",
"\n",
" model_type=\"test_model\"\n",
" \n",
" def __init__(self, d_model: int=None, n_class: int=None, bias: bool=True, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.d_model = d_model\n",
" self.n_class = n_class\n",
" self.bias = bias\n",
"\n",
"\n",
"class TestModel(PreTrainedModel):\n",
" \n",
" config_class = TestConfig\n",
" \n",
" def __init__(self, config: TestConfig):\n",
" super().__init__(config)\n",
" \n",
" self.linear = torch.nn.Linear(\n",
" in_features=config.d_model, \n",
" out_features=config.n_class,\n",
" bias=config.bias\n",
" )\n",
" \n",
" def forward(self, x, **kwargs):\n",
" random_outputs = torch.randn(x.shape[0], 77, 768)\n",
" return (random_outputs, )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "892853ee-ef56-4434-9ca2-6f74a0460f94",
"metadata": {},
"outputs": [],
"source": [
"test_model_config = TestConfig(d_model=3, n_class=10)\n",
"test_model = TestModel(test_model_config)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4d574ae7-1566-4a74-af80-1653abf7e2e2",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoConfig, AutoModel\n",
"\n",
"AutoConfig.register(\"test_model\", TestConfig)\n",
"AutoModel.register(TestConfig, TestModel)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b34b1a44-140d-49a4-8884-b05d8e2a6e32",
"metadata": {},
"outputs": [],
"source": [
"!mkdir test_model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b6714878-4650-40d2-95b6-4264f9cb5399",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing test_model/configuration_test_model.py\n"
]
}
],
"source": [
"%%writefile test_model/configuration_test_model.py\n",
"from transformers import PretrainedConfig \n",
"\n",
"\n",
"class TestConfig(PretrainedConfig):\n",
" \n",
" model_type = \"test_model\"\n",
" \n",
" def __init__(self, d_model: int=None, n_class: int=None, bias: bool=True, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.d_model = d_model\n",
" self.n_class = n_class\n",
" self.bias = bias"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "da435782-2522-4c7c-b418-5de70aca54b5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing test_model/modeling_test_model.py\n"
]
}
],
"source": [
"%%writefile test_model/modeling_test_model.py\n",
"import torch \n",
"from transformers import PreTrainedModel \n",
"from .configuration_test_model import TestConfig\n",
"\n",
"\n",
"class TestModel(PreTrainedModel):\n",
" \n",
" config_class = TestConfig\n",
" \n",
" def __init__(self, config: TestConfig):\n",
" super().__init__(config)\n",
" \n",
" self.linear = torch.nn.Linear(\n",
" in_features=config.d_model, \n",
" out_features=config.n_class,\n",
" bias=config.bias\n",
" )\n",
" \n",
" def forward(self, x, **kwargs):\n",
" random_outputs = torch.randn(x.shape[0], 77, 768)\n",
" return (random_outputs, )"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "049edb45-b651-4d14-8e73-334a00a8a324",
"metadata": {},
"outputs": [],
"source": [
"!touch test_model/__init__.py"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bdffa93b-177b-4c46-889a-c541bdf93017",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total 16\n",
"-rw-r--r--@ 1 sayakpaul staff 0B Mar 13 17:27 __init__.py\n",
"-rw-r--r--@ 1 sayakpaul staff 332B Mar 13 17:27 configuration_test_model.py\n",
"-rw-r--r--@ 1 sayakpaul staff 564B Mar 13 17:27 modeling_test_model.py\n"
]
}
],
"source": [
"!ls -lh test_model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "afd4e1bf-07fc-4f22-9c27-50bab35db90b",
"metadata": {},
"outputs": [],
"source": [
"!rm -rf __pycache__"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "28d88593-d97d-4def-99bc-762d78431801",
"metadata": {},
"outputs": [],
"source": [
"from test_model.configuration_test_model import TestConfig\n",
"from test_model.modeling_test_model import TestModel"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "239beaf3-d6fc-4c9f-811d-ad3b47ae6e8d",
"metadata": {},
"outputs": [],
"source": [
"TestConfig.register_for_auto_class()\n",
"TestModel.register_for_auto_class(\"AutoModel\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6147a27f-c90f-42a8-93f1-2cccbb0ac44f",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38c21b85e5c445d18a0abe4b2c970706",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/336 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/sayakpaul/test_model_transformers/commit/850cd2f8ae793627b83e382066e3c3a5e5271d5d', commit_message='Upload model', commit_description='', oid='850cd2f8ae793627b83e382066e3c3a5e5271d5d', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_model_config = TestConfig(d_model=3, n_class=10)\n",
"test_model = TestModel(test_model_config)\n",
"test_model.push_to_hub(\"test_model_transformers\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "bc40c3d9-3674-4969-bb37-0bf15f0be094",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9c05df91f2d0479e874610fa16060018",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/315 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ace2b8e12c0d4c72a1923826b9b23b3f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"configuration_test_model.py: 0%| | 0.00/332 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A new version of the following files was downloaded from https://huggingface.co/sayakpaul/test_model_transformers:\n",
"- configuration_test_model.py\n",
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "04d08a3b8b0c4e4fa836c0a46f21446b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"modeling_test_model.py: 0%| | 0.00/564 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A new version of the following files was downloaded from https://huggingface.co/sayakpaul/test_model_transformers:\n",
"- modeling_test_model.py\n",
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "995291e6bf694693b05b6ef111d9d87c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/336 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"torch.Size([10, 77, 768])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoModel \n",
"\n",
"loaded_test_model = AutoModel.from_pretrained(\"sayakpaul/test_model_transformers\", trust_remote_code=True)\n",
"loaded_test_model(torch.randn(10, 3))[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "a59cc8f2-5c04-4cec-9786-8c3dce002693",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cd2c8267d004c4ca4a8beb0a3160a1b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model_index.json: 0%| | 0.00/541 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb11bb85c24241ceae3734e79c62ff4d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26a462ab605b4c15ae7f5fb64393463d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"scheduler/scheduler_config.json: 0%| | 0.00/308 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "16c471b8deef4b2ab221c34cccf9739d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer/special_tokens_map.json: 0%| | 0.00/472 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f95fad9118b54f15b1cb955ba03027d3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer/tokenizer_config.json: 0%| | 0.00/806 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28683486165b4f458fb2507f37980b91",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"safety_checker/config.json: 0%| | 0.00/4.72k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ab59f5de437f4c648e735cc4f8c5f41b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer/merges.txt: 0%| | 0.00/525k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a3bc93857c884ea1be64628a872b3a18",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer/vocab.json: 0%| | 0.00/1.06M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71339a4951ed46649dfa17ab8afc924a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(…)ature_extractor/preprocessor_config.json: 0%| | 0.00/342 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2188d303c434f15b74cd283fbf66710",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"unet/config.json: 0%| | 0.00/743 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fd993d6471c046e3acd4b1e6c298bf96",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vae/config.json: 0%| | 0.00/547 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "761f242b231f423896ad5df01d4046b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.fp16.safetensors: 0%| | 0.00/608M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "22206d311e4b449c9245a4ea7c857a6e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"diffusion_pytorch_model.fp16.safetensors: 0%| | 0.00/167M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c86b6d1563940b7b6ba3d56ff306986",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"diffusion_pytorch_model.fp16.safetensors: 0%| | 0.00/1.72G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9412ea1a66ce4a76919273bfdd3abe08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"bos_token_id\"]` will be overriden.\n",
"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"eos_token_id\"]` will be overriden.\n"
]
}
],
"source": [
"from diffusers import DiffusionPipeline\n",
"\n",
"repo_id = \"runwayml/stable-diffusion-v1-5\"\n",
"pipeline = DiffusionPipeline.from_pretrained(repo_id, text_encoder=loaded_test_model, variant=\"fp16\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7d97120e-394c-4cdf-9e6d-ca19765af604",
"metadata": {},
"outputs": [],
"source": [
"pipeline = pipeline.to(\"mps\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a45f1a18-2712-4ba6-932d-f5b95ce49dac",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "498882d863c345cfbdcdbfe315399cfa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(512, 512)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline(\"okay\", num_inference_steps=3).images[0].size"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment