Skip to content

Instantly share code, notes, and snippets.

@dsvensson
Created September 30, 2022 00:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsvensson/b9a1b19a1ad9974866070aee9ba34586 to your computer and use it in GitHub Desktop.
Save dsvensson/b9a1b19a1ad9974866070aee9ba34586 to your computer and use it in GitHub Desktop.
diff --git a/predict.py b/predict.py
index 4b9df0a..ecc9b74 100644
--- a/predict.py
+++ b/predict.py
@@ -14,6 +14,9 @@ from image_to_image import (
preprocess_mask,
)
+def dummy(images, **kwargs):
+ return images, False
+
def patch_conv(**patch):
cls = torch.nn.Conv2d
init = cls.__init__
@@ -41,6 +44,7 @@ class Predictor(BasePredictor):
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
+ self.pipe.safety_checker = dummy
@torch.inference_mode()
@torch.cuda.amp.autocast()
@@ -124,8 +128,8 @@ class Predictor(BasePredictor):
generator=generator,
num_inference_steps=num_inference_steps,
)
- if any(output["nsfw_content_detected"]):
- raise Exception("NSFW content detected, please try a different prompt")
+ #if any(output["nsfw_content_detected"]):
+ # raise Exception("NSFW content detected, please try a different prompt")
output_paths = []
for i, sample in enumerate(output["sample"]):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment