Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created May 22, 2023 11:38
Show Gist options
  • Save laksjdjf/537f13ee32835f6395484e385f0d46b6 to your computer and use it in GitHub Desktop.
Save laksjdjf/537f13ee32835f6395484e385f0d46b6 to your computer and use it in GitHub Desktop.
import torch
import modules.scripts as scripts
import gradio as gr
from modules.processing import StableDiffusionProcessing, process_images
class Script(scripts.Script):
def __init__(self):
pass
def title(self):
return "any resolution"
def forward_hijack(self, target, index, height, width, hr_scale):
def forward(x):
target_height = height // 2**(index-1)
target_width = width // 2**(index-1)
# if hires fix
if x.shape[2] != int(height / (2 ** index)):
target_height = target_height * hr_scale
target_width = target_width * hr_scale
x = torch.nn.functional.interpolate(x, size=(target_height, target_width), mode="nearest")
if target.with_conv:
x = target.conv(x)
return x
return forward
def ui(self, is_img2img):
with gr.Group():
with gr.Row():
enable = gr.Checkbox(label="Enable", default=False)
return [enable]
def run(
self,
p: StableDiffusionProcessing,
enable: bool
):
if not enable:
return process_images(p)
origin_upsample_forward = []
for i in range(1,4):
target = p.sd_model.first_stage_model.decoder.up[i].upsample
origin_upsample_forward.append(target.forward)
target.forward = self.forward_hijack(target, i, p.height, p.width, p.hr_scale)
result = process_images(p)
for i in range(1,4):
target = p.sd_model.first_stage_model.decoder.up[i].upsample
target.forward = origin_upsample_forward[i-1]
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment