Skip to content

Instantly share code, notes, and snippets.

@elijahbenizzy
Created April 9, 2024 13:04
Show Gist options
  • Save elijahbenizzy/b0524c5153319f470c29dd5cec471315 to your computer and use it in GitHub Desktop.
Save elijahbenizzy/b0524c5153319f470c29dd5cec471315 to your computer and use it in GitHub Desktop.
"""
This module demonstrates a telephone application
using Burr that:
- captions an image
- creates caption embeddings (for analysis)
- creates a new image based on the created caption
"""
import os
import uuid
from hamilton import dataflows, driver
import requests
from burr.core import Action, ApplicationBuilder, State, default, expr
from burr.core.action import action
from burr.lifecycle import PostRunStepHook
# import hamilton modules
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")
@action(
reads=["current_image_location"],
writes=["current_image_caption", "image_location_history"],
)
def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to caption an image."""
current_image = state["current_image_location"]
result = caption_image_driver.execute(
["generated_caption"], inputs={"image_url": current_image}
)
updates = {
"current_image_caption": result["generated_caption"],
}
# could save to S3 here.
return result, state.update(**updates).append(image_location_history=current_image)
@action(
reads=["current_image_caption"],
writes=["caption_analysis"],
)
def caption_embeddings(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
result = caption_image_driver.execute(
["metadata"],
overrides={"generated_caption": state["current_image_caption"]}
)
# could save to S3 here.
return result, state.append(caption_analysis=result["metadata"])
@action(
reads=["current_image_caption"],
writes=["current_image_location", "image_caption_history"],
)
def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to create an image."""
current_caption = state["current_image_caption"]
result = generate_image_driver.execute(
["generated_image"], inputs={"image_generation_prompt": current_caption}
)
updates = {
"current_image_location": result["generated_image"],
}
# could save to S3 here.
return result, state.update(**updates).append(image_caption_history=current_caption)
@action(
reads=["image_location_history", "image_caption_history", "caption_analysis"],
writes=[]
)
def terminal_step(state: State) -> tuple[dict, State]:
result = {"image_location_history": state["image_location_history"],
"image_caption_history": state["image_caption_history"],
"caption_analysis": state["caption_analysis"]}
# could save to S3 here.
return result, state
def build_application(starting_image: str = "statemachine.png",
number_of_images_to_caption: int = 4):
"""This shows how one might define functions to be nodes."""
# instantiate hamilton drivers and then bind them to the actions.
caption_image_driver = (
driver.Builder()
.with_config({"include_embeddings": True})
.with_modules(caption_images)
.build()
)
generate_image_driver = (
driver.Builder()
.with_config({})
.with_modules(generate_images)
.build()
)
app = (
ApplicationBuilder()
.with_state(
current_image_location=starting_image,
current_image_caption="",
image_location_history=[],
image_caption_history=[],
caption_analysis=[],
)
.with_actions(
caption=image_caption.bind(caption_image_driver=caption_image_driver),
analysis=caption_embeddings.bind(caption_image_driver=caption_image_driver),
generate=image_generation.bind(generate_image_driver=generate_image_driver),
terminal=terminal_step,
)
.with_transitions(
("caption", "analysis", default),
("analysis", "terminal",
expr(f"len(image_caption_history) == {number_of_images_to_caption}")),
("analysis", "generate", default),
("generate", "caption", default),
)
.with_entrypoint("caption")
.with_tracker(project="image-telephone")
.build()
)
return app
if __name__ == "__main__":
app = build_application()
app.visualize(
output_file_path="statemachine", include_conditions=True, view=True, format="png"
)
last_action, result, state = app.run(halt_after=["terminal"])
# save to S3 / download images etc.
print(state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment