Created
January 23, 2024 23:40
-
-
Save timmyreilly/493e202d368033eb8a49b17a256e6b78 to your computer and use it in GitHub Desktop.
Get an estimate of GPT token usage and cost 2024
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import tiktoken | |
from pydantic import BaseModel, Field, field_validator | |
pricing = { | |
"GPT-3.5-Turbo": {"context_limit": 4000, "prompt": 0.0015, "completion": 0.002}, | |
"GPT-3.5-Turbo-16K": { | |
"context_limit": 16000, | |
"prompt": 0.003, | |
"completion": 0.004, | |
}, | |
"GPT-3.5-Turbo-1106": { | |
"context_limit": 16000, | |
"prompt": 0.001, | |
"completion": 0.002, | |
}, | |
"GPT-4-Turbo": {"context_limit": 128000, "prompt": 0.01, "completion": 0.03}, | |
"GPT-4-Turbo-Vision": { | |
"context_limit": 128000, | |
"prompt": 0.01, | |
"completion": 0.03, | |
}, | |
"GPT-4": {"context_limit": 8000, "prompt": 0.03, "completion": 0.06}, | |
"GPT-4-32K": {"context_limit": 32000, "prompt": 0.06, "completion": 0.12}, | |
"Ada": { | |
"embedding": 0.0001, | |
"context_limit": 8191, | |
}, # Embedding service for Ada | |
} | |
# Define Pydantic models for input validation | |
class TokenCountInput(BaseModel): | |
text: str = Field(...) | |
class CostCalculationInput(BaseModel): | |
model: str = Field(...) | |
token_count: int = Field(...) | |
service_type: str = Field(...) | |
@field_validator("model") | |
def model_must_be_valid(cls, v): | |
valid_models = pricing.keys() | |
if v not in valid_models: | |
raise ValueError(f"Model '{v}' is not valid.") | |
return v | |
@field_validator("service_type") | |
def service_type_must_be_valid(cls, v): | |
valid_service_types = ["prompt", "completion", "embedding"] | |
if v not in valid_service_types: | |
raise ValueError(f"Service type '{v}' is not valid.") | |
return v | |
def get_token_count(text: str, encoding="cl100k_base") -> int: | |
""" | |
Get the accurate token count for a given text using tiktoken. | |
""" | |
enc = tiktoken.get_encoding(encoding) | |
encoding = enc.encode(text) | |
return len(encoding) | |
def get_token_count_from_file(file_path: str, encoding="cl100k_base") -> int: | |
""" | |
Get the accurate token count from a text file using tiktoken. | |
params: file_path: Path to the text file. | |
""" | |
enc = tiktoken.get_encoding(encoding) | |
with open(file_path, "r", encoding="utf-8") as file: | |
text = file.read() | |
encoding = enc.encode(text) | |
return len(encoding) | |
def calculate_cost( | |
model: str, token_count: int, service_type: str, embedding=False | |
) -> float: | |
""" | |
Calculate the cost of tokens given the model, number of tokens, and service type. | |
""" | |
if model not in pricing: | |
raise Exception(f"Model '{model}' not found in the pricing structure.") | |
model_info = pricing[model] | |
if service_type not in model_info: | |
raise Exception(f"Service type '{service_type}' not found for model '{model}'.") | |
if token_count > model_info["context_limit"]: | |
raise Exception( | |
f"Token count {token_count} exceeds the context limit of" | |
f" {model_info['context_limit']} for model '{model}'." | |
) | |
rate = model_info[service_type] | |
return (token_count / 1000) * rate | |
if __name__ == "__main__": | |
available_models = pricing.keys() | |
available_service_types = ["prompt", "completion", "embedding"] | |
parser = argparse.ArgumentParser( | |
description="Calculate cost for processing text with different OpenAI models.", | |
epilog="Example: python cost_estimation.py text.txt GPT-3.5-Turbo prompt", | |
) | |
parser.add_argument("file", type=str, help="Path to the text file.") | |
parser.add_argument( | |
"model", | |
type=str, | |
choices=available_models, | |
help="Model name (e.g., GPT-3.5-Turbo, GPT-4).", | |
) | |
parser.add_argument( | |
"service_type", | |
type=str, | |
choices=available_service_types, | |
help="Service type (prompt, completion, or embedding).", | |
) | |
args = parser.parse_args() | |
try: | |
# Validate input using Pydantic | |
cost_input = CostCalculationInput( | |
model=args.model, | |
token_count=get_token_count_from_file(args.file), | |
service_type=args.service_type, | |
) | |
# Calculate cost | |
cost = calculate_cost( | |
cost_input.model, cost_input.token_count, cost_input.service_type | |
) | |
print(f"Accurate Token Count: {cost_input.token_count}") | |
print( | |
f"Estimated Cost for {cost_input.model} ({cost_input.service_type}):" | |
f" ${cost:.4f}" | |
) | |
except Exception as e: | |
print(f"Error: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment