Skip to content

Instantly share code, notes, and snippets.

@helena-intel
Last active May 3, 2021 12:34
Show Gist options
  • Save helena-intel/a33905119d71e9a63ece9da3e475b4a0 to your computer and use it in GitHub Desktop.
Save helena-intel/a33905119d71e9a63ece9da3e475b4a0 to your computer and use it in GitHub Desktop.
import io
from enum import Enum
from typing import Optional
import cv2
import numpy as np
from openvino.inference_engine import IECore
from opyrator.components.types import FileContent
from PIL import Image, ImageDraw, ImageFont
from pydantic import BaseModel, Field
class ModelSelection(str, Enum):
HUMAN = "Human Segmentation"
SALIENT = "Salient Object Detection"
LITE = "Salient Object Detection (Lite)"
class ImageInput(BaseModel):
image: FileContent = Field(
..., description="The image with the foreground object to detect/segment"
)
background_image: Optional[FileContent] = Field(
None,
description="An optional background image. If a background image is "
"chosen, the background image will be added to the salient object from the input image",
)
caption: Optional[str] = Field(
None, description="An optional caption text to write on the result image"
)
model_selection: ModelSelection = Field(
...,
description="Select the U^2-Net model to use. Use Human segmentation to detect people. "
"The Lite salient model is faster but less accurate than the regular salient model.",
)
class ImageOutput(BaseModel):
input_image: FileContent = Field(
...,
mime_type="image/jpeg",
description="Input image",
)
result_image: FileContent = Field(
...,
mime_type="image/png",
description="U^2-Net result",
)
def remove_background(input: ImageInput) -> ImageOutput:
"""
Remove the background of an image and optionally add a new background. On mobile,
click on the ">" icon in the top left of the screen to add a background image.
"""
image = np.array(Image.open(io.BytesIO(input.image.as_bytes())))[:, :, :3]
if input.model_selection == ModelSelection.HUMAN:
ir_path = "saved_models/u2net_human_seg/u2net_human_seg.xml"
elif input.model_selection == ModelSelection.LITE:
ir_path = "saved_models/u2net_lite/u2net_lite.xml"
elif input.model_selection == ModelSelection.SALIENT:
ir_path = "saved_models/u2net_lite/u2net_lite.xml"
# Load network to Inference Engine
ie = IECore()
net_ir = ie.read_network(model=ir_path)
exec_net_ir = ie.load_network(network=net_ir, device_name="CPU")
# Get names of input and output layers
input_layer_ir = next(iter(exec_net_ir.input_info))
output_layer_ir = next(iter(exec_net_ir.outputs))
resized_image = cv2.resize(image, (512, 512))
# Convert the image shape to shape and data type expected by network
# for OpenVINO IR model
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0)
# Run inference on the input image...
res_ir = exec_net_ir.infer(inputs={input_layer_ir: input_image})
res_ir = res_ir[output_layer_ir]
# Resize the network result to the size of the input image
resized_result = cv2.resize(res_ir[0][0], (image.shape[1], image.shape[0]))
# Round the result values, and convert to an integer
# This creates a segmentation mask with values 0 and 1, where 0 values
# contain background pixels and 1 values contain foreground pixels
resized_result = np.rint(resized_result).astype(np.uint8)
# Remove the background pixels from the input image
bg_removed_result = image.copy()
bg_removed_result[resized_result < 1] = 255
new_image = bg_removed_result
if input.background_image is not None:
background_image = np.array(
Image.open(io.BytesIO(input.background_image.as_bytes()))
)[:, :, :3]
background_image = cv2.resize(
background_image, (image.shape[1], image.shape[0])
)
# Set all the foreground pixels from the result to 0
# in the background image and add the background-removed image
background_image[resized_result == 1] = 0
new_image += background_image
if input.caption is not None:
size = cv2.getTextSize(input.caption, cv2.FONT_HERSHEY_DUPLEX, 2, 8)[0]
cv2.putText(
new_image,
input.caption,
(10, size[1] + 10),
cv2.FONT_HERSHEY_DUPLEX,
2,
(0, 0, 0),
8,
)
cv2.putText(
new_image,
input.caption,
(10, size[1] + 10),
cv2.FONT_HERSHEY_DUPLEX,
2,
(255, 255, 255),
6,
)
pil_image = Image.fromarray(new_image)
with io.BytesIO() as output_f:
pil_image.save(output_f, format="PNG")
return ImageOutput(input_image=input.image, result_image=output_f.getvalue())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment