Created
September 30, 2022 00:26
-
-
Save dsvensson/b9a1b19a1ad9974866070aee9ba34586 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/predict.py b/predict.py | |
index 4b9df0a..ecc9b74 100644 | |
--- a/predict.py | |
+++ b/predict.py | |
@@ -14,6 +14,9 @@ from image_to_image import ( | |
preprocess_mask, | |
) | |
+def dummy(images, **kwargs): | |
+ return images, False | |
+ | |
def patch_conv(**patch): | |
cls = torch.nn.Conv2d | |
init = cls.__init__ | |
@@ -41,6 +44,7 @@ class Predictor(BasePredictor): | |
cache_dir=MODEL_CACHE, | |
local_files_only=True, | |
).to("cuda") | |
+ self.pipe.safety_checker = dummy | |
@torch.inference_mode() | |
@torch.cuda.amp.autocast() | |
@@ -124,8 +128,8 @@ class Predictor(BasePredictor): | |
generator=generator, | |
num_inference_steps=num_inference_steps, | |
) | |
- if any(output["nsfw_content_detected"]): | |
- raise Exception("NSFW content detected, please try a different prompt") | |
+ #if any(output["nsfw_content_detected"]): | |
+ # raise Exception("NSFW content detected, please try a different prompt") | |
output_paths = [] | |
for i, sample in enumerate(output["sample"]): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment