Skip to content

Instantly share code, notes, and snippets.

@AlexeyLugovoy
Last active June 11, 2024 19:33
Show Gist options
  • Save AlexeyLugovoy/2da4846db083d6ed81ecafe23695821e to your computer and use it in GitHub Desktop.
Save AlexeyLugovoy/2da4846db083d6ed81ecafe23695821e to your computer and use it in GitHub Desktop.
from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-large-512", cache_dir=CACHE_DIR)
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-large-512", cache_dir=CACHE_DIR)#.to(DEVICE)
# prepare image for the model
inputs = processor(images=img_pil, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=img_pil.size[::-1],
mode="bicubic",
align_corners=False,
)
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
formatted = cv2.bilateralFilter(formatted, 50, 90, 90)
img_depth = Image.fromarray(formatted)
controlnet = ControlNetModel.from_pretrained(
'diffusers/controlnet-depth-sdxl-1.0-small',
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
cache_dir=CACHE_DIR
).to(DEVICE)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
cache_dir=CACHE_DIR
).to(DEVICE)
prompt = ["New modern style of livingroom, warm color palette, detailed, 8k"]
negative_prompt = ["low quality, bad quality, sketches"]
img_gen = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=img_pil,
control_image=img_depth,
strength=0.9,
eta=0.0,
num_inference_steps=150,
controlnet_conditioning_scale=0.4,
guidance_scale=12,
num_images_per_prompt=2,
generator = torch.Generator(DEVICE).manual_seed(2)
)
img_gen.images[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment