Skip to content

Instantly share code, notes, and snippets.

@craiga
Created July 15, 2024 14:17
Show Gist options
  • Save craiga/ac09b21908f7ddbab0bb9e899c8b07b2 to your computer and use it in GitHub Desktop.
Save craiga/ac09b21908f7ddbab0bb9e899c8b07b2 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""Get images from PhotoPrism and label them."""
import json
import logging
import tempfile
import time
from pprint import pformat
import boto3
import click
import click_log
import httpx
import ollama as ollamalib
logger = logging.getLogger(__name__)
click_log.basic_config(logger)
class SessionIDAuth(httpx.Auth):
"""Session ID authentication."""
def __init__(self, session_id):
"""Session ID authentication."""
self.session_id = session_id
def auth_flow(self, request):
"""Add session ID header to request."""
request.headers["X-Session-ID"] = self.session_id
yield request
def image_search(session, api_url, query, order):
"""Perform a paginated image search."""
offset = 0
page_size = 100
while True:
response = session.get(
api_url + "/photos",
params={"count": page_size, "offset": offset, "q": query, "order": order},
)
response.raise_for_status()
photos = response.json()
if not photos:
break
yield from photos
offset += page_size
@click.group()
@click.option("-u", "--username")
@click.option("-p", "--password")
@click.option("-u", "--api-url")
@click.option(
"-s", "--sleep-between-labels", "--sleep", default=0, show_default=True, type=int
)
@click.option("--sleep-between-images", default=0, show_default=True, type=int)
@click.option("-q", "--query", default="original:*", show_default=True)
@click.option("-o", "--order", default="added", show_default=True)
@click_log.simple_verbosity_option(logger)
@click.pass_context
def cli(
context,
username,
password,
api_url,
sleep_between_labels,
sleep_between_images,
query,
order,
):
"""Get images from PhotoPrism and label them."""
logger.debug("Establish an httpx session.")
response = httpx.post(
api_url + "/session",
json={"username": username, "password": password},
timeout=None,
)
response.raise_for_status()
response_data = response.json()
session = httpx.Client(auth=SessionIDAuth(response_data["id"]), timeout=None)
download_token = response_data["config"]["downloadToken"]
context.ensure_object(dict)
context.obj["api_url"] = api_url
context.obj["download_token"] = download_token
context.obj["order"] = order
context.obj["query"] = query
context.obj["session"] = session
context.obj["sleep_between_images"] = sleep_between_images
context.obj["sleep_between_labels"] = sleep_between_labels
@cli.command()
@click.pass_context
def rekognition(context):
"""Label images using Amazon Rekognition."""
api_url = context.obj["api_url"]
download_token = context.obj["download_token"]
order = context.obj["order"]
query = context.obj["query"]
session = context.obj["session"]
sleep_between_images = context.obj["sleep_between_images"]
sleep_between_labels = context.obj["sleep_between_labels"]
rekognition = boto3.client("rekognition")
for search_result in image_search(session, api_url, query, order):
logger.info(f"Processing uid:{search_result["UID"]}…")
info_response = session.get(api_url + "/photos/" + search_result["UID"])
info_response.raise_for_status()
info = info_response.json()
if "Rekognition" in [l["Label"]["Name"] for l in info["Labels"]]:
logger.info("Skipping as it already has Rekognition label.")
continue
try:
download_response = session.get(
api_url + "/dl/" + search_result["Hash"], params={"t": download_token}
)
download_response.raise_for_status()
except Exception as exc:
logger.warning(
"Error while calling Rekognition, moving on to next image.",
exc_info=exc,
)
continue
try:
rekognition_response = rekognition.detect_labels(
Image={"Bytes": download_response.content}
)
except Exception as exc:
logger.warning(
"Error while calling Rekognition, moving on to next image.",
exc_info=exc,
)
continue
labels = [
(l["Name"], l["Confidence"]) for l in rekognition_response["Labels"]
] + [("Rekognition", 100)]
for label, confidence in labels:
uncertainty = int(100 - confidence)
logger.info(f"Adding label {label} with uncertainty of {uncertainty}.")
label_response = session.post(
api_url + "/photos/" + search_result["UID"] + "/label",
json={"Name": label, "Priority": 0, "Uncertainty": uncertainty},
)
label_response.raise_for_status()
time.sleep(sleep_between_labels)
time.sleep(sleep_between_images)
@cli.command()
@click.pass_context
@click.option("--url", default="http://localhost:11434", show_default=True, type=str)
@click.option("--model", default="llava", show_default=True, type=str)
@click.option(
"--prompt",
default=(
"Your job is to generate tags for images, along with a confidence score for"
" each tag. Generate as many tags as possible. Tags are generally one or two"
" words long. Confidence should be a score from 0 to 100 of how confident you"
" are in that tag. Return results as a JSON object with the tags as the keys"
" and the confidence as the values."
),
show_default=True,
type=str,
)
def ollama(context, url, model, prompt):
"""Label images using Ollama."""
api_url = context.obj["api_url"]
download_token = context.obj["download_token"]
order = context.obj["order"]
query = context.obj["query"]
session = context.obj["session"]
sleep_between_images = context.obj["sleep_between_images"]
sleep_between_labels = context.obj["sleep_between_labels"]
for search_result in image_search(session, api_url, query, order):
logger.info(f"Processing uid:{search_result["UID"]}…")
try:
download_response = session.get(
api_url + "/dl/" + search_result["Hash"], params={"t": download_token}
)
download_response.raise_for_status()
except Exception as exc:
logger.warning(
"Error while calling Rekognition, moving on to next image.",
exc_info=exc,
)
continue
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(download_response.content)
ollama_client = ollamalib.Client(host=url)
ollama_response = ollama_client.generate(
model=model, format="json", images=[tmp_file.name], prompt=prompt
)
logger.debug(
"Got response from Ollama:\n%s", pformat(ollama_response["response"])
)
try:
labels = json.loads(ollama_response["response"])
except json.decoder.JSONDecodeError as exc:
exc.add_note(
f'String trying to be decoded: "{ollama_response["response"]}"'
)
logger.warning(
"Error decoding JSON response from Ollama, moving on to next image.",
exc_info=exc,
)
continue
for label, confidence in labels.items():
uncertainty = int(100 - confidence)
logger.info(f"Adding label {label} with uncertainty of {uncertainty}.")
label_response = session.post(
api_url + "/photos/" + search_result["UID"] + "/label",
json={"Name": label, "Priority": 0, "Uncertainty": uncertainty},
)
label_response.raise_for_status()
time.sleep(sleep_between_labels)
time.sleep(sleep_between_images)
if __name__ == "__main__":
cli()
boto3
click
click-log
pip-tools
httpx
ollama
#
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --allow-unsafe --strip-extras
#
anyio==4.4.0
# via httpx
boto3==1.34.115
# via -r requirements.in
botocore==1.34.115
# via
# boto3
# s3transfer
build==1.2.1
# via pip-tools
certifi==2024.2.2
# via
# httpcore
# httpx
click==8.1.7
# via
# -r requirements.in
# click-log
# pip-tools
click-log==0.4.0
# via -r requirements.in
h11==0.14.0
# via httpcore
httpcore==1.0.5
# via httpx
httpx==0.27.0
# via
# -r requirements.in
# ollama
idna==3.7
# via
# anyio
# httpx
jmespath==1.0.1
# via
# boto3
# botocore
ollama==0.2.1
# via -r requirements.in
packaging==24.0
# via build
pip-tools==7.4.1
# via -r requirements.in
pyproject-hooks==1.1.0
# via
# build
# pip-tools
python-dateutil==2.9.0.post0
# via botocore
s3transfer==0.10.1
# via boto3
six==1.16.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# httpx
urllib3==2.2.1
# via botocore
wheel==0.43.0
# via pip-tools
# The following packages are considered to be unsafe in a requirements file:
pip==24.0
# via pip-tools
setuptools==70.0.0
# via pip-tools
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment