Skip to content

Instantly share code, notes, and snippets.

@pablomm
Created March 26, 2024 11:36
Show Gist options
  • Save pablomm/984994d3a5b671e7228f308158534fd6 to your computer and use it in GitHub Desktop.
Save pablomm/984994d3a5b671e7228f308158534fd6 to your computer and use it in GitHub Desktop.
Img2Img Ovam example
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from ovam import StableDiffusionHooker
from ovam.utils import set_seed
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from PIL import Image
device = "cuda"
model_id_or_path = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe = pipe.to(device)
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
prompt = "A fantasy landscape, trending on artstation"
# Generate image storing the internal attentions
set_seed(1)
with StableDiffusionHooker(pipe) as hooker:
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]
# Evaluator of attention
ovam_evaluator = hooker.get_ovam_callable(expand_size=(512, 512))
with torch.no_grad():
attention_maps = ovam_evaluator("castle")
attention_maps = attention_maps[0].cpu().numpy() # (3, 512, 512)
attribution_prompt = "A castle"
with torch.no_grad():
attention_maps = ovam_evaluator(attribution_prompt)
attention_maps = attention_maps[0].cpu().numpy() # (3, 512, 512)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(10, 4))
ax0.imshow(init_image)
ax1.imshow(image)
ax2.imshow(image)
# Normalize attentions [0, 1] to use as alpha in plot
castle_attention = attention_maps[2] # Castle is the third word <SoT> a castle<EoT>
castle_attention = (castle_attention - castle_attention.min())
castle_attention = castle_attention / castle_attention.max()
ax2.imshow(castle_attention, alpha=castle_attention.astype(float), cmap='jet')
ax0.set_title("Init image")
ax0.axis('off')
ax1.set_title("Generated image")
ax1.axis('off')
ax2.set_title("+ Castle attentions")
ax2.axis('off')
fig.savefig("example_img2img.jpg")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment