Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active January 21, 2023 09:35
Show Gist options
  • Save takuma104/58fbd99a02006c67dbb9ff968c7417f2 to your computer and use it in GitHub Desktop.
Save takuma104/58fbd99a02006c67dbb9ff968c7417f2 to your computer and use it in GitHub Desktop.
Reproducible generation with xFormers for SD2.x or variant models

Reproducible generation with xFormers for SD2.x or variant models

More details and background: Generating (almost) reproducible pictures using Diffusers with xFormers

Limitation

Unfortunately, Flash Attention won't accept SD1.x attention shapes, but will accept SD2.x or variants.

Test Result

Automatic1111 WebUI (API) result:

$  python launch.py --xformers --listen --api

fig___xformers_only

$  python launch.py --xformers --xformers-flash-attention --listen --api

fig___xformers_with___xformers_flash_attention

Performance

without --xformers-flash-attention option:

100%|█████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.89it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.30it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.23it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.30it/s]
Total progress: 80it [00:19,  4.18it/s]

with --xformers-flash-attention option:

100%|█████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.91it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.35it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.26it/s]
100%|█████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.33it/s]
Total progress: 80it [00:19,  4.22it/s]
import json
import requests
import io
import base64
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys
from pprint import pprint
plt.rcParams["figure.figsize"] = (8, 6)
plt.rcParams['figure.facecolor'] = 'white'
url = "http://127.0.0.1:7860"
# prompt by @p1atdev_art
# https://twitter.com/p1atdev_art/status/1616557167087353856
prompt = """masterpiece, best quality, high quality, 1girl, sun hat,
frilled white dress, looking at viewer, summer, sky, beach, semi-realistic"""
neg_prompt = """nsfw, worst quality, low quality, medium quality, deleted,
lowres, bad anatomy, bad hands, text, error, missing fingers, extra digits,
fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry,"""
t2i_payload = {
"prompt": prompt,
"negative_prompt": neg_prompt,
"seed": 0,
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 11,
"width": 512,
"height": 512,
"sampler_index": "DPM++ 2S a Karras",
"model": "plat-v1-3-1-fp16", # https://huggingface.co/p1atdev/pd-archive/tree/main
}
def t2i_repeate(fn_prefix='run', seed=0, runs=4):
t2i_payload['seed'] = seed
infotext = ''
for run in range(runs):
response = requests.post(url=f'{url}/sdapi/v1/txt2img', json=t2i_payload)
r = response.json()
infotext = json.loads(r['info'])['infotexts'][0]
i = r['images'][0]
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
fn = f"{fn_prefix}_{run}.png"
image.save(fn)
print(f'{fn} generated.')
return infotext
def calc_difference_image(image0, image1, normalize=False):
image0 = np.array(image0, dtype=np.int32)
image1 = np.array(image1, dtype=np.int32)
abs_diff = np.abs(image0 - image1)
if normalize:
abs_diff = abs_diff / abs_diff.max() * 255
return Image.fromarray(abs_diff.astype(np.uint8))
def render_figure(title, footer, fn_prefix, fig_fn, runs=4):
def plot_normal_row(axs, fn_prefix):
for i, ax in enumerate(axs):
ax.set_title(f'Run#{i}')
ax.imshow(Image.open(f'{fn_prefix}_{i}.png'))
def plot_diff_row(axs, fn_prefix):
for i, ax in enumerate(axs):
ax.set_title(f'Run#{i} - Run#0')
ref = Image.open(f'{fn_prefix}_{0}.png')
res = Image.open(f'{fn_prefix}_{i}.png')
ax.imshow(calc_difference_image(ref, res))
fig, axs = plt.subplots(2, runs)
for ax in axs.flatten():
ax.set_aspect('equal', 'box')
ax.axis('off')
plot_normal_row(axs[0], fn_prefix)
plot_diff_row(axs[1], fn_prefix)
fig.text(.02, .05, footer, wrap=True, horizontalalignment='left', fontsize=8)
fig.suptitle(title, fontsize=16)
fig.tight_layout(rect=[0,0.11,1,1])
fig.savefig(fig_fn, dpi=72*2)
if __name__ == '__main__':
title = sys.argv[1]
fn_prefix='run'
runs = 4
footer = t2i_repeate(fn_prefix=fn_prefix, seed=7, runs=runs)
fig_fn = f"fig_{title.replace(' ', '_').replace('-', '_')}.png"
render_figure(title=title, footer=footer, fn_prefix=fn_prefix, fig_fn=fig_fn, runs=runs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment