Created
March 30, 2023 13:40
-
-
Save yasushisakai/f14dd402d8aa4a58d622371d0cad4448 to your computer and use it in GitHub Desktop.
ControlNet Hack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
save_memory = False | |
image_file = 'C:\\Users\\yasushi\\shibuya.png' | |
save_dir = 'C:\\Users\\yasushi\\images' | |
repo_dir = 'C:\\Users\\yasushi\\code\\ControlNet' | |
a_prompt = 'city, birds eye view, best quality, insane details' | |
n_prompt = 'longbody, lowres, extra digit, fewer digits, cropped, low quality, worst quality' | |
# num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
num_samples = 1 | |
# image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) | |
image_resolution = 512 | |
# strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
strength = 1.0 | |
# guess_mode = gr.Checkbox(label='Guess Mode', value=False) | |
# guess mode is turned off | |
# detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1) | |
detect_resolution = 384 | |
# ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
ddim_steps = 20 | |
# scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
scale = 9 | |
# eta = gr.Number(label="eta (DDIM)", value=0.0) | |
# seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
# just to be a little deterministic | |
seed = 2147483647 / 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from share import * | |
import config | |
import sys | |
import cv2 | |
import einops | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from annotator.util import resize_image, HWC3 | |
from annotator.midas import MidasDetector | |
from cldm.model import create_model, load_state_dict | |
from cldm.ddim_hacked import DDIMSampler | |
def process( | |
input_image, | |
prompt, | |
a_prompt, | |
n_prompt, | |
num_samples, | |
image_resolution, | |
detect_resolution, | |
ddim_steps, | |
strength, | |
scale, | |
seed, | |
eta, | |
model, | |
ddim_sampler | |
): | |
with torch.no_grad(): | |
input_image = HWC3(input_image) | |
detected_map, _ = apply_midas( | |
resize_image(input_image, detect_resolution)) | |
detected_map = HWC3(detected_map) | |
img = resize_image(input_image, image_resolution) | |
H, W, C = img.shape | |
detected_map = cv2.resize( | |
detected_map, (W, H), interpolation=cv2.INTER_LINEAR) | |
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 | |
control = torch.stack([control for _ in range(num_samples)], dim=0) | |
control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
cond = {"c_concat": [control], "c_crossattn": [ | |
model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} | |
un_cond = {"c_concat": [control], "c_crossattn": [ | |
model.get_learned_conditioning([n_prompt] * num_samples)]} | |
shape = (4, H // 8, W // 8) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=True) | |
model.control_scales = [strength * | |
(0.825 ** float(12 - i)) for i in range(13)] | |
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
x_samples = model.decode_first_stage(samples) | |
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') | |
* 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
results = [x_samples[i] for i in range(num_samples)] | |
return [detected_map] + results | |
if __name__ == '__main__': | |
# init | |
# I know this is not the most efficient way to do this.... | |
apply_midas = MidasDetector() | |
model = create_model(f'{config.repo_dir}/models/cldm_v15.yaml').cpu() | |
model.load_state_dict(load_state_dict( | |
f'{config.repo_dir}/models/control_sd15_depth.pth', location='cuda')) | |
model = model.cuda() | |
ddim_sampler = DDIMSampler(model) | |
# collect information for the 'process' function | |
# check 'config.py' for the parameters | |
file_name = f'{config.save_dir}\\{sys.argv[1]}.png' | |
prompt = sys.argv[2] | |
print(prompt) | |
input_image = np.asarray(Image.open(config.image_file)) | |
result = process(input_image, prompt, config.a_prompt, config.n_prompt, config.num_samples, config.image_resolution, | |
config.detect_resolution, config.ddim_steps, config.strength, config.scale, config.seed, 0.0, model, ddim_sampler) | |
output_image = Image.fromarray(result[1]) | |
output_image.save(file_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is an almost rip-off of the ControlNet gradio_depth2image.py script. I'm using this to run it from the command line to execute it from the queue management service (https://github.com/yasushisakai/ultra_queue).