Skip to content

Instantly share code, notes, and snippets.

@tsvikas
Last active July 4, 2024 01:38
Show Gist options
  • Save tsvikas/186015d1d085e0e01e1e5170d54c2b8c to your computer and use it in GitHub Desktop.
Save tsvikas/186015d1d085e0e01e1e5170d54c2b8c to your computer and use it in GitHub Desktop.
generate images with dall-e
"""
Generate images using Dall-E, save the parameters and the output easily to file.
Usage:
```python
import os; os.environ["OPENAI_API_KEY"] = "SECRET_KEY" # set your API key
from generate_images import GeneratedImage, GeneratedImagesFile # import this code
img = GeneratedImage.generate("Astronaut") # generate an image
img # display the generated image + metadata in Jupyter
img.save_image("astronaut.png") # save a specific image
astronaut_images = GeneratedImagesFile("astronaut.jsonl") # load a file with many images
astronaut_images.append(img) # add a generated image to the file
astronaut_images.generate("Cool astronaut") # or generate and add to the file with one function
astronaut_images.generate_many("Psychadelic astronaut", n=10) # you can generate and add more than one image
astronaut_images # display thumbnails in Jupyter
astronaut_images[0] # access a specific image
astronaut_images[1, -1] # display thumbnails for a subset of images
astronaut_images.select(1, -1).display() # display a subset of images
astronaut_images.select(1, -1).copy_to("something.jsonl") # copy a subset of images
```
""" # noqa: E501
# TODO: improve docs, maybe with copilot / claude
# TODO: add testing, maybe with mock
import base64
import dataclasses
import functools
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import Any, Literal, Self
import jsonlines
import openai
from IPython.display import Markdown, display
# TODO: if no API KEY, make it read-only.
client = openai.OpenAI()
ModelType = Literal["dall-e-2", "dall-e-3"]
SizeType = Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
QualityType = Literal["standard", "hd"] | None
StyleType = Literal["vivid", "natural"] | None
def _check_valid_values(
name: str, value: str | None, valid_values: list[str | None], model: str
) -> None:
if value not in valid_values:
raise ValueError(
f"Invalid {name} for model {model!r}. "
f"Expected one of: {valid_values}. "
f"Received: {value!r}"
)
def _image_b64_to_html(image_b64: str, width: int) -> str:
img_src = f"data:image/png;base64,{image_b64}"
style = f"width:{width}px; display:inline-block; margin-right: 10px;"
return f'<img src="{img_src}" style="{style}"/>'
@dataclasses.dataclass
class GeneratedImage:
"""
The generated image from OpenAI.
API documentation: https://platform.openai.com/docs/api-reference/images/create
pricing details: https://openai.com/api/pricing/
"""
prompt: str | None
revised_prompt: str | None
model: ModelType
size: SizeType
quality: QualityType
style: StyleType
image_b64: str = dataclasses.field(repr=False)
@property
def image_bytes(self) -> bytes:
"""Convert the image to bytes."""
return base64.b64decode(self.image_b64)
def save_image(self, filename: Path | str) -> None:
"""
Save the image to file.
The API does not specify which format he uses, but it seems to return PNG.
"""
Path(filename).write_bytes(self.image_bytes)
def _repr_markdown_(self) -> str:
header_fields = ["prompt", "revised_prompt"]
detail_fields = ["model", "size", "quality", "style"]
header_markdowns = [f"**{fld}**: {getattr(self, fld)}" for fld in header_fields]
# TODO: add max width
image_markdown = f"![Image](data:image/png;base64,{self.image_b64})"
details_markdown = "**details**: " + " | ".join(
[
getattr(self, fld)
for fld in detail_fields
if getattr(self, fld) is not None
]
)
return " \n".join([*header_markdowns, details_markdown, image_markdown])
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
data = dataclasses.asdict(self)
return data
@classmethod
def from_dict(cls, d: Mapping) -> Self:
"""Convert from dictionary."""
return cls(**d)
@classmethod
def generate( # noqa: PLR0913
cls,
prompt: str,
*,
model: ModelType = "dall-e-3",
size: SizeType = "1024x1024",
quality: QualityType = None,
style: StyleType = None,
use_exact_prompt: bool = False,
) -> Self:
"""
Create an image given a prompt.
use_exact_prompt: will add a OpenAI recommendad pre-prompt, to prevent revising
the prompt.
"""
if model == "dall-e-2":
_check_valid_values(
"size", size, ["256x256", "512x512", "1024x1024"], model
)
_check_valid_values("quality", quality, [None], model)
_check_valid_values("style", style, [None], model)
elif model == "dall-e-3":
_check_valid_values(
"size", size, ["1024x1024", "1792x1024", "1024x1792"], model
)
if quality is None:
quality = "standard"
if style is None:
style = "vivid"
else:
raise ValueError("Unsupported model")
pre_prompt = (
"I NEED to test how the tool works with extremely simple prompts. "
"DO NOT add any detail, just use it AS-IS: "
if use_exact_prompt
else ""
)
prompt = pre_prompt + prompt
response = client.images.generate(
prompt=prompt,
model=model,
n=1,
quality=quality,
response_format="b64_json",
size=size,
style=style,
)
response_data = dict(response.data[0])
# see https://platform.openai.com/docs/guides/images/prompting
return cls(
image_b64=response_data["b64_json"],
prompt=prompt,
revised_prompt=response_data["revised_prompt"],
model=model,
size=size,
quality=quality,
style=style,
)
def thumbnail(self, width: int = 128) -> Markdown:
"""Return a markdown object with thumbnail for the image."""
# TODO: maybe lower the resolution, to help with filesize
return Markdown(_image_b64_to_html(self.image_b64, width))
def display(self) -> None:
"""Display the image in jupyter."""
display(self)
class GeneratedImages:
"""A list of generated images."""
def __init__(self, images: Iterable[GeneratedImage]):
self._images = list(images)
@functools.singledispatchmethod
def __getitem__(self, index): # noqa: ANN001
raise TypeError(f"Invalid index type: {type(index)}")
@__getitem__.register(slice)
def _(self, index: slice) -> "GeneratedImages":
return GeneratedImages(self._images[index])
@__getitem__.register(Iterable)
def _(self, index: Iterable[int]) -> "GeneratedImages":
if isinstance(index, str | bytes):
raise TypeError(f"Invalid index type: {type(index)}")
return GeneratedImages([self._images[i] for i in index])
@__getitem__.register(int)
def _(self, index: int) -> GeneratedImage:
return self._images[index]
def select(self, *indexes: tuple[int]) -> "GeneratedImages":
"""Return a subset of images."""
return self[indexes]
def __repr__(self) -> str:
return f"{type(self).__name__}({self._images!r})"
def thumbnails(self, width: int = 128) -> Markdown:
"""Return a markdown object with thumbnails for the images."""
# TODO: add alt text for index / prompt
return Markdown("\n".join(img.thumbnail(width).data for img in self._images))
def _repr_markdown_(self) -> str:
return self.thumbnails(width=128).data
def copy_to(self, filename: Path | str) -> "GeneratedImagesFile":
"""Copy the images to another file."""
other = GeneratedImagesFile(filename)
return other.extend(self._images)
def display(self) -> None:
"""Display all images."""
for img in self._images:
img.display()
class GeneratedImagesFile(GeneratedImages):
"""
A jsonlines file contains several generated images.
Genereated images can be appended or generated directly.
"""
def __init__(self, filename: Path | str):
self.filename = Path(filename)
if self.filename.exists():
with jsonlines.open(self.filename, "r") as reader:
images = [GeneratedImage.from_dict(data) for data in reader]
else:
images = []
super().__init__(images)
def __repr__(self) -> str:
return f"{type(self).__name__}(filename={self.filename})"
def _repr_markdown_(self) -> str:
return (
# TODO: the filename part is not working with jupyter
# f"**filename**: {self.filename}\n " +
super()._repr_markdown_()
)
def overwrite(self, images: Iterable[GeneratedImage]) -> Self:
"""Overwrite the files with images."""
# TODO: add backup before?
self._images = list(images)
with jsonlines.open(self.filename, "w") as writer:
for img in self._images:
writer.write(img.to_dict())
return self
def remove_last(self) -> Self:
"""
Remove the last image from the file.
It is a non effective way, since it rewrites all images.
"""
return self.overwrite(self._images[:-1])
def append(self, img: GeneratedImage) -> Self:
"""Add an image to the file."""
with jsonlines.open(self.filename, mode="a") as writer:
writer.write(img.to_dict())
print(f"Image saved to index {len(self._images)}")
self._images.append(img)
return self
def extend(self, imgs: Iterable[GeneratedImage]) -> Self:
"""Add several images to the file."""
for img in imgs:
self.append(img)
return self
def generate( # noqa: PLR0913
self,
prompt: str,
*,
model: ModelType = "dall-e-3",
size: SizeType = "1024x1024",
quality: QualityType = None,
style: StyleType = None,
use_exact_prompt: bool = False,
) -> GeneratedImage:
"""Create an image given a prompt, and save it to the file."""
img = GeneratedImage.generate(
prompt=prompt,
model=model,
size=size,
quality=quality,
style=style,
use_exact_prompt=use_exact_prompt,
)
self.append(img)
return img
def generate_many( # noqa: PLR0913
self,
prompt: str,
n: int = 1,
*,
model: ModelType = "dall-e-3",
size: SizeType = "1024x1024",
quality: QualityType = None,
style: StyleType = None,
use_exact_prompt: bool = False,
) -> list[GeneratedImage]:
"""Create several images with the same prompt, and save them to the file."""
return [
self.generate(
prompt=prompt,
model=model,
size=size,
quality=quality,
style=style,
use_exact_prompt=use_exact_prompt,
)
for _i in range(n)
]
def generate_or_load(
self, index: int | None, prompt: str | None = None, **kwargs: dict[str, Any]
) -> GeneratedImage:
"""
Generate and save an image, or load it from the file.
An helper function that is used to verify that an image is not created twice
"""
# TODO: maybe instead search for the prompt in the file?
if index is None:
if prompt is None:
raise ValueError("A prompt is required when generating a new image.")
return self.generate(prompt=prompt, **kwargs)
print(f"Image loaded from index {index}")
img = self._images[index]
if prompt is not None:
kwargs = {"prompt": prompt, **kwargs}
for name, value in kwargs.items():
if (saved_value := getattr(img, name)) != value:
raise ValueError(f"MISMATCH {name}, should be {saved_value!r}")
return img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment