Last active
December 21, 2024 19:40
-
-
Save jbarrow/c35279cbed578eeba4e1253dc6907c8c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing_extensions import TypedDict, Literal | |
| from PIL import Image, ImageDraw | |
| import google.generativeai as genai | |
| import argparse | |
| import json | |
| import os | |
| def draw_bounding_box( | |
| image: Image.Image, type: str, bbox: list[int] | |
| ) -> Image.Image: | |
| width, height = image.size | |
| draw = ImageDraw.Draw(image) | |
| ymin, xmin, ymax, xmax = [coord / 1000 for coord in bbox] | |
| box_xmin = int(xmin * width) | |
| box_ymin = int(ymin * height) | |
| box_xmax = int(xmax * width) | |
| box_ymax = int(ymax * height) | |
| draw.rectangle( | |
| [(box_xmin, box_ymin), (box_xmax, box_ymax)], | |
| outline="blue", | |
| width=2, | |
| ) | |
| draw.text( | |
| (box_xmin, box_ymin), | |
| type, | |
| fill="white", | |
| ) | |
| return image | |
| class TeaSet(TypedDict): | |
| type: Literal["Teacup", "Teapot"] | |
| bounding_box: list[int] | |
| def main(image_path: str, size: float = 1024) -> None: | |
| image = Image.open(image_path) | |
| # I've found that localization works better when the image is smaller | |
| image.thumbnail((size, size)) | |
| image.save("/Users/jbarrow/vaults/DocAI/Images/gemini-in.jpeg") | |
| model = genai.GenerativeModel("gemini-2.0-flash-exp") | |
| result = model.generate_content( | |
| [ | |
| """Find all the teacups and teapots in the image. | |
| Return your answer as a list of JSON objects with the type and bounding box. | |
| Return the bounding box in [ymin, xmin, ymax, xmax] format.""", | |
| image, | |
| ], | |
| generation_config=genai.GenerationConfig( | |
| response_mime_type="application/json", | |
| response_schema=list[TeaSet], | |
| ), | |
| ) | |
| objects = json.loads(str(result.candidates[0].content.parts[0].text)) | |
| print(json.dumps(objects, indent=2)) | |
| for object in objects: | |
| details = TeaSet(**object) | |
| print(json.dumps(details, indent=2)) | |
| if "bounding_box" in details: | |
| image = draw_bounding_box( | |
| image, details["type"], details["bounding_box"] | |
| ) | |
| image.show() | |
| if __name__ == "__main__": | |
| genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("image_path", help="Path to image file to process") | |
| args = parser.parse_args() | |
| main(args.image_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment