Skip to content

Instantly share code, notes, and snippets.

@skrawcz
Last active April 6, 2024 22:47
Show Gist options
  • Save skrawcz/6b21ceb0789c5c0d2ec42885e3362093 to your computer and use it in GitHub Desktop.
Save skrawcz/6b21ceb0789c5c0d2ec42885e3362093 to your computer and use it in GitHub Desktop.
"""
This module demonstrates a telephone application
using Burr that:
- captions an image
- creates caption embeddings (for analysis)
- creates a new image based on the created caption
"""
import os
import uuid
from hamilton import dataflows, driver
import requests
from burr.core import Action, ApplicationBuilder, State, default, expr
from burr.core.action import action
from burr.lifecycle import PostRunStepHook
# import hamilton modules
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")
@action(
reads=["current_image_location"],
writes=["current_image_caption", "image_location_history"],
)
def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to caption an image."""
current_image = state["current_image_location"]
result = caption_image_driver.execute(
["generated_caption"], inputs={"image_url": current_image}
)
updates = {
"current_image_caption": result["generated_caption"],
}
# could save to S3 here.
return result, state.update(**updates).append(image_location_history=current_image)
@action(
reads=["current_image_caption"],
writes=["caption_analysis"],
)
def caption_embeddings(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:
result = caption_image_driver.execute(
["metadata"],
overrides={"generated_caption": state["current_image_caption"]}
)
# could save to S3 here.
return result, state.append(caption_analysis=result["metadata"])
@action(
reads=["current_image_caption"],
writes=["current_image_location", "image_caption_history"],
)
def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]:
"""Action to create an image."""
current_caption = state["current_image_caption"]
result = generate_image_driver.execute(
["generated_image"], inputs={"image_generation_prompt": current_caption}
)
updates = {
"current_image_location": result["generated_image"],
}
# could save to S3 here.
return result, state.update(**updates).append(image_caption_history=current_caption)
@action(
reads=["image_location_history", "image_caption_history", "caption_analysis"],
writes=[]
)
def terminal_step(state: State) -> tuple[dict, State]:
result = {"image_location_history": state["image_location_history"],
"image_caption_history": state["image_caption_history"],
"caption_analysis": state["caption_analysis"]}
# could save to S3 here.
return result, state
def build_application(starting_image: str = "statemachine.png",
number_of_images_to_caption: int = 4):
"""This shows how one might define functions to be nodes."""
# instantiate hamilton drivers and then bind them to the actions.
caption_image_driver = (
driver.Builder()
.with_config({"include_embeddings": True})
.with_modules(caption_images)
.build()
)
generate_image_driver = (
driver.Builder()
.with_config({})
.with_modules(generate_images)
.build()
)
app = (
ApplicationBuilder()
.with_state(
current_image_location=starting_image,
current_image_caption="",
image_location_history=[],
image_caption_history=[],
caption_analysis=[],
)
.with_actions(
caption=image_caption.bind(caption_image_driver=caption_image_driver),
analysis=caption_embeddings.bind(caption_image_driver=caption_image_driver),
generate=image_generation.bind(generate_image_driver=generate_image_driver),
terminal=terminal_step,
)
.with_transitions(
("caption", "analysis", default),
("analysis", "terminal",
expr(f"len(image_caption_history) == {number_of_images_to_caption}")),
("analysis", "generate", default),
("generate", "caption", default),
)
.with_entrypoint("caption")
.with_tracker(project="image-telephone")
.build()
)
return app
if __name__ == "__main__":
import random
coin_flip = random.choice([True, False])
# app = build_application("path/to/my/image.png")
app = build_application()
app.visualize(
output_file_path="statemachine", include_conditions=True, view=True, format="png"
)
if coin_flip:
last_action, result, state = app.run(halt_after=["terminal"])
# save to S3 / download images etc.
print(state)
else:
# alternate way to run:
while True:
action, result, state = app.step()
print("action=====\n", action)
print("result=====\n", result)
# you could save S3 / download images etc. here.
if action.name == "terminal":
break
print(state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment