Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created September 27, 2023 16:21
Show Gist options
  • Save laksjdjf/6a66646954578555b9c24f93558aa194 to your computer and use it in GitHub Desktop.
Save laksjdjf/6a66646954578555b9c24f93558aa194 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from rembg import remove
class RembgMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
}
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "rembg_mask"
CATEGORY = "image/preprocessors"
def rembg_mask(self, image):
rem_tensors = []
masks = []
for i in range(image.shape[0]):
rem_image = remove(np.array(torch.clip((255. * image[i]), 0, 255).round()).astype(np.uint8))
rem_tensor = torch.tensor(rem_image[:,:,:3].astype(np.float32))/255
mask = torch.tensor(rem_image[:,:,3].astype(np.float32))/255
rem_tensors.append(rem_tensor)
masks.append(mask)
rem_tensors = torch.stack(rem_tensors)
masks = torch.stack(masks)
masks = masks if masks.shape[0] > 1 else masks[0]
return (rem_tensors, masks)
class MakeNoise:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"width": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 1, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
}),
"height": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 1, #Slider's step
"display": "number" # Cosmetic only: display as "number" or "slider"
}),
"grayscale":(["false", "true"], )
}
}
RETURN_TYPES = ("IMAGE", )
FUNCTION = "make_noise"
CATEGORY = "image"
def make_noise(self, width, height, grayscale):
if grayscale == "true":
return (torch.rand((1, height, width, 1)).repeat(1,1,1,3),)
else:
return (torch.rand((1, height, width, 3)),)
class ImageMasking:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"mask": ("MASK", ),
"weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "display": "number"}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "image_masking"
CATEGORY = "image"
def image_masking(self, image, mask, weight):
return (image * mask.unsqueeze(0).unsqueeze(3) * weight + image * (1-weight),)
NODE_CLASS_MAPPINGS = {
"RembgMask": RembgMask,
"MakeNoise": MakeNoise,
"ImageMasking": ImageMasking
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RembgMask": "RembgMask",
"MakeNoise": "MakeNoise",
"ImageMasking": "ImageMasking"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment