Skip to content

Instantly share code, notes, and snippets.

@inu-ai
Last active November 28, 2022 11:42
Show Gist options
  • Save inu-ai/7e8a8ecda5f6649d81bd5202ce8e6a21 to your computer and use it in GitHub Desktop.
Save inu-ai/7e8a8ecda5f6649d81bd5202ce8e6a21 to your computer and use it in GitHub Desktop.
stable_diffusion_1_dreambooth_Kohya_S.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/thx-pw/7e8a8ecda5f6649d81bd5202ce8e6a21/dreambooth_stable_diffusion_fixed.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"移動します。\n",
"\n",
"https://github.com/thx-pw/stable-diffusion-2.0-dreambooth"
],
"metadata": {
"id": "51U6Ge8sCn1z"
}
},
{
"cell_type": "markdown",
"source": [
"このColabのライセンスはApache License 2.0\n",
"\n",
"引用元\n",
"\n",
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth\n",
"\n",
"https://note.com/kohya_ss/n/nee3ed1649fb6\n",
"\n",
"https://note.com/kohya_ss/n/nad3bce9a3622"
],
"metadata": {
"id": "8DK981h_7p0X"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XU7NuMAA2drw"
},
"outputs": [],
"source": [
"#@title GPUチェック\n",
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BzM7j0ZSc_9c"
},
"source": [
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aLWXPZqjsZVV",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title 必要なパッケージのインストール\n",
"!pip install -q diffusers[torch]==0.9.0 accelerate transformers==4.21.3 ftfy albumentations opencv-python einops bitsandbytes fairscale==0.4.6 numpy==1.21.6\n",
"!pip install -q https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.14/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl\n",
"\n",
"!pip install -q pytorch_lightning\n",
"\n",
"!git clone https://github.com/salesforce/BLIP --quiet\n",
"\n",
"!wget -q https://github.com/thx-pw/stable-diffusion-2.0-dreambooth/raw/main/gen_img_diffusers.py\n",
"!wget -q https://github.com/thx-pw/stable-diffusion-2.0-dreambooth/raw/main/train_db_fixed.py\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Rxg0y5MBudmd"
},
"outputs": [],
"source": [
"#@title モデルの選択\n",
"\n",
"#@markdown https://huggingface.co/settings/tokens\n",
"HUGGINGFACE_TOKEN = \"\" #@param {type:\"string\"}\n",
"\n",
"MODEL_NAME = \"stable-v14\" #@param ['trinart-characters-19m', 'waifu-v13-float32', 'waifu-v13-float16', 'stable-v14', 'pokemon', 'robo-v1']\n",
"\n",
"# 教師データ(学習データ)の保存場所\n",
"TRAIN_DIR = \"/content/input/train\"\n",
"!mkdir -p $TRAIN_DIR\n",
"\n",
"# 正則化画像(クラスの画像)の保存場所\n",
"REG_DIR = \"/content/input/reg\"\n",
"!mkdir -p $REG_DIR\n",
"\n",
"# 学習済みモデルの保存場所\n",
"OUTPUT_DIR = \"/content/output\" \n",
"!mkdir -p $OUTPUT_DIR\n"
]
},
{
"cell_type": "code",
"source": [
"#@title モデルのダウンロード\n",
"\n",
"models_dict = {\n",
" \"trinart-characters-19m\" : \"https://huggingface.co/naclbit/trinart_characters_19.2m_stable_diffusion_v1/blob/main/trinart_characters_it4_v1.ckpt\",\n",
" \"waifu-v13-float32\" : \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/blob/main/wd-v1-3-float32.ckpt\",\n",
" \"waifu-v13-float16\" : \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/blob/main/wd-v1-3-float16.ckpt\",\n",
" \"stable-v14\" : \"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/blob/main/sd-v1-4.ckpt\",\n",
" \"pokemon\" : \"https://huggingface.co/justinpinkney/pokemon-stable-diffusion/blob/main/ema-only-epoch%3D000142.ckpt\",\n",
" \"robo-v1\" : \"https://huggingface.co/nousr/robo-diffusion/blob/main/models/robo-diffusion-v1.ckpt\",\n",
"}\n",
"\n",
"model_url = models_dict[MODEL_NAME].replace(\"/blob/\", \"/resolve/\")\n",
"user_header = f\"\\\"Authorization: Bearer {HUGGINGFACE_TOKEN}\\\"\"\n",
"!wget --header={user_header} {model_url} -O /content/{MODEL_NAME}.ckpt\n"
],
"metadata": {
"cellView": "form",
"id": "ww7DMIzFxm_h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title BLIP機能\n",
"import torch\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"def load_blip():\n",
" import sys\n",
" sys.path.append('BLIP')\n",
" \n",
" from models.blip import blip_decoder\n",
"\n",
" %cd /content/BLIP\n",
"\n",
" blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
" \n",
" blip_model = blip_decoder(pretrained=blip_model_url, image_size=384, vit='base')\n",
" blip_model.eval()\n",
" blip_model = blip_model.to(device)\n",
"\n",
" %cd /content\n",
"\n",
" return blip_model"
],
"metadata": {
"cellView": "form",
"id": "sL8UGjj8tmUH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fe-GgtnUVO_e",
"cellView": "form"
},
"outputs": [],
"source": [
"#@title 教師データ(学習データ)のアップロードと正則化画像(クラスの画像)の自動生成\n",
"#@markdown このセルでSKSとCLASSを変更し、複数回実行すると複数キャラの学習が可能\n",
"\n",
"#@markdown Colabへのアップロードが遅いのでリサイズするツール:https://www.birme.net/?target_width=1024&target_height=1024\n",
"SKS = \"zundamon\" #@param {type:\"string\"}\n",
"CLASS = \"boy\" #@param {type:\"string\"}\n",
"TRAIN_N_REPEATS = 20\n",
"REG_N_REPEATS = 1 \n",
"#@markdown (実験的機能)教師データをBLIPでpromptを逆算して、「CLASS, prompt」で正則化画像を自動生成します\n",
"\n",
"#@markdown BLIPを使わない場合は、promptはCLASSだけになります\n",
"use_blip = False #@param {type:\"boolean\"}\n",
"NEGATIVE_PROMPT = \"\" #@param {type:\"string\"}\n",
"\n",
"PROMPTS_PATH = \"/content/prompts.txt\"\n",
"\n",
"import os\n",
"from google.colab import files\n",
"import shutil\n",
"import glob\n",
"from PIL import Image\n",
"\n",
"train_path = os.path.join(TRAIN_DIR, f\"{TRAIN_N_REPEATS}_{SKS} {CLASS}\")\n",
"os.makedirs(train_path, exist_ok=True)\n",
"reg_path = os.path.join(REG_DIR, f\"{REG_N_REPEATS}_{CLASS}\")\n",
"os.makedirs(reg_path, exist_ok=True)\n",
"\n",
"uploaded = files.upload()\n",
"for filename in uploaded.keys():\n",
" dst_path = os.path.join(train_path, filename)\n",
" shutil.move(filename, dst_path)\n",
"\n",
"def get_prompt(blip_model, image):\n",
" from torchvision import transforms\n",
" from torchvision.transforms.functional import InterpolationMode\n",
" image_size = 384\n",
" transform = transforms.Compose([\n",
" transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
" ])\n",
" \n",
" image = transform(image).unsqueeze(0).to(device)\n",
" with torch.no_grad():\n",
" prompt = blip_model.generate(image, sample=False, num_beams=3, max_length=200, min_length=30)\n",
" return prompt[0]\n",
"\n",
"def get_prompts(blip_model):\n",
" prompts = []\n",
" for image_path in glob.glob(f'{train_path}/*.*'):\n",
" image = Image.open(image_path).convert('RGB')\n",
" prompt = get_prompt(blip_model, image)\n",
" prompts.append(prompt)\n",
" return prompts\n",
"\n",
"def generate_prompts():\n",
" if use_blip:\n",
" blip_model = load_blip()\n",
" prompts = [f\"{CLASS}, {x} --n {NEGATIVE_PROMPT}\" for x in get_prompts(blip_model)]\n",
"\n",
" del blip_model\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" else:\n",
" prompts = [f\"{CLASS}, --n {NEGATIVE_PROMPT}\"]\n",
"\n",
" with open(PROMPTS_PATH, \"w\") as f:\n",
" f.write('\\n'.join(prompts))\n",
" \n",
"def generate_reg_images():\n",
" reg_num_images = sum(os.path.isfile(os.path.join(reg_path, name)) for name in os.listdir(reg_path))\n",
" reg_num_images = (TRAIN_N_REPEATS * train_num_images) // REG_N_REPEATS - reg_num_images\n",
" \n",
" !python gen_img_diffusers.py \\\n",
" --ckpt {MODEL_NAME}.ckpt \\\n",
" --outdir {reg_path} \\\n",
" --xformers \\\n",
" --fp16 \\\n",
" --W 512 \\\n",
" --H 512 \\\n",
" --scale 12.5 \\\n",
" --sampler ddim \\\n",
" --steps 20 \\\n",
" --batch_size 4 \\\n",
" --images_per_prompt {reg_num_images} \\\n",
" --from_file {PROMPTS_PATH}\n",
"\n",
"train_num_images = sum(os.path.isfile(os.path.join(train_path, name)) for name in os.listdir(train_path))\n",
"if train_num_images > 0:\n",
" generate_prompts()\n",
" generate_reg_images()\n",
"else:\n",
" print(\"cancel upload.\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jjcSXTp-u-Eg"
},
"outputs": [],
"source": [
"!accelerate launch --num_cpu_threads_per_process 2 train_db_fixed.py \\\n",
" --pretrained_model_name_or_path={MODEL_NAME}.ckpt \\\n",
" --train_data_dir=$TRAIN_DIR \\\n",
" --reg_data_dir=$REG_DIR \\\n",
" --output_dir=$OUTPUT_DIR \\\n",
" --prior_loss_weight=1.0 \\\n",
" --resolution=512 \\\n",
" --train_batch_size=4 \\\n",
" --learning_rate=2e-6 \\\n",
" --max_train_steps=400 \\\n",
" --use_8bit_adam \\\n",
" --mixed_precision='fp16' \\\n",
" --xformers \\\n",
" --cache_latents \\\n",
" --gradient_checkpointing \\\n",
" --save_precision='fp16' \\\n",
" --save_every_n_epochs 2 \\\n",
" --logging_dir=logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "89Az5NUxOWdy"
},
"outputs": [],
"source": [
"#@title ログの確認\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir=logs"
]
},
{
"cell_type": "code",
"source": [
"#@title Google Driveにckptを保存\n",
"ckpt_name = \"epoch-000010\" #@param {type:\"string\"}\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"import os\n",
"model_checkpoints = \"/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion\"\n",
"os.makedirs(model_checkpoints, exist_ok=True)\n",
"!cp \"{OUTPUT_DIR}/{ckpt_name}.ckpt\" {model_checkpoints}\n",
"\n",
"print(f\"save to {model_checkpoints}\")"
],
"metadata": {
"cellView": "form",
"id": "MYDjfXf8MB2R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 学習済みモデルで画像生成\n",
"ckpt_name = \"epoch-000010\" #@param {type:\"string\"}\n",
"!python gen_img_diffusers.py \\\n",
" --ckpt \"{OUTPUT_DIR}/{ckpt_name}.ckpt\" \\\n",
" --outdir 'tmp' \\\n",
" --xformers \\\n",
" --fp16 \\\n",
" --W 768 \\\n",
" --H 768 \\\n",
" --scale 12.5 \\\n",
" --sampler ddim \\\n",
" --steps 20 \\\n",
" --batch_size 4 \\\n",
" --images_per_prompt 4 \\\n",
" --prompt \"{SKS} {CLASS} eating a lunch in MacDonald's -n\"\n",
"\n",
"print(\"create to /content/tmp\")"
],
"metadata": {
"cellView": "form",
"id": "OuQGBG737QJc"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": [],
"private_outputs": true,
"name": "stable_diffusion_1_dreambooth_Kohya_S.ipynb",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3.8.12 ('pytorch')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.12"
},
"vscode": {
"interpreter": {
"hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment