Skip to content

Instantly share code, notes, and snippets.

@td2sk
Last active September 6, 2022 06:10
Show Gist options
  • Save td2sk/33044aabd8cc080137bceb3f87f650ef to your computer and use it in GitHub Desktop.
Save td2sk/33044aabd8cc080137bceb3f87f650ef to your computer and use it in GitHub Desktop.
prompt correction patch for Stable Diffusion txt2.img.py

Stable Diffusion 意味補正計算パッチ

Stable Diffusion のプロンプト上で足し算や引き算ができるようオプションを追加するパッチ

使い方例

ピラミッドを20%だけ日本に寄せる

# 計算式は ピラミッド + 20% * (日本 - エジプト)
# プロンプトは pyramid + 0.2 (japan - egypt) となる

# txt2img の実行コマンドには、通常通りプロンプトを渡し、追加の項は --prompt-correction で指定する
# 追加の項は --prompt-correction 'プロンプト::重み' と、コロン2つを挟んでプロンプトと重みを書く
python scripts/txt2img.py --plms --n_samples 1 --n_iter 1 (その他オプション略) \
  --prompt "pyramid" \
  --prompt-correction 'egypt::-0.2' \
  --prompt-correction 'japan::0.2'

注意

プロンプトの調整分は必ず差分の形で用いること (例: 0.2 * (日本 - エジプト) )

  • 差分でないと unconditional_conditioning の調整処理の辻褄が合わなくなるため
diff --git "a/.\\stable-diffusion\\scripts\\txt2img.py" "b/.\\stable-diffusion\\scripts\\txt2img_mod.py"
index 384f38f..c33f1a2 100644
--- "a/.\\stable-diffusion\\scripts\\txt2img.py"
+++ "b/.\\stable-diffusion\\scripts\\txt2img_mod.py"
@@ -177,6 +177,10 @@ def main():
choices=["full", "autocast"],
default="autocast"
)
+ parser.add_argument(
+ '--prompt-correction',
+ action='append',
+ )
opt = parser.parse_args()
if opt.laion400m:
@@ -189,6 +193,8 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
+ # use float16 model (for VRAM 8GB environment)
+ model = model.to(torch.float16)
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -240,6 +246,11 @@ def main():
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
+ for pw in opt.prompt_correction:
+ pw = pw.split('::')
+ p, weight = "::".join(pw[:-1]), float(pw[-1])
+ c += weight * \
+ model.get_learned_conditioning(list(p))
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
@td2sk
Copy link
Author

td2sk commented Sep 6, 2022

All source code and patches that I have published on gists that do not specify a license are under the CC0 license.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment