Skip to content

Instantly share code, notes, and snippets.

@timmyreilly
Created January 23, 2024 23:40
Show Gist options
  • Save timmyreilly/493e202d368033eb8a49b17a256e6b78 to your computer and use it in GitHub Desktop.
Save timmyreilly/493e202d368033eb8a49b17a256e6b78 to your computer and use it in GitHub Desktop.
Get an estimate of GPT token usage and cost 2024
import argparse
import tiktoken
from pydantic import BaseModel, Field, field_validator
pricing = {
"GPT-3.5-Turbo": {"context_limit": 4000, "prompt": 0.0015, "completion": 0.002},
"GPT-3.5-Turbo-16K": {
"context_limit": 16000,
"prompt": 0.003,
"completion": 0.004,
},
"GPT-3.5-Turbo-1106": {
"context_limit": 16000,
"prompt": 0.001,
"completion": 0.002,
},
"GPT-4-Turbo": {"context_limit": 128000, "prompt": 0.01, "completion": 0.03},
"GPT-4-Turbo-Vision": {
"context_limit": 128000,
"prompt": 0.01,
"completion": 0.03,
},
"GPT-4": {"context_limit": 8000, "prompt": 0.03, "completion": 0.06},
"GPT-4-32K": {"context_limit": 32000, "prompt": 0.06, "completion": 0.12},
"Ada": {
"embedding": 0.0001,
"context_limit": 8191,
}, # Embedding service for Ada
}
# Define Pydantic models for input validation
class TokenCountInput(BaseModel):
text: str = Field(...)
class CostCalculationInput(BaseModel):
model: str = Field(...)
token_count: int = Field(...)
service_type: str = Field(...)
@field_validator("model")
def model_must_be_valid(cls, v):
valid_models = pricing.keys()
if v not in valid_models:
raise ValueError(f"Model '{v}' is not valid.")
return v
@field_validator("service_type")
def service_type_must_be_valid(cls, v):
valid_service_types = ["prompt", "completion", "embedding"]
if v not in valid_service_types:
raise ValueError(f"Service type '{v}' is not valid.")
return v
def get_token_count(text: str, encoding="cl100k_base") -> int:
"""
Get the accurate token count for a given text using tiktoken.
"""
enc = tiktoken.get_encoding(encoding)
encoding = enc.encode(text)
return len(encoding)
def get_token_count_from_file(file_path: str, encoding="cl100k_base") -> int:
"""
Get the accurate token count from a text file using tiktoken.
params: file_path: Path to the text file.
"""
enc = tiktoken.get_encoding(encoding)
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
encoding = enc.encode(text)
return len(encoding)
def calculate_cost(
model: str, token_count: int, service_type: str, embedding=False
) -> float:
"""
Calculate the cost of tokens given the model, number of tokens, and service type.
"""
if model not in pricing:
raise Exception(f"Model '{model}' not found in the pricing structure.")
model_info = pricing[model]
if service_type not in model_info:
raise Exception(f"Service type '{service_type}' not found for model '{model}'.")
if token_count > model_info["context_limit"]:
raise Exception(
f"Token count {token_count} exceeds the context limit of"
f" {model_info['context_limit']} for model '{model}'."
)
rate = model_info[service_type]
return (token_count / 1000) * rate
if __name__ == "__main__":
available_models = pricing.keys()
available_service_types = ["prompt", "completion", "embedding"]
parser = argparse.ArgumentParser(
description="Calculate cost for processing text with different OpenAI models.",
epilog="Example: python cost_estimation.py text.txt GPT-3.5-Turbo prompt",
)
parser.add_argument("file", type=str, help="Path to the text file.")
parser.add_argument(
"model",
type=str,
choices=available_models,
help="Model name (e.g., GPT-3.5-Turbo, GPT-4).",
)
parser.add_argument(
"service_type",
type=str,
choices=available_service_types,
help="Service type (prompt, completion, or embedding).",
)
args = parser.parse_args()
try:
# Validate input using Pydantic
cost_input = CostCalculationInput(
model=args.model,
token_count=get_token_count_from_file(args.file),
service_type=args.service_type,
)
# Calculate cost
cost = calculate_cost(
cost_input.model, cost_input.token_count, cost_input.service_type
)
print(f"Accurate Token Count: {cost_input.token_count}")
print(
f"Estimated Cost for {cost_input.model} ({cost_input.service_type}):"
f" ${cost:.4f}"
)
except Exception as e:
print(f"Error: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment