Skip to content

Instantly share code, notes, and snippets.

@safa-dayo
Created September 21, 2023 21:35
Show Gist options
  • Save safa-dayo/551fead5d9ffdabc0fa9a1bef21a3044 to your computer and use it in GitHub Desktop.
Save safa-dayo/551fead5d9ffdabc0fa9a1bef21a3044 to your computer and use it in GitHub Desktop.
kohya-ss/sd-scriptsをGoogle Colab上で実行し、LoRA学習を行うためのサンプルコマンド
### kohya-ss/sd-scriptsをGoogle Colab上で扱うためのスクリプト ###
# 以下の事前準備が必要です
# 1. /content配下(ルート配下)にtrain-dataフォルダを作成する
# 2. train-dataフォルダ内に学習用の画像とtrain.tomlファイルアップロードする
#
# こちらはuroko cloudというLoRAを作成するサンプルとなっています。
# train_network.pyを実行している箇所で、適宜 "--output_name" の引数を変更してください
#
# また学習のベースモデルとしてはchilled_remix v2を用いています。
# ベースモデルを変更する際はモデルのダウンロード箇所と、"--pretrained_model_name_or_path"の引数を変更してください
#############################################################
import os
import sys
%cd /content
# bitsandbytesの導入
!git clone -b 0.41.0 https://github.com/TimDettmers/bitsandbytes
%cd /content/bitsandbytes
!CUDA_VERSION=118 make cuda11x
!python setup.py install
# kohya-ss/sd-scriptsの導入
%cd /content
!git clone https://github.com/kohya-ss/sd-scripts.git
%cd sd-scripts
# 依存関係のインストール
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install xformers==0.0.21 bitsandbytes==0.41.1
!pip install tensorboard==2.12.3 tensorflow==2.12.0
# requiest==2.31.0を導入したいため、直接書き換える
!sed -i 's/requests==2\.28\.2/requests==2.31.0/' requirements.txt
!pip install --upgrade -r requirements.txt
# 本来はaccelerateで作成する予定のdefualt_config.yamlをGitHubからダウンロード
!curl -L https://gist.githubusercontent.com/safa-dayo/0672e7bf33d1efbe20c329e5a73e7538/raw/32af2e5b091f197a99878bda6374e97915a175a1/sd-scripts_train_accelerate_default_config.yaml -o /content/default_config.yaml
# train-dataディレクトリとtrain.tomlの存在確認
train_dir_path = "/content/train-data"
required_files = ["train.toml"]
if not os.path.isdir(train_dir_path):
print(f"このスクリプトを実行するには【{train_dir_path}】というディレクトリが存在している必要があります")
sys.exit(1)
for file_name in required_files:
file_path = os.path.join(train_dir_path, file_name)
if not os.path.isfile(file_path):
print(f"このスクリプトを実行するには【{file_path}】が存在している必要があります")
sys.exit(1)
# output用のディレクトリ
!mkdir /content/output
# 学習に利用するモデルファイルのダウンロード(ここではchilled_remix_v2を利用しているが適宜書き換える)
!mkdir /content/models
!wget https://huggingface.co/sazyou-roukaku/chilled_remix/resolve/main/chilled_remix_v2.safetensors --directory-prefix=/content/models/
print("===========================================================================================================================================")
print("セットアップが完了しました。\n【/content/train-data】というディレクトリが存在しており、\ntrainingに利用するtomlファイル(train.toml)が格納されている場合、このまま学習を続けます。")
print("===========================================================================================================================================")
### LoRA学習を開始
# 一応明示的にsd-scripts内に移動
%cd /content/sd-scripts
# 学習実行(LoRA)
!accelerate launch --num_cpu_threads_per_process 1 train_network.py \
--pretrained_model_name_or_path="/content/models/chilled_remix_v2.safetensors" \
--dataset_config="/content/train-data/train.toml" \
--output_dir="/content/output" \
--output_name="uroko_cloud" \
--save_model_as=safetensors \
--prior_loss_weight=1.0 \
--max_train_steps=400 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--xformers \
--mixed_precision="fp16" \
--cache_latents \
--gradient_checkpointing \
--save_every_n_epochs=1 \
--network_module=networks.lora 2>&1 > debug.log
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment