Created
May 22, 2023 11:38
-
-
Save laksjdjf/537f13ee32835f6395484e385f0d46b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import modules.scripts as scripts | |
import gradio as gr | |
from modules.processing import StableDiffusionProcessing, process_images | |
class Script(scripts.Script): | |
def __init__(self): | |
pass | |
def title(self): | |
return "any resolution" | |
def forward_hijack(self, target, index, height, width, hr_scale): | |
def forward(x): | |
target_height = height // 2**(index-1) | |
target_width = width // 2**(index-1) | |
# if hires fix | |
if x.shape[2] != int(height / (2 ** index)): | |
target_height = target_height * hr_scale | |
target_width = target_width * hr_scale | |
x = torch.nn.functional.interpolate(x, size=(target_height, target_width), mode="nearest") | |
if target.with_conv: | |
x = target.conv(x) | |
return x | |
return forward | |
def ui(self, is_img2img): | |
with gr.Group(): | |
with gr.Row(): | |
enable = gr.Checkbox(label="Enable", default=False) | |
return [enable] | |
def run( | |
self, | |
p: StableDiffusionProcessing, | |
enable: bool | |
): | |
if not enable: | |
return process_images(p) | |
origin_upsample_forward = [] | |
for i in range(1,4): | |
target = p.sd_model.first_stage_model.decoder.up[i].upsample | |
origin_upsample_forward.append(target.forward) | |
target.forward = self.forward_hijack(target, i, p.height, p.width, p.hr_scale) | |
result = process_images(p) | |
for i in range(1,4): | |
target = p.sd_model.first_stage_model.decoder.up[i].upsample | |
target.forward = origin_upsample_forward[i-1] | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment