Skip to content

Instantly share code, notes, and snippets.

@elecnix
Created November 7, 2023 02:56
Show Gist options
  • Save elecnix/574e571204f79529286faed02415cb79 to your computer and use it in GitHub Desktop.
Save elecnix/574e571204f79529286faed02415cb79 to your computer and use it in GitHub Desktop.
"""
This module contains functions for estimating the cost of OpenAI's image recognition API.
https://platform.openai.com/docs/guides/vision
Usage:
import openai_vision_cost
cost = openai_vision_cost.estimate_image_cost(image_width, image_height, detail_level)
Generated with the help of ChatGPT:
https://chat.openai.com/share/bc4e6202-f007-496e-ac18-2f888b93a80f
"""
def scale_to_fit_square(image_width, image_height, square_size):
"""Scale an image to fit within a square of the given size, maintaining aspect ratio, only if it's larger."""
# Only scale down if either dimension is greater than the square_size
if image_width > square_size or image_height > square_size:
aspect_ratio = image_width / image_height
if aspect_ratio > 1:
# Width is greater than height
return square_size, int(square_size / aspect_ratio)
else:
# Height is greater than or equal to width
return int(square_size * aspect_ratio), square_size
else:
# If the image is smaller than the square, return its original dimensions
return image_width, image_height
def scale_to_shortest_side(image_width, image_height, target_short_side_length):
"""Scale the image so that the shortest side is the target length, maintaining aspect ratio, only if it's larger."""
# Determine the current shortest side
current_shortest_side = min(image_width, image_height)
# Only scale down if the current shortest side is greater than the target length
if current_shortest_side > target_short_side_length:
scaling_factor = target_short_side_length / current_shortest_side
scaled_width = int(image_width * scaling_factor)
scaled_height = int(image_height * scaling_factor)
else:
# If the image is smaller than the target, return its original dimensions
scaled_width = image_width
scaled_height = image_height
return scaled_width, scaled_height
def count_512px_squares(image_width, image_height):
"""Count how many 512x512 squares can fit into an image."""
num_squares_width = -(-image_width // 512) # Ceiling division
num_squares_height = -(-image_height // 512) # Ceiling division
return num_squares_width * num_squares_height
def calculate_tokens(num_squares, detail_level):
"""Calculate the token cost based on number of squares and detail level."""
if detail_level == 'low':
return 85
elif detail_level == 'high':
return num_squares * 170 + 85
else:
raise ValueError("Detail level must be 'low' or 'high'.")
def estimate_image_cost(image_width, image_height, detail_level):
# First, scale the image to fit within a 2048x2048 square if the detail level is high
if detail_level == 'high':
scaled_width, scaled_height = scale_to_fit_square(image_width, image_height, 2048)
# Then, scale the image such that the shortest side is 768px long
final_width, final_height = scale_to_shortest_side(scaled_width, scaled_height, 768)
# Count how many 512px squares the image consists of
num_squares = count_512px_squares(final_width, final_height)
else:
# If detail is low, we don't need to scale
num_squares = 1 # detail: low always counts as one square
# Calculate the token cost based on the number of squares and detail level
token_cost = calculate_tokens(num_squares, detail_level)
return token_cost
def test_functions():
# Test scale_to_fit_square
assert scale_to_fit_square(3000, 4000, 2048) == (1536, 2048)
assert scale_to_fit_square(4000, 3000, 2048) == (2048, 1536)
# Test scale_to_shortest_side
assert scale_to_shortest_side(3000, 4000, 768) == (768, 1024)
assert scale_to_shortest_side(4000, 3000, 768) == (1024, 768)
# Test count_512px_squares
assert count_512px_squares(1024, 768) == 4
assert count_512px_squares(1536, 2048) == 12
# Test calculate_tokens
assert calculate_tokens(4, 'low') == 85
assert calculate_tokens(4, 'high') == 4 * 170 + 85
assert calculate_tokens(12, 'high') == 12 * 170 + 85
# Test estimate_image_cost
assert estimate_image_cost(1000, 500, 'low') == 85
assert estimate_image_cost(512, 512, 'high') == 85 + 170
assert estimate_image_cost(1024, 2048, 'high') == 6 * 170 + 85
print("All tests passed.")
if __name__ == '__main__':
test_functions()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment