Skip to content

Instantly share code, notes, and snippets.

@Saren-Arterius
Last active October 20, 2023 17:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Saren-Arterius/2014d9799361711d25ff33740990a7b8 to your computer and use it in GitHub Desktop.
Save Saren-Arterius/2014d9799361711d25ff33740990a7b8 to your computer and use it in GitHub Desktop.
diff --git a/modules/api/api.py b/modules/api/api.py
index e6edffe7..86f51fa3 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -337,6 +337,10 @@ class Api:
return script_args
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
+ txt2imgreq.width = round(txt2imgreq.width / 64) * 64
+ txt2imgreq.height = round(txt2imgreq.height / 64) * 64
+ print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt)
+
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
@@ -387,6 +391,10 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
+ img2imgreq.width = round(img2imgreq.width / 64) * 64
+ img2imgreq.height = round(img2imgreq.height / 64) * 64
+ print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt)
+
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
diff --git a/modules/processing.py b/modules/processing.py
index e124e7f0..7c13be80 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -233,6 +233,8 @@ class StableDiffusionProcessing:
self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c
+ shared.current_prompt = f'{self.prompt.lower()}|{self.width}*{self.height}*{self.batch_size}'
+
@property
def sd_model(self):
return shared.sd_model
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 5525cfbc..49daf1d4 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -1,8 +1,8 @@
import torch.nn
import ldm.modules.diffusionmodules.openaimodel
+import time
from modules import script_callbacks, shared, devices
-
unet_options = []
current_unet_option = None
current_unet = None
@@ -85,8 +85,17 @@ class SdUnet(torch.nn.Module):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
- if current_unet is not None:
- return current_unet.forward(x, timesteps, context, *args, **kwargs)
-
- return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
-
+ try:
+ if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt:
+ if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt:
+ raise Exception('LoRA unsupported in TRT UNet')
+ f = current_unet.forward(x, timesteps, context, *args, **kwargs)
+ return f
+ except Exception as e:
+ start = time.time()
+ print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt)
+ shared.sd_model.model.diffusion_model.to(devices.device)
+ shared.skip_unet_prompt = shared.current_prompt
+ print('[UNet] Used', time.time() - start, 'seconds')
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
\ No newline at end of file
diff --git a/modules/shared.py b/modules/shared.py
index 63661939..577bd100 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -85,3 +85,6 @@ list_checkpoint_tiles = shared_items.list_checkpoint_tiles
refresh_checkpoints = shared_items.refresh_checkpoints
list_samplers = shared_items.list_samplers
reload_hypernetworks = shared_items.reload_hypernetworks
+
+current_prompt = ''
+skip_unet_prompt = ''
\ No newline at end of file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment