Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Last active March 21, 2024 12:12
Show Gist options
  • Save kohya-ss/ad65420f597e0f58256f15f0d1f3d4ae to your computer and use it in GitHub Desktop.
Save kohya-ss/ad65420f597e0f58256f15f0d1f3d4ae to your computer and use it in GitHub Desktop.
LECOっぽいText Encoder only LoRAを学習する検証実装
# sdxl_train_network.py と同じ引数を指定してください
# --network_train_text_encoder_only オプションが必須です
#
# 260行目あたりの src_str, tgt_str に変換元と変換先のテキストを指定してください
# tagger の selected_tags.csv が必要ですので、適宜パスを変更してください
# この辺にあります : https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/tree/main
#
# "1girl" タグだけは必ず含まれる感じにしているので、必要なら 820 行目あたりを適宜変更してください
#
# 以下のオプションは指定できません(エラーになります):
# --cache_latents
# --cache_latents_to_disk
# --cache_text_encoder_outputs
# --cache_text_encoder_outputs_to_disk
# --sample_every_n_epochs
# --sample_every_n_steps
# --gradient_checkpointing (エラーは出ませんが指定すると学習がうまく行かないかも)
#
# ほかにもエラーになるオプションがあるかもしれませんので、適宜外してください
#
# dataset 設定の .toml は無視されます
# VRAM 使用量を見ながら --train_batch_size を適宜指定してください(わりと大きくできます)
import argparse
import random
import sys
import time
import os
import math
import importlib
import json
import toml
import torch
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
# with torch.enable_grad():
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
raise NotImplementedError("text_encoder_outputs1_list is not implemented yet")
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoders[0].device),
# batch["input_ids2"].to(text_encoders[0].device),
# tokenizers[0],
# tokenizers[1],
# text_encoders[0],
# text_encoders[1],
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# logger.info("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def train(self, args):
assert args.network_train_text_encoder_only
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
train_network.set_seed(args.seed)
# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
tokenizer = self.load_tokenizer(args)
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
# datasetを準備する:単なるテキストなのでPyTorchのDatasetは使わないでOK
# load tags from csv
csv_file = "./wd14_tagger_model/selected_tags.csv"
tags = []
with open(csv_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines[6:]: # rating tag と 1girl を外す
tags.append(line.split(",")[1].replace("_", " ").strip())
print(f"tags loaded: {len(tags)}, {tags[:5]}")
# srcをtgtに変換するように学習する:たぶんトークン長が同じでないと難しい
src_str = "hatsune miku"
tgt_str = "miki sayaka"
num_steps_per_epoch = 100 # 適当に変えてね
# acceleratorを準備する
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
# self.cache_text_encoder_outputs_if_needed(
# args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
# )
# prepare network
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
vae,
text_encoder,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for t_enc in text_encoders:
t_enc.gradient_checkpointing_enable()
del t_enc
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
# train_dataloader = torch.utils.data.DataLoader(
# train_dataset_group,
# batch_size=1,
# shuffle=True,
# collate_fn=collator,
# num_workers=n_workers,
# persistent_workers=args.persistent_data_loader_workers,
# )
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
num_steps_per_epoch / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
# train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
# network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc in text_encoders:
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
t_enc.eval()
del t_enc
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(num_steps_per_epoch / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
# accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
# accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {num_steps_per_epoch}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
# accelerator.print(
# f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
# )
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
# TODO refactor metadata creation and move to util
metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
# "ss_num_train_images": train_dataset_group.num_train_images,
# "ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": num_steps_per_epoch,
"ss_num_epochs": num_train_epochs,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_base_model_version": model_version,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_cache_latents": bool(args.cache_latents),
"ss_seed": args.seed,
"ss_lowram": args.lowram,
"ss_noise_offset": args.noise_offset,
"ss_multires_noise_iterations": args.multires_noise_iterations,
"ss_multires_noise_discount": args.multires_noise_discount,
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
"ss_zero_terminal_snr": args.zero_terminal_snr,
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
"ss_max_grad_norm": args.max_grad_norm,
"ss_caption_dropout_rate": args.caption_dropout_rate,
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
}
"""
if use_user_config:
# save metadata of multiple datasets
# NOTE: pack "ss_datasets" value as json one time
# or should also pack nested collections as json?
datasets_metadata = []
tag_frequency = {} # merge tag frequency for metadata editor
dataset_dirs_info = {} # merge subset dirs for metadata editor
for dataset in train_dataset_group.datasets:
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
dataset_metadata = {
"is_dreambooth": is_dreambooth_dataset,
"batch_size_per_device": dataset.batch_size,
"num_train_images": dataset.num_train_images, # includes repeating
"num_reg_images": dataset.num_reg_images,
"resolution": (dataset.width, dataset.height),
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
}
subsets_metadata = []
for subset in dataset.subsets:
subset_metadata = {
"img_count": subset.img_count,
"num_repeats": subset.num_repeats,
"color_aug": bool(subset.color_aug),
"flip_aug": bool(subset.flip_aug),
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
if subset.image_dir:
image_dir = os.path.basename(subset.image_dir)
subset_metadata["image_dir"] = image_dir
image_dir_or_metadata_file = image_dir
if is_dreambooth_dataset:
subset_metadata["class_tokens"] = subset.class_tokens
subset_metadata["is_reg"] = subset.is_reg
if subset.is_reg:
image_dir_or_metadata_file = None # not merging reg dataset
else:
metadata_file = os.path.basename(subset.metadata_file)
subset_metadata["metadata_file"] = metadata_file
image_dir_or_metadata_file = metadata_file # may overwrite
subsets_metadata.append(subset_metadata)
# merge dataset dir: not reg subset only
# TODO update additional-network extension to show detailed dataset config from metadata
if image_dir_or_metadata_file is not None:
# datasets may have a certain dir multiple times
v = image_dir_or_metadata_file
i = 2
while v in dataset_dirs_info:
v = image_dir_or_metadata_file + f" ({i})"
i += 1
image_dir_or_metadata_file = v
dataset_dirs_info[image_dir_or_metadata_file] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
dataset_metadata["subsets"] = subsets_metadata
datasets_metadata.append(dataset_metadata)
# merge tag frequency:
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
# なので、ここで複数datasetの回数を合算してもあまり意味はない
if ds_dir_name in tag_frequency:
continue
tag_frequency[ds_dir_name] = ds_freq_for_dir
metadata["ss_datasets"] = json.dumps(datasets_metadata)
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
else:
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
assert (
len(train_dataset_group.datasets) == 1
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
dataset = train_dataset_group.datasets[0]
dataset_dirs_info = {}
reg_dataset_dirs_info = {}
if use_dreambooth_method:
for subset in dataset.subsets:
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
else:
for subset in dataset.subsets:
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
metadata.update(
{
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_resolution": args.resolution,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_enable_bucket": bool(dataset.enable_bucket),
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
"ss_bucket_info": json.dumps(dataset.bucket_info),
}
)
"""
# add extra args
if args.network_args:
metadata["ss_network_args"] = json.dumps(net_kwargs)
# model name and hash
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
# make minimum metadata for filtering
minimum_metadata = {}
for key in train_util.SS_METADATA_MINIMUM_KEYS:
if key in metadata:
minimum_metadata[key] = metadata[key]
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_recorder = train_util.LossRecorder()
# del train_dataset_group
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
metadata["ss_training_finished_at"] = str(time.time())
metadata["ss_steps"] = str(steps)
metadata["ss_epoch"] = str(epoch_no)
metadata_to_save = minimum_metadata if args.no_metadata else metadata
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
# current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
# for step, batch in enumerate(train_dataloader):
for step in range(num_steps_per_epoch):
# current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
# make prompt
# Danbooru Tag と対象文字列を含むランダムなプロンプトを作成する
src_ids_list1 = []
src_ids_list2 = []
tgt_ids_list1 = []
tgt_ids_list2 = []
for bi in range(args.train_batch_size):
tag_count = random.randint(1, 20)
selected_tags = random.sample(tags, tag_count)
prompt_tags = ["1girl", src_str] + selected_tags # 必要ならここも変える
random.shuffle(prompt_tags) # selected_tags だけ shuffle する感じにすると実際のプロンプトに近くなるかも
src_prompt = ", ".join(prompt_tags)
tgt_prompt = ", ".join(prompt_tags).replace(src_str, tgt_str)
# print(f"src: {src_prompt}")
# print(f"tgt: {tgt_prompt}")
src_ids1 = (
tokenizers[0].encode(src_prompt, return_tensors="pt", truncation=True, padding="max_length").to("cpu")
)
src_ids2 = (
tokenizers[1].encode(src_prompt, return_tensors="pt", truncation=True, padding="max_length").to("cpu")
)
tgt_ids1 = (
tokenizers[0].encode(tgt_prompt, return_tensors="pt", truncation=True, padding="max_length").to("cpu")
)
tgt_ids2 = (
tokenizers[1].encode(tgt_prompt, return_tensors="pt", truncation=True, padding="max_length").to("cpu")
)
src_ids_list1.append(src_ids1)
src_ids_list2.append(src_ids2)
tgt_ids_list1.append(tgt_ids1)
tgt_ids_list2.append(tgt_ids2)
src_prompt_ids1 = torch.stack(src_ids_list1)
src_prompt_ids2 = torch.stack(src_ids_list2)
tgt_prompt_ids1 = torch.stack(tgt_ids_list1)
tgt_prompt_ids2 = torch.stack(tgt_ids_list2)
with torch.no_grad():
# Get the text embedding for conditioning
batch = {"input_ids": tgt_prompt_ids1, "input_ids2": tgt_prompt_ids2}
tgt_text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
with torch.set_grad_enabled(True), accelerator.autocast():
batch = {"input_ids": src_prompt_ids1, "input_ids2": src_prompt_ids2}
src_text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
src_emb1, src_emb2, src_pool = src_text_encoder_conds
tgt_emb1, tgt_emb2, tgt_pool = tgt_text_encoder_conds
loss1 = torch.nn.functional.mse_loss(src_emb1, tgt_emb1)
loss2 = torch.nn.functional.mse_loss(src_emb2, tgt_emb2)
loss_pool = torch.nn.functional.mse_loss(src_pool, tgt_pool)
# 必要ならここで重みづけする
loss = loss1 + loss2 + loss_pool
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
current_loss1 = loss1.detach().item()
current_loss2 = loss2.detach().item()
current_loss_pool = loss_pool.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss, "l1": current_loss1, "l2": current_loss2, "lp": current_loss_pool}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
# metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())
if is_main_process:
network = accelerator.unwrap_model(network)
accelerator.end_training()
if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sdxl_train_util.add_sdxl_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()
trainer.train(args)
@kohya-ss
Copy link
Author

容易に過学習するので、適当なepochで切り上げるか、適用時の重みで調整するかしてください。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment