Skip to content

Instantly share code, notes, and snippets.

@yhs0602
Created February 16, 2024 09:22
Show Gist options
  • Save yhs0602/cb19765c7c572c0bd1767c7886302c4a to your computer and use it in GitHub Desktop.
Save yhs0602/cb19765c7c572c0bd1767c7886302c4a to your computer and use it in GitHub Desktop.
import cv2
import io
import numpy as np
import torch
from u2net import U2Net
from yolov8 import detect_objects
from diffusion import StableDiffusionInpaintPipeline
def generate_mask(image_path, bboxes, u2net_path):
# 이미지 읽기
image = cv2.imread(image_path)
# U2-Net 모델 로드
u2net = U2Net(pretrained_model=u2net_path)
# 마스크 생성
mask = u2net.predict(image)
# 객체 영역만 남기기
for bbox in bboxes:
x1, y1, x2, y2 = bbox
mask[y1:y2, x1:x2] = 1
return mask
def inpaint_image(image_path, mask, prompt, diffusion_path):
# 이미지 읽기
image = cv2.imread(image_path)
# StableDiffusionInpaintPipeline 모델 로드
diffusion = StableDiffusionInpaintPipeline.from_pretrained(diffusion_path)
# 이미지 Inpainting
inpainted_image = diffusion.inpaint(image, mask, prompt)
return inpainted_image
def main():
# 설정 파일 읽기
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
# 이미지 경로, 객체 이름, Prompt 입력
image_path = input("이미지 경로를 입력하세요: ")
object_name = input("객체 이름을 입력하세요: ")
prompt = input("Prompt를 입력하세요: ")
# 객체 검색
bboxes = detect_objects(image_path, config["model_path"], object_name)
# 마스크 생성
mask = generate_mask(image_path, bboxes, config["u2net_path"])
# 이미지 Inpainting
inpainted_image = inpaint_image(image_path, mask, prompt, config["diffusion_path"])
# 결과 이미지 저장
cv2.imwrite("output.jpg", inpainted_image)
print("Inpainting 완료!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment