Skip to content

Instantly share code, notes, and snippets.

@Pangoraw
Last active August 29, 2022 14:29
Show Gist options
  • Save Pangoraw/07f21ae6ad25bf549db585e5f355ef2a to your computer and use it in GitHub Desktop.
Save Pangoraw/07f21ae6ad25bf549db585e5f355ef2a to your computer and use it in GitHub Desktop.
Stable diffusion txt2img using the CPU
From 4e32b5ebfc3bb54cabf192488e43e768781eeafb Mon Sep 17 00:00:00 2001
From: Paul Berg <paul.berg@univ-ubs.fr>
Date: Mon, 29 Aug 2022 16:28:55 +0200
Subject: [PATCH] Allow using the CPU for txt2img
---
configs/stable-diffusion/v1-inference.yaml | 2 ++
scripts/txt2img.py | 18 ++++++++++++------
2 files changed, 14 insertions(+), 6 deletions(-)
diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml
index d4effe5..b7239eb 100644
--- a/configs/stable-diffusion/v1-inference.yaml
+++ b/configs/stable-diffusion/v1-inference.yaml
@@ -68,3 +68,5 @@ model:
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ device: "cpu"
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
index 59c16a1..c49d450 100644
--- a/scripts/txt2img.py
+++ b/scripts/txt2img.py
@@ -45,7 +45,7 @@ def numpy_to_pil(images):
return pil_images
-def load_model_from_config(config, ckpt, verbose=False):
+def load_model_from_config(config, ckpt, device="cuda", verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
@@ -60,7 +60,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
- model.cuda()
+ model.to(device)
model.eval()
return model
@@ -226,6 +226,12 @@ def main():
choices=["full", "autocast"],
default="autocast"
)
+ parser.add_argument(
+ "--device",
+ type=str,
+ choices=["cuda", "cpu"],
+ default="cuda",
+ )
opt = parser.parse_args()
if opt.laion400m:
@@ -237,15 +243,15 @@ def main():
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
- model = load_model_from_config(config, f"{opt.ckpt}")
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ device = opt.device
+ model = load_model_from_config(config, f"{opt.ckpt}", device)
model = model.to(device)
if opt.plms:
- sampler = PLMSSampler(model)
+ sampler = PLMSSampler(model, device=device)
else:
- sampler = DDIMSampler(model)
+ sampler = DDIMSampler(model, device=device)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
--
2.34.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment