Skip to content

Instantly share code, notes, and snippets.

@ritwikraha
Created December 30, 2023 20:24
Show Gist options
  • Save ritwikraha/804b6d9b835c6ee16a0ccf4fd5cffe95 to your computer and use it in GitHub Desktop.
Save ritwikraha/804b6d9b835c6ee16a0ccf4fd5cffe95 to your computer and use it in GitHub Desktop.
Gradio Pixel Selector Utility
import gradio as gr
import numpy as np
import torch
from PIL import Image
'''
TODOs:
- Fetch the SAM model
- Fetch the inpainting model
- Initialize the pipeline
- Create the mask_generator from a SAM or other similar model
- Create relevant functions for inpainting
Reference: Abhishek Thakur's YouTube Video: https://www.youtube.com/watch?v=CERvlvUvVEI&t=764s
'''
# Initialize a Gradio demo for pixel selection
with gr.Blocks() as demo:
# Display a title using Markdown
gr.Markdown("# Pixel Selector using Gradio")
# Define a state to store selected pixels
selected_pixels = gr.State([])
# Create a row for image inputs
with gr.Row():
input_img = gr.Image(label="Input")
mask_img = gr.Image(label="Mask", interactive=False)
seg_img = gr.Image(label="Segmentation", interactive=False)
output_img = gr.Image(label="Output", interactive=False)
# Create a row for text input
with gr.Row():
prompt_text = gr.Textbox(lines=1, label="Prompt")
# Create a row for buttons
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
# Define a function to generate a mask based on selected pixels
def generate_mask(image, selected_pixels, event: gr.SelectData):
selected_pixels.append(event.index)
predictor.set_image(image)
input_point = np.array(selected_pixels)
input_label = np.ones(input_point.shape[0])
mask, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
# Clear torch cache
torch.cuda.empty_cache()
mask = Image.fromarray(mask[0, :, :])
segmentations = mask_generator.generate(image)
boolean_masks = [s["segmentation"] for s in segmentations]
final_segmentation = np.zeros((boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8)
torch.cuda.empty_cache()
return mask, final_segmentation
# Define a function to clear all selections and inputs
def clear_selection(selected_pixels, input_img, mask_img, seg_img, output_img, prompt_text):
selected_pixels = []
img = None
mask = None
seg = None
out = None
prompt = ""
neg_prompt = ""
return img, mask, seg, out, prompt, neg_prompt
# Launch the Gradio demo
if __name__ == "__main__":
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment