Skip to content

Instantly share code, notes, and snippets.

@catboxanon
Last active August 17, 2023 02:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catboxanon/69ce64e0389fa803d26dc59bb444af53 to your computer and use it in GitHub Desktop.
Save catboxanon/69ce64e0389fa803d26dc59bb444af53 to your computer and use it in GitHub Desktop.
import gradio as gr
import numpy as np
import torch
from PIL import Image
from scipy.ndimage import gaussian_filter
from skimage.transform import resize
import modules.scripts as scripts
from modules import devices, script_callbacks, shared
from modules.processing import StableDiffusionProcessingTxt2Img
from modules.script_callbacks import ExtraNoiseParams
from modules.scripts import AlwaysVisible
NAME = "Extra Noise Mask"
MASK = None
BLUR_RADIUS = 0
class Script(scripts.Script):
def title(self):
return NAME
def ui(self, _):
with gr.Group():
with gr.Accordion(NAME, open=False):
enabled = gr.Checkbox(label="Enabled", value=False)
return_mask = gr.Checkbox(label="Return mask", value=False)
blur_radius = gr.Slider(label="Blur radius", value=0, minimum=0, maximum=32, step=1)
canvas = gr.Image(
image_mode="RGB",
source='upload',
tool='sketch',
type='numpy',
height=768,
show_label=False,
show_download_button=False,
interactive=True,
brush_color=shared.opts.data.get('img2img_inpaint_mask_brush_color', '#FFFFFF'), # type: ignore
brush_radius=128,
)
return [
enabled,
return_mask,
canvas,
blur_radius,
]
def show(self, is_img2img):
if not is_img2img:
return AlwaysVisible
def process(self,
p: StableDiffusionProcessingTxt2Img,
enabled: bool,
return_mask: bool,
canvas: dict,
blur_radius: float,
):
global MASK, BLUR_RADIUS
if enabled:
MASK = canvas["mask"] # type: ignore
p._extra_noise_mask = canvas["mask"] # type: ignore
BLUR_RADIUS = blur_radius
if not MASK.any():
return
p.extra_generation_params.update({
f'{NAME} enabled': enabled,
f'{NAME} blur radius': blur_radius,
})
else:
MASK = None
BLUR_RADIUS = 0
def postprocess(self, p, processed, *args):
if hasattr(p, '_extra_noise_mask') and args[1]:
processed.images.extend([Image.fromarray(p._extra_noise_mask)])
def on_extra_noise(params: ExtraNoiseParams):
global MASK, BLUR_RADIUS
noise = params.noise
if MASK is not None and MASK.any():
MASK = resize(MASK, (noise.shape[2], noise.shape[3]))
MASK = gaussian_filter(MASK, sigma=BLUR_RADIUS)
MASK = MASK.mean(axis=-1)
MASK = torch.from_numpy(MASK).unsqueeze(0).unsqueeze(0).repeat(1,4,1,1).to(devices.device)
params.noise = noise * MASK
script_callbacks.on_extra_noise(on_extra_noise)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment