Model: hf.co/lmstudio-community/Qwen2.5-7B-Instruct-1M-GGUF:Q8_0
files-to-prompt ~/Dropbox/Development/llm -e py -c
<documents>
<document index="1">
<source>/Users/simon/Dropbox/Development/llm/setup.py</source>
<document_content>
from setuptools import setup, find_packages
import os
VERSION = "0.20"
def get_long_description():
with open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
encoding="utf8",
) as fp:
return fp.read()
setup(
name="llm",
description=(
"A CLI utility and Python library for interacting with Large Language Models, "
"including OpenAI, PaLM and local models installed on your own machine."
),
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Simon Willison",
url="https://github.com/simonw/llm",
project_urls={
"Documentation": "https://llm.datasette.io/",
"Issues": "https://github.com/simonw/llm/issues",
"CI": "https://github.com/simonw/llm/actions",
"Changelog": "https://github.com/simonw/llm/releases",
},
license="Apache License, Version 2.0",
version=VERSION,
packages=find_packages(),
entry_points="""
[console_scripts]
llm=llm.cli:cli
""",
install_requires=[
"click",
"openai>=1.55.3",
"click-default-group>=1.2.3",
"sqlite-utils>=3.37",
"sqlite-migrate>=0.1a2",
"pydantic>=1.10.2",
"PyYAML",
"pluggy",
"python-ulid",
"setuptools",
"pip",
"pyreadline3; sys_platform == 'win32'",
"puremagic",
],
extras_require={
"test": [
"pytest",
"numpy",
"pytest-httpx>=0.33.0",
"pytest-asyncio",
"cogapp",
"mypy>=1.10.0",
"black>=24.1.0",
"ruff",
"types-click",
"types-PyYAML",
"types-setuptools",
]
},
python_requires=">=3.9",
)
</document_content>
</document>
<document index="2">
<source>/Users/simon/Dropbox/Development/llm/llm/__init__.py</source>
<document_content>
from .hookspecs import hookimpl
from .errors import (
ModelError,
NeedsKeyException,
)
from .models import (
AsyncModel,
AsyncResponse,
Attachment,
Conversation,
Model,
ModelWithAliases,
EmbeddingModel,
EmbeddingModelWithAliases,
Options,
Prompt,
Response,
)
from .embeddings import Collection
from .templates import Template
from .plugins import pm, load_plugins
import click
from typing import Dict, List, Optional
import json
import os
import pathlib
import struct
__all__ = [
"hookimpl",
"get_async_model",
"get_model",
"get_key",
"user_dir",
"AsyncResponse",
"Attachment",
"Collection",
"Conversation",
"Model",
"Options",
"Prompt",
"Response",
"Template",
"ModelError",
"NeedsKeyException",
]
DEFAULT_MODEL = "gpt-4o-mini"
def get_plugins(all=False):
plugins = []
plugin_to_distinfo = dict(pm.list_plugin_distinfo())
for plugin in pm.get_plugins():
if not all and plugin.__name__.startswith("llm.default_plugins."):
continue
plugin_info = {
"name": plugin.__name__,
"hooks": [h.name for h in pm.get_hookcallers(plugin)],
}
distinfo = plugin_to_distinfo.get(plugin)
if distinfo:
plugin_info["version"] = distinfo.version
plugin_info["name"] = (
getattr(distinfo, "name", None) or distinfo.project_name
)
plugins.append(plugin_info)
return plugins
def get_models_with_aliases() -> List["ModelWithAliases"]:
model_aliases = []
# Include aliases from aliases.json
aliases_path = user_dir() / "aliases.json"
extra_model_aliases: Dict[str, list] = {}
if aliases_path.exists():
configured_aliases = json.loads(aliases_path.read_text())
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)
def register(model, async_model=None, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(ModelWithAliases(model, async_model, alias_list))
load_plugins()
pm.hook.register_models(register=register)
return model_aliases
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
model_aliases = []
# Include aliases from aliases.json
aliases_path = user_dir() / "aliases.json"
extra_model_aliases: Dict[str, list] = {}
if aliases_path.exists():
configured_aliases = json.loads(aliases_path.read_text())
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)
def register(model, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(EmbeddingModelWithAliases(model, alias_list))
load_plugins()
pm.hook.register_embedding_models(register=register)
return model_aliases
def get_embedding_models():
models = []
def register(model, aliases=None):
models.append(model)
load_plugins()
pm.hook.register_embedding_models(register=register)
return models
def get_embedding_model(name):
aliases = get_embedding_model_aliases()
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + str(name))
def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
model_aliases = {}
for model_with_aliases in get_embedding_models_with_aliases():
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases
def get_async_model_aliases() -> Dict[str, AsyncModel]:
async_model_aliases = {}
for model_with_aliases in get_models_with_aliases():
if model_with_aliases.async_model:
for alias in model_with_aliases.aliases:
async_model_aliases[alias] = model_with_aliases.async_model
async_model_aliases[model_with_aliases.model.model_id] = (
model_with_aliases.async_model
)
return async_model_aliases
def get_model_aliases() -> Dict[str, Model]:
model_aliases = {}
for model_with_aliases in get_models_with_aliases():
if model_with_aliases.model:
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases
class UnknownModelError(KeyError):
pass
def get_models() -> List[Model]:
"Get all registered models"
models_with_aliases = get_models_with_aliases()
return [mwa.model for mwa in models_with_aliases if mwa.model]
def get_async_models() -> List[AsyncModel]:
"Get all registered async models"
models_with_aliases = get_models_with_aliases()
return [mwa.async_model for mwa in models_with_aliases if mwa.async_model]
def get_async_model(name: Optional[str] = None) -> AsyncModel:
"Get an async model by name or alias"
aliases = get_async_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
# Does a sync model exist?
sync_model = None
try:
sync_model = get_model(name, _skip_async=True)
except UnknownModelError:
pass
if sync_model:
raise UnknownModelError("Unknown async model (sync model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)
def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model:
"Get a model by name or alias"
aliases = get_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
# Does an async model exist?
if _skip_async:
raise UnknownModelError("Unknown model: " + name)
async_model = None
try:
async_model = get_async_model(name)
except UnknownModelError:
pass
if async_model:
raise UnknownModelError("Unknown model (async model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)
def get_key(
explicit_key: Optional[str], key_alias: str, env_var: Optional[str] = None
) -> Optional[str]:
"""
Return an API key based on a hierarchy of potential sources.
:param provided_key: A key provided by the user. This may be the key, or an alias of a key in keys.json.
:param key_alias: The alias used to retrieve the key from the keys.json file.
:param env_var: Name of the environment variable to check for the key.
"""
stored_keys = load_keys()
# If user specified an alias, use the key stored for that alias
if explicit_key in stored_keys:
return stored_keys[explicit_key]
if explicit_key:
# User specified a key that's not an alias, use that
return explicit_key
# Stored key over-rides environment variables over-ride the default key
if key_alias in stored_keys:
return stored_keys[key_alias]
# Finally try environment variable
if env_var and os.environ.get(env_var):
return os.environ[env_var]
# Couldn't find it
return None
def load_keys():
path = user_dir() / "keys.json"
if path.exists():
return json.loads(path.read_text())
else:
return {}
def user_dir():
llm_user_path = os.environ.get("LLM_USER_PATH")
if llm_user_path:
path = pathlib.Path(llm_user_path)
else:
path = pathlib.Path(click.get_app_dir("io.datasette.llm"))
path.mkdir(exist_ok=True, parents=True)
return path
def set_alias(alias, model_id_or_alias):
"""
Set an alias to point to the specified model.
"""
path = user_dir() / "aliases.json"
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.write_text("{}\n")
try:
current = json.loads(path.read_text())
except json.decoder.JSONDecodeError:
# We're going to write a valid JSON file in a moment:
current = {}
# Resolve model_id_or_alias to a model_id
try:
model = get_model(model_id_or_alias)
model_id = model.model_id
except UnknownModelError:
# Try to resolve it to an embedding model
try:
model = get_embedding_model(model_id_or_alias)
model_id = model.model_id
except UnknownModelError:
# Set the alias to the exact string they provided instead
model_id = model_id_or_alias
current[alias] = model_id
path.write_text(json.dumps(current, indent=4) + "\n")
def remove_alias(alias):
"""
Remove an alias.
"""
path = user_dir() / "aliases.json"
if not path.exists():
raise KeyError("No aliases.json file exists")
try:
current = json.loads(path.read_text())
except json.decoder.JSONDecodeError:
raise KeyError("aliases.json file is not valid JSON")
if alias not in current:
raise KeyError("No such alias: {}".format(alias))
del current[alias]
path.write_text(json.dumps(current, indent=4) + "\n")
def encode(values):
return struct.pack("<" + "f" * len(values), *values)
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
def cosine_similarity(a, b):
dot_product = sum(x * y for x, y in zip(a, b))
magnitude_a = sum(x * x for x in a) ** 0.5
magnitude_b = sum(x * x for x in b) ** 0.5
return dot_product / (magnitude_a * magnitude_b)
def get_default_model(filename="default_model.txt", default=DEFAULT_MODEL):
path = user_dir() / filename
if path.exists():
return path.read_text().strip()
else:
return default
def set_default_model(model, filename="default_model.txt"):
path = user_dir() / filename
if model is None and path.exists():
path.unlink()
else:
path.write_text(model)
def get_default_embedding_model():
return get_default_model("default_embedding_model.txt", None)
def set_default_embedding_model(model):
set_default_model(model, "default_embedding_model.txt")
</document_content>
</document>
<document index="3">
<source>/Users/simon/Dropbox/Development/llm/llm/__main__.py</source>
<document_content>
from .cli import cli
if __name__ == "__main__":
cli()
</document_content>
</document>
<document index="4">
<source>/Users/simon/Dropbox/Development/llm/llm/cli.py</source>
<document_content>
import asyncio
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
import io
import json
from llm import (
Attachment,
AsyncResponse,
Collection,
Conversation,
Response,
Template,
UnknownModelError,
encode,
get_async_model,
get_default_model,
get_default_embedding_model,
get_embedding_models_with_aliases,
get_embedding_model_aliases,
get_embedding_model,
get_key,
get_plugins,
get_model,
get_model_aliases,
get_models_with_aliases,
user_dir,
set_alias,
set_default_model,
set_default_embedding_model,
remove_alias,
)
from .migrations import migrate
from .plugins import pm, load_plugins
from .utils import (
mimetype_from_path,
mimetype_from_string,
token_usage_string,
extract_fenced_code_block,
)
import base64
import httpx
import pathlib
import pydantic
import readline
from runpy import run_module
import shutil
import sqlite_utils
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Optional, Iterable, Union, Tuple
import warnings
import yaml
warnings.simplefilter("ignore", ResourceWarning)
DEFAULT_TEMPLATE = "prompt: "
class AttachmentType(click.ParamType):
name = "attachment"
def convert(self, value, param, ctx):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
mimetype = mimetype_from_string(content)
if mimetype is None:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
response = httpx.head(value)
response.raise_for_status()
mimetype = response.headers.get("content-type")
except httpx.HTTPError as ex:
raise click.BadParameter(str(ex))
return Attachment(mimetype, None, value, None)
# Check that the file exists
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
path = path.resolve()
# Try to guess type
mimetype = mimetype_from_path(str(path))
if mimetype is None:
raise click.BadParameter(f"Could not determine mimetype of {value}")
return Attachment(type=mimetype, path=str(path), url=None, content=None)
def attachment_types_callback(ctx, param, values):
collected = []
for value, mimetype in values:
if "://" in value:
attachment = Attachment(mimetype, None, value, None)
elif value == "-":
content = sys.stdin.buffer.read()
attachment = Attachment(mimetype, None, None, content)
else:
# Look for file
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
path = path.resolve()
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
return collected
def _validate_metadata_json(ctx, param, value):
if value is None:
return value
try:
obj = json.loads(value)
if not isinstance(obj, dict):
raise click.BadParameter("Metadata must be a JSON object")
return obj
except json.JSONDecodeError:
raise click.BadParameter("Metadata must be valid JSON")
@click.group(
cls=DefaultGroup,
default="prompt",
default_if_no_args=True,
)
@click.version_option()
def cli():
"""
Access Large Language Models from the command-line
Documentation: https://llm.datasette.io/
LLM can run models from many different providers. Consult the
plugin directory for a list of available models:
https://llm.datasette.io/en/stable/plugins/directory.html
To get started with OpenAI, obtain an API key from them and:
\b
$ llm keys set openai
Enter key: ...
Then execute a prompt like this:
llm 'Five outrageous names for a pet pelican'
"""
@cli.command(name="prompt")
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"attachments",
"-a",
"--attachment",
type=AttachmentType(),
multiple=True,
help="Attachment path or URL or -",
)
@click.option(
"attachment_types",
"--at",
"--attachment-type",
type=(str, str),
multiple=True,
callback=attachment_types_callback,
help="Attachment with explicit mimetype",
)
@click.option(
"options",
"-o",
"--option",
type=(str, str),
multiple=True,
help="key/value options for the model",
)
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Parameters for template",
)
@click.option("--no-stream", is_flag=True, help="Do not stream output")
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option("--log", is_flag=True, help="Log prompt and response to the database")
@click.option(
"_continue",
"-c",
"--continue",
is_flag=True,
flag_value=-1,
help="Continue the most recent conversation.",
)
@click.option(
"conversation_id",
"--cid",
"--conversation",
help="Continue the conversation with the given ID.",
)
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
@click.option("-u", "--usage", is_flag=True, help="Show token usage")
@click.option("-x", "--extract", is_flag=True, help="Extract first fenced code block")
@click.option(
"extract_last",
"--xl",
"--extract-last",
is_flag=True,
help="Extract last fenced code block",
)
def prompt(
prompt,
system,
model_id,
attachments,
attachment_types,
options,
template,
param,
no_stream,
no_log,
log,
_continue,
conversation_id,
key,
save,
async_,
usage,
extract,
extract_last,
):
"""
Execute a prompt
Documentation: https://llm.datasette.io/en/stable/usage.html
Examples:
\b
llm 'Capital of France?'
llm 'Capital of France?' -m gpt-4o
llm 'Capital of France?' -s 'answer in Spanish'
Multi-modal models can be called with attachments like this:
\b
llm 'Extract text from this image' -a image.jpg
llm 'Describe' -a https://static.simonwillison.net/static/2024/pelicans.jpg
cat image | llm 'describe image' -a -
# With an explicit mimetype:
cat image | llm 'describe image' --at - image/jpeg
The -x/--extract option returns just the content of the first ``` fenced code
block, if one is present. If none are present it returns the full response.
\b
llm 'JavaScript function for reversing a string' -x
"""
if log and no_log:
raise click.ClickException("--log and --no-log are mutually exclusive")
model_aliases = get_model_aliases()
def read_prompt():
nonlocal prompt
# Is there extra prompt available on stdin?
stdin_prompt = None
if not sys.stdin.isatty():
stdin_prompt = sys.stdin.read()
if stdin_prompt:
bits = [stdin_prompt]
if prompt:
bits.append(prompt)
prompt = " ".join(bits)
if (
prompt is None
and not save
and sys.stdin.isatty()
and not attachments
and not attachment_types
):
# Hang waiting for input to stdin (unless --save)
prompt = sys.stdin.read()
return prompt
if save:
# We are saving their prompt/system/etc to a new template
# Fields to save: prompt, system, model - and more in the future
disallowed_options = []
for option, var in (
("--template", template),
("--continue", _continue),
("--cid", conversation_id),
):
if var:
disallowed_options.append(option)
if disallowed_options:
raise click.ClickException(
"--save cannot be used with {}".format(", ".join(disallowed_options))
)
path = template_dir() / f"{save}.yaml"
to_save = {}
if model_id:
try:
to_save["model"] = model_aliases[model_id].model_id
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
prompt = read_prompt()
if prompt:
to_save["prompt"] = prompt
if system:
to_save["system"] = system
if param:
to_save["defaults"] = dict(param)
if extract:
to_save["extract"] = True
if extract_last:
to_save["extract_last"] = True
path.write_text(
yaml.dump(
to_save,
indent=4,
default_flow_style=False,
),
"utf-8",
)
return
if template:
params = dict(param)
# Cannot be used with system
if system:
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
extract = template_obj.extract
extract_last = template_obj.extract_last
prompt = read_prompt()
try:
prompt, system = template_obj.evaluate(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
model_id = template_obj.model
if extract or extract_last:
no_stream = True
conversation = None
if conversation_id or _continue:
# Load the conversation - loads most recent if no ID provided
try:
conversation = load_conversation(conversation_id)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
# Figure out which model we are using
if model_id is None:
if conversation:
model_id = conversation.model.model_id
else:
model_id = get_default_model()
# Now resolve the model
try:
if async_:
model = get_async_model(model_id)
else:
model = get_model(model_id)
except UnknownModelError as ex:
raise click.ClickException(ex)
# Provide the API key, if one is needed and has been provided
if model.needs_key:
model.key = get_key(key, model.needs_key, model.key_env_var)
if conversation:
# To ensure it can see the key
conversation.model = model
# Validate options
validated_options = {}
if options:
# Validate with pydantic
try:
validated_options = dict(
(key, value)
for key, value in model.Options(**dict(options))
if value is not None
)
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))
resolved_attachments = [*attachments, *attachment_types]
should_stream = model.can_stream and not no_stream
if not should_stream:
validated_options["stream"] = False
prompt = read_prompt()
response = None
prompt_method = model.prompt
if conversation:
prompt_method = conversation.prompt
try:
if async_:
async def inner():
if should_stream:
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
async for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
text = await response.text()
if extract or extract_last:
text = (
extract_fenced_code_block(text, last=extract_last) or text
)
print(text)
return response
response = asyncio.run(inner())
else:
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
text = response.text()
if extract or extract_last:
text = extract_fenced_code_block(text, last=extract_last) or text
print(text)
# List of exceptions that should never be raised in pytest:
except (ValueError, NotImplementedError) as ex:
raise click.ClickException(str(ex))
except Exception as ex:
# All other exceptions should raise in pytest, show to user otherwise
if getattr(sys, "_called_from_test", False):
raise
raise click.ClickException(str(ex))
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
if usage:
# Show token usage to stderr in yellow
click.echo(
click.style(
"Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
),
err=True,
)
# Log to the database
if (logs_on() or log) and not no_log:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
response.log_to_db(db)
@cli.command()
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"_continue",
"-c",
"--continue",
is_flag=True,
flag_value=-1,
help="Continue the most recent conversation.",
)
@click.option(
"conversation_id",
"--cid",
"--conversation",
help="Continue the conversation with the given ID.",
)
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Parameters for template",
)
@click.option(
"options",
"-o",
"--option",
type=(str, str),
multiple=True,
help="key/value options for the model",
)
@click.option("--no-stream", is_flag=True, help="Do not stream output")
@click.option("--key", help="API key to use")
def chat(
system,
model_id,
_continue,
conversation_id,
template,
param,
options,
no_stream,
key,
):
"""
Hold an ongoing chat with a model.
"""
# Left and right arrow keys to move cursor:
if sys.platform != "win32":
readline.parse_and_bind("\\e[D: backward-char")
readline.parse_and_bind("\\e[C: forward-char")
else:
readline.parse_and_bind("bind -x '\\e[D: backward-char'")
readline.parse_and_bind("bind -x '\\e[C: forward-char'")
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
conversation = None
if conversation_id or _continue:
# Load the conversation - loads most recent if no ID provided
try:
conversation = load_conversation(conversation_id)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
template_obj = None
if template:
params = dict(param)
# Cannot be used with system
if system:
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
if model_id is None and template_obj.model:
model_id = template_obj.model
# Figure out which model we are using
if model_id is None:
if conversation:
model_id = conversation.model.model_id
else:
model_id = get_default_model()
# Now resolve the model
try:
model = get_model(model_id)
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
# Provide the API key, if one is needed and has been provided
if model.needs_key:
model.key = get_key(key, model.needs_key, model.key_env_var)
if conversation is None:
# Start a fresh conversation for this chat
conversation = Conversation(model=model)
else:
# Ensure it can see the API key
conversation.model = model
# Validate options
validated_options = {}
if options:
try:
validated_options = dict(
(key, value)
for key, value in model.Options(**dict(options))
if value is not None
)
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))
should_stream = model.can_stream and not no_stream
if not should_stream:
validated_options["stream"] = False
click.echo("Chatting with {}".format(model.model_id))
click.echo("Type 'exit' or 'quit' to exit")
click.echo("Type '!multi' to enter multiple lines, then '!end' to finish")
in_multi = False
accumulated = []
end_token = "!end"
while True:
prompt = click.prompt("", prompt_suffix="> " if not in_multi else "")
if prompt.strip().startswith("!multi"):
in_multi = True
bits = prompt.strip().split()
if len(bits) > 1:
end_token = "!end {}".format(" ".join(bits[1:]))
continue
if in_multi:
if prompt.strip() == end_token:
prompt = "\n".join(accumulated)
in_multi = False
accumulated = []
else:
accumulated.append(prompt)
continue
if template_obj:
try:
prompt, system = template_obj.evaluate(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if prompt.strip() in ("exit", "quit"):
break
response = conversation.prompt(prompt, system=system, **validated_options)
# System prompt only sent for the first message:
system = None
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
response.log_to_db(db)
print("")
def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
db = sqlite_utils.Database(logs_db_path())
migrate(db)
if conversation_id is None:
# Return the most recent conversation, or None if there are none
matches = list(db["conversations"].rows_where(order_by="id desc", limit=1))
if matches:
conversation_id = matches[0]["id"]
else:
return None
try:
row = cast(sqlite_utils.db.Table, db["conversations"]).get(conversation_id)
except sqlite_utils.db.NotFoundError:
raise click.ClickException(
"No conversation found with id={}".format(conversation_id)
)
# Inflate that conversation
conversation = Conversation.from_row(row)
for response in db["responses"].rows_where(
"conversation_id = ?", [conversation_id]
):
conversation.responses.append(Response.from_row(db, response))
return conversation
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def keys():
"Manage stored API keys for different models"
@keys.command(name="list")
def keys_list():
"List names of all stored keys"
path = user_dir() / "keys.json"
if not path.exists():
click.echo("No keys found")
return
keys = json.loads(path.read_text())
for key in sorted(keys.keys()):
if key != "// Note":
click.echo(key)
@keys.command(name="path")
def keys_path_command():
"Output the path to the keys.json file"
click.echo(user_dir() / "keys.json")
@keys.command(name="get")
@click.argument("name")
def keys_get(name):
"""
Return the value of a stored key
Example usage:
\b
export OPENAI_API_KEY=$(llm keys get openai)
"""
path = user_dir() / "keys.json"
if not path.exists():
raise click.ClickException("No keys found")
keys = json.loads(path.read_text())
try:
click.echo(keys[name])
except KeyError:
raise click.ClickException("No key found with name '{}'".format(name))
@keys.command(name="set")
@click.argument("name")
@click.option("--value", prompt="Enter key", hide_input=True, help="Value to set")
def keys_set(name, value):
"""
Save a key in the keys.json file
Example usage:
\b
$ llm keys set openai
Enter key: ...
"""
default = {"// Note": "This file stores secret API credentials. Do not share!"}
path = user_dir() / "keys.json"
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.write_text(json.dumps(default))
path.chmod(0o600)
try:
current = json.loads(path.read_text())
except json.decoder.JSONDecodeError:
current = default
current[name] = value
path.write_text(json.dumps(current, indent=2) + "\n")
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def logs():
"Tools for exploring logged prompts and responses"
@logs.command(name="path")
def logs_path():
"Output the path to the logs.db file"
click.echo(logs_db_path())
@logs.command(name="status")
def logs_status():
"Show current status of database logging"
path = logs_db_path()
if not path.exists():
click.echo("No log database found at {}".format(path))
return
if logs_on():
click.echo("Logging is ON for all prompts".format())
else:
click.echo("Logging is OFF".format())
db = sqlite_utils.Database(path)
migrate(db)
click.echo("Found log database at {}".format(path))
click.echo("Number of conversations logged:\t{}".format(db["conversations"].count))
click.echo("Number of responses logged:\t{}".format(db["responses"].count))
click.echo(
"Database file size: \t\t{}".format(_human_readable_size(path.stat().st_size))
)
@logs.command(name="on")
def logs_turn_on():
"Turn on logging for all prompts"
path = user_dir() / "logs-off"
if path.exists():
path.unlink()
@logs.command(name="off")
def logs_turn_off():
"Turn off logging for all prompts"
path = user_dir() / "logs-off"
path.touch()
LOGS_COLUMNS = """ responses.id,
responses.model,
responses.prompt,
responses.system,
responses.prompt_json,
responses.options_json,
responses.response,
responses.response_json,
responses.conversation_id,
responses.duration_ms,
responses.datetime_utc,
responses.input_tokens,
responses.output_tokens,
responses.token_details,
conversations.name as conversation_name,
conversations.model as conversation_model"""
LOGS_SQL = """
select
{columns}
from
responses
left join conversations on responses.conversation_id = conversations.id{extra_where}
order by responses.id desc{limit}
"""
LOGS_SQL_SEARCH = """
select
{columns}
from
responses
left join conversations on responses.conversation_id = conversations.id
join responses_fts on responses_fts.rowid = responses.rowid
where responses_fts match :query{extra_where}
order by responses_fts.rank desc{limit}
"""
ATTACHMENTS_SQL = """
select
response_id,
attachments.id,
attachments.type,
attachments.path,
attachments.url,
length(attachments.content) as content_length
from attachments
join prompt_attachments
on attachments.id = prompt_attachments.attachment_id
where prompt_attachments.response_id in ({})
order by prompt_attachments."order"
"""
@logs.command(name="list")
@click.option(
"-n",
"--count",
type=int,
default=None,
help="Number of entries to show - defaults to 3, use 0 for all",
)
@click.option(
"-p",
"--path",
type=click.Path(readable=True, exists=True, dir_okay=False),
help="Path to log database",
)
@click.option("-m", "--model", help="Filter by model or model alias")
@click.option("-q", "--query", help="Search for logs matching this string")
@click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output")
@click.option("-u", "--usage", is_flag=True, help="Include token usage")
@click.option("-r", "--response", is_flag=True, help="Just output the last response")
@click.option("-x", "--extract", is_flag=True, help="Extract first fenced code block")
@click.option(
"extract_last",
"--xl",
"--extract-last",
is_flag=True,
help="Extract last fenced code block",
)
@click.option(
"current_conversation",
"-c",
"--current",
is_flag=True,
flag_value=-1,
help="Show logs from the current conversation",
)
@click.option(
"conversation_id",
"--cid",
"--conversation",
help="Show logs for this conversation ID",
)
@click.option(
"json_output",
"--json",
is_flag=True,
help="Output logs as JSON",
)
def logs_list(
count,
path,
model,
query,
truncate,
usage,
response,
extract,
extract_last,
current_conversation,
conversation_id,
json_output,
):
"Show recent logged prompts and their responses"
path = pathlib.Path(path or logs_db_path())
if not path.exists():
raise click.ClickException("No log database found at {}".format(path))
db = sqlite_utils.Database(path)
migrate(db)
if response and not current_conversation and not conversation_id:
current_conversation = True
if current_conversation:
try:
conversation_id = next(
db.query(
"select conversation_id from responses order by id desc limit 1"
)
)["conversation_id"]
except StopIteration:
# No conversations yet
raise click.ClickException("No conversations found")
# For --conversation set limit 0, if not explicitly set
if count is None:
if conversation_id:
count = 0
else:
count = 3
model_id = None
if model:
# Resolve alias, if any
try:
model_id = get_model(model).model_id
except UnknownModelError:
# Maybe they uninstalled a model, use the -m option as-is
model_id = model
sql = LOGS_SQL
if query:
sql = LOGS_SQL_SEARCH
limit = ""
if count is not None and count > 0:
limit = " limit {}".format(count)
sql_format = {
"limit": limit,
"columns": LOGS_COLUMNS,
"extra_where": "",
}
where_bits = []
if model_id:
where_bits.append("responses.model = :model")
if conversation_id:
where_bits.append("responses.conversation_id = :conversation_id")
if where_bits:
where_ = " and " if query else " where "
sql_format["extra_where"] = where_ + " and ".join(where_bits)
final_sql = sql.format(**sql_format)
rows = list(
db.query(
final_sql,
{"model": model_id, "query": query, "conversation_id": conversation_id},
)
)
# Reverse the order - we do this because we 'order by id desc limit 3' to get the
# 3 most recent results, but we still want to display them in chronological order
# ... except for searches where we don't do this
if not query:
rows.reverse()
# Fetch any attachments
ids = [row["id"] for row in rows]
attachments = list(db.query(ATTACHMENTS_SQL.format(",".join("?" * len(ids))), ids))
attachments_by_id = {}
for attachment in attachments:
attachments_by_id.setdefault(attachment["response_id"], []).append(attachment)
for row in rows:
if truncate:
row["prompt"] = _truncate_string(row["prompt"])
row["response"] = _truncate_string(row["response"])
# Either decode or remove all JSON keys
keys = list(row.keys())
for key in keys:
if key.endswith("_json") and row[key] is not None:
if truncate:
del row[key]
else:
row[key] = json.loads(row[key])
output = None
if json_output:
# Output as JSON if requested
for row in rows:
row["attachments"] = [
{k: v for k, v in attachment.items() if k != "response_id"}
for attachment in attachments_by_id.get(row["id"], [])
]
output = json.dumps(list(rows), indent=2)
elif extract or extract_last:
# Extract and return first code block
for row in rows:
output = extract_fenced_code_block(row["response"], last=extract_last)
if output is not None:
break
elif response:
# Just output the last response
if rows:
output = rows[-1]["response"]
if output is not None:
click.echo(output)
else:
# Output neatly formatted human-readable logs
current_system = None
should_show_conversation = True
for row in rows:
click.echo(
"# {}{}\n{}".format(
row["datetime_utc"].split(".")[0],
(
" conversation: {}".format(row["conversation_id"])
if should_show_conversation
else ""
),
(
"\nModel: **{}**\n".format(row["model"])
if should_show_conversation
else ""
),
)
)
# In conversation log mode only show it for the first one
if conversation_id:
should_show_conversation = False
click.echo("## Prompt:\n\n{}".format(row["prompt"]))
if row["system"] != current_system:
if row["system"] is not None:
click.echo("\n## System:\n\n{}".format(row["system"]))
current_system = row["system"]
attachments = attachments_by_id.get(row["id"])
if attachments:
click.echo("\n### Attachments\n")
for i, attachment in enumerate(attachments, 1):
if attachment["path"]:
path = attachment["path"]
click.echo(
"{}. **{}**: `{}`".format(i, attachment["type"], path)
)
elif attachment["url"]:
click.echo(
"{}. **{}**: {}".format(
i, attachment["type"], attachment["url"]
)
)
elif attachment["content_length"]:
click.echo(
"{}. **{}**: `<{} bytes>`".format(
i,
attachment["type"],
f"{attachment['content_length']:,}",
)
)
click.echo("\n## Response:\n\n{}\n".format(row["response"]))
if usage:
token_usage = token_usage_string(
row["input_tokens"],
row["output_tokens"],
json.loads(row["token_details"]) if row["token_details"] else None,
)
if token_usage:
click.echo("## Token usage:\n\n{}\n".format(token_usage))
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def models():
"Manage available models"
_type_lookup = {
"number": "float",
"integer": "int",
"string": "str",
"object": "dict",
}
@models.command(name="list")
@click.option(
"--options", is_flag=True, help="Show options for each model, if available"
)
@click.option("async_", "--async", is_flag=True, help="List async models")
@click.option("-q", "--query", help="Search for models matching this string")
def models_list(options, async_, query):
"List available models"
models_that_have_shown_options = set()
for model_with_aliases in get_models_with_aliases():
if async_ and not model_with_aliases.async_model:
continue
if query and not model_with_aliases.matches(query):
continue
extra = ""
if model_with_aliases.aliases:
extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases))
model = (
model_with_aliases.model if not async_ else model_with_aliases.async_model
)
output = str(model) + extra
if options and model.Options.schema()["properties"]:
output += "\n Options:"
for name, field in model.Options.schema()["properties"].items():
any_of = field.get("anyOf")
if any_of is None:
any_of = [{"type": field["type"]}]
types = ", ".join(
[
_type_lookup.get(item["type"], item["type"])
for item in any_of
if item["type"] != "null"
]
)
bits = ["\n ", name, ": ", types]
description = field.get("description", "")
if description and (
model.__class__ not in models_that_have_shown_options
):
wrapped = textwrap.wrap(description, 70)
bits.append("\n ")
bits.extend("\n ".join(wrapped))
output += "".join(bits)
models_that_have_shown_options.add(model.__class__)
if options and model.attachment_types:
attachment_types = ", ".join(sorted(model.attachment_types))
wrapper = textwrap.TextWrapper(
width=min(max(shutil.get_terminal_size().columns, 30), 70),
initial_indent=" ",
subsequent_indent=" ",
)
output += "\n Attachment types:\n{}".format(wrapper.fill(attachment_types))
click.echo(output)
if not query:
click.echo(f"Default: {get_default_model()}")
@models.command(name="default")
@click.argument("model", required=False)
def models_default(model):
"Show or set the default model"
if not model:
click.echo(get_default_model())
return
# Validate it is a known model
try:
model = get_model(model)
set_default_model(model.model_id)
except KeyError:
raise click.ClickException("Unknown model: {}".format(model))
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def templates():
"Manage stored prompt templates"
@templates.command(name="list")
def templates_list():
"List available prompt templates"
path = template_dir()
pairs = []
for file in path.glob("*.yaml"):
name = file.stem
template = load_template(name)
text = []
if template.system:
text.append(f"system: {template.system}")
if template.prompt:
text.append(f" prompt: {template.prompt}")
else:
text = [template.prompt if template.prompt else ""]
pairs.append((name, "".join(text).replace("\n", " ")))
try:
max_name_len = max(len(p[0]) for p in pairs)
except ValueError:
return
else:
fmt = "{name:<" + str(max_name_len) + "} : {prompt}"
for name, prompt in sorted(pairs):
text = fmt.format(name=name, prompt=prompt)
click.echo(display_truncated(text))
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def aliases():
"Manage model aliases"
@aliases.command(name="list")
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
def aliases_list(json_):
"List current aliases"
to_output = []
for alias, model in get_model_aliases().items():
if alias != model.model_id:
to_output.append((alias, model.model_id, ""))
for alias, embedding_model in get_embedding_model_aliases().items():
if alias != embedding_model.model_id:
to_output.append((alias, embedding_model.model_id, "embedding"))
if json_:
click.echo(
json.dumps({key: value for key, value, type_ in to_output}, indent=4)
)
return
max_alias_length = max(len(a) for a, _, _ in to_output)
fmt = "{alias:<" + str(max_alias_length) + "} : {model_id}{type_}"
for alias, model_id, type_ in to_output:
click.echo(
fmt.format(
alias=alias, model_id=model_id, type_=f" ({type_})" if type_ else ""
)
)
@aliases.command(name="set")
@click.argument("alias")
@click.argument("model_id")
def aliases_set(alias, model_id):
"""
Set an alias for a model
Example usage:
\b
$ llm aliases set turbo gpt-3.5-turbo
"""
set_alias(alias, model_id)
@aliases.command(name="remove")
@click.argument("alias")
def aliases_remove(alias):
"""
Remove an alias
Example usage:
\b
$ llm aliases remove turbo
"""
try:
remove_alias(alias)
except KeyError as ex:
raise click.ClickException(ex.args[0])
@aliases.command(name="path")
def aliases_path():
"Output the path to the aliases.json file"
click.echo(user_dir() / "aliases.json")
@cli.command(name="plugins")
@click.option("--all", help="Include built-in default plugins", is_flag=True)
def plugins_list(all):
"List installed plugins"
click.echo(json.dumps(get_plugins(all), indent=2))
def display_truncated(text):
console_width = shutil.get_terminal_size()[0]
if len(text) > console_width:
return text[: console_width - 3] + "..."
else:
return text
@templates.command(name="show")
@click.argument("name")
def templates_show(name):
"Show the specified prompt template"
template = load_template(name)
click.echo(
yaml.dump(
dict((k, v) for k, v in template.dict().items() if v is not None),
indent=4,
default_flow_style=False,
)
)
@templates.command(name="edit")
@click.argument("name")
def templates_edit(name):
"Edit the specified prompt template using the default $EDITOR"
# First ensure it exists
path = template_dir() / f"{name}.yaml"
if not path.exists():
path.write_text(DEFAULT_TEMPLATE, "utf-8")
click.edit(filename=path)
# Validate that template
load_template(name)
@templates.command(name="path")
def templates_path():
"Output the path to the templates directory"
click.echo(template_dir())
@cli.command()
@click.argument("packages", nargs=-1, required=False)
@click.option(
"-U", "--upgrade", is_flag=True, help="Upgrade packages to latest version"
)
@click.option(
"-e",
"--editable",
help="Install a project in editable mode from this path",
)
@click.option(
"--force-reinstall",
is_flag=True,
help="Reinstall all packages even if they are already up-to-date",
)
@click.option(
"--no-cache-dir",
is_flag=True,
help="Disable the cache",
)
def install(packages, upgrade, editable, force_reinstall, no_cache_dir):
"""Install packages from PyPI into the same environment as LLM"""
args = ["pip", "install"]
if upgrade:
args += ["--upgrade"]
if editable:
args += ["--editable", editable]
if force_reinstall:
args += ["--force-reinstall"]
if no_cache_dir:
args += ["--no-cache-dir"]
args += list(packages)
sys.argv = args
run_module("pip", run_name="__main__")
@cli.command()
@click.argument("packages", nargs=-1, required=True)
@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation")
def uninstall(packages, yes):
"""Uninstall Python packages from the LLM environment"""
sys.argv = ["pip", "uninstall"] + list(packages) + (["-y"] if yes else [])
run_module("pip", run_name="__main__")
@cli.command()
@click.argument("collection", required=False)
@click.argument("id", required=False)
@click.option(
"-i",
"--input",
type=click.Path(exists=True, readable=True, allow_dash=True),
help="File to embed",
)
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
@click.option(
"-c",
"--content",
help="Content to embed",
)
@click.option("--binary", is_flag=True, help="Treat input as binary data")
@click.option(
"--metadata",
help="JSON object metadata to store",
callback=_validate_metadata_json,
)
@click.option(
"format_",
"-f",
"--format",
type=click.Choice(["json", "blob", "base64", "hex"]),
help="Output format",
)
def embed(
collection, id, input, model, store, database, content, binary, metadata, format_
):
"""Embed text and store or return the result"""
if collection and not id:
raise click.ClickException("Must provide both collection and id")
if store and not collection:
raise click.ClickException("Must provide collection when using --store")
# Lazy load this because we do not need it for -c or -i versions
def get_db():
if database:
return sqlite_utils.Database(database)
else:
return sqlite_utils.Database(user_dir() / "embeddings.db")
collection_obj = None
model_obj = None
if collection:
db = get_db()
if Collection.exists(db, collection):
# Load existing collection and use its model
collection_obj = Collection(collection, db)
model_obj = collection_obj.model()
else:
# We will create a new one, but that means model is required
if not model:
model = get_default_embedding_model()
if model is None:
raise click.ClickException(
"You need to specify an embedding model (no default model is set)"
)
collection_obj = Collection(collection, db=db, model_id=model)
model_obj = collection_obj.model()
if model_obj is None:
if model is None:
model = get_default_embedding_model()
try:
model_obj = get_embedding_model(model)
except UnknownModelError:
raise click.ClickException(
"You need to specify an embedding model (no default model is set)"
)
show_output = True
if collection and (format_ is None):
show_output = False
# Resolve input text
if not content:
if not input or input == "-":
# Read from stdin
input_source = sys.stdin.buffer if binary else sys.stdin
content = input_source.read()
else:
mode = "rb" if binary else "r"
with open(input, mode) as f:
content = f.read()
if not content:
raise click.ClickException("No content provided")
if collection_obj:
embedding = collection_obj.embed(id, content, metadata=metadata, store=store)
else:
embedding = model_obj.embed(content)
if show_output:
if format_ == "json" or format_ is None:
click.echo(json.dumps(embedding))
elif format_ == "blob":
click.echo(encode(embedding))
elif format_ == "base64":
click.echo(base64.b64encode(encode(embedding)).decode("ascii"))
elif format_ == "hex":
click.echo(encode(embedding).hex())
@cli.command()
@click.argument("collection")
@click.argument(
"input_path",
type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True),
required=False,
)
@click.option(
"--format",
type=click.Choice(["json", "csv", "tsv", "nl"]),
help="Format of input file - defaults to auto-detect",
)
@click.option(
"--files",
type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str),
multiple=True,
help="Embed files in this directory - specify directory and glob pattern",
)
@click.option(
"encodings",
"--encoding",
help="Encoding to use when reading --files",
multiple=True,
)
@click.option("--binary", is_flag=True, help="Treat --files as binary data")
@click.option("--sql", help="Read input using this SQL query")
@click.option(
"--attach",
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
multiple=True,
help="Additional databases to attach - specify alias and file path",
)
@click.option(
"--batch-size", type=int, help="Batch size to use when running embeddings"
)
@click.option("--prefix", help="Prefix to add to the IDs", default="")
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
def embed_multi(
collection,
input_path,
format,
files,
encodings,
binary,
sql,
attach,
batch_size,
prefix,
model,
store,
database,
):
"""
Store embeddings for multiple strings at once
Input can be CSV, TSV or a JSON list of objects.
The first column is treated as an ID - all other columns
are assumed to be text that should be concatenated together
in order to calculate the embeddings.
Input data can come from one of three sources:
\b
1. A CSV, JSON, TSV or JSON-nl file (including on standard input)
2. A SQL query against a SQLite database
3. A directory of files
"""
if binary and not files:
raise click.UsageError("--binary must be used with --files")
if binary and encodings:
raise click.UsageError("--binary cannot be used with --encoding")
if not input_path and not sql and not files:
raise click.UsageError("Either --sql or input path or --files is required")
if files:
if input_path or sql or format:
raise click.UsageError(
"Cannot use --files with --sql, input path or --format"
)
if database:
db = sqlite_utils.Database(database)
else:
db = sqlite_utils.Database(user_dir() / "embeddings.db")
for alias, attach_path in attach:
db.attach(alias, attach_path)
try:
collection_obj = Collection(
collection, db=db, model_id=model or get_default_embedding_model()
)
except ValueError:
raise click.ClickException(
"You need to specify an embedding model (no default model is set)"
)
expected_length = None
if files:
encodings = encodings or ("utf-8", "latin-1")
def count_files():
i = 0
for directory, pattern in files:
for path in pathlib.Path(directory).glob(pattern):
i += 1
return i
def iterate_files():
for directory, pattern in files:
p = pathlib.Path(directory)
if not p.exists() or not p.is_dir():
# fixes issue/274 - raise error if directory does not exist
raise click.UsageError(f"Invalid directory: {directory}")
for path in pathlib.Path(directory).glob(pattern):
if path.is_dir():
continue # fixed issue/280 - skip directories
relative = path.relative_to(directory)
content = None
if binary:
content = path.read_bytes()
else:
for encoding in encodings:
try:
content = path.read_text(encoding=encoding)
except UnicodeDecodeError:
continue
if content is None:
# Log to stderr
click.echo(
"Could not decode text in file {}".format(path),
err=True,
)
else:
yield {"id": str(relative), "content": content}
expected_length = count_files()
rows = iterate_files()
elif sql:
rows = db.query(sql)
count_sql = "select count(*) as c from ({})".format(sql)
expected_length = next(db.query(count_sql))["c"]
else:
def load_rows(fp):
return rows_from_file(fp, Format[format.upper()] if format else None)[0]
try:
if input_path != "-":
# Read the file twice - first time is to get a count
expected_length = 0
with open(input_path, "rb") as fp:
for _ in load_rows(fp):
expected_length += 1
rows = load_rows(
open(input_path, "rb")
if input_path != "-"
else io.BufferedReader(sys.stdin.buffer)
)
except json.JSONDecodeError as ex:
raise click.ClickException(str(ex))
with click.progressbar(
rows, label="Embedding", show_percent=True, length=expected_length
) as rows:
def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:
for row in rows:
values = list(row.values())
id = prefix + str(values[0])
if binary:
yield id, cast(bytes, values[1])
else:
yield id, " ".join(v or "" for v in values[1:])
embed_kwargs = {"store": store}
if batch_size:
embed_kwargs["batch_size"] = batch_size
collection_obj.embed_multi(tuples(), **embed_kwargs)
@cli.command()
@click.argument("collection")
@click.argument("id", required=False)
@click.option(
"-i",
"--input",
type=click.Path(exists=True, readable=True, allow_dash=True),
help="File to embed for comparison",
)
@click.option("-c", "--content", help="Content to embed for comparison")
@click.option("--binary", is_flag=True, help="Treat input as binary data")
@click.option(
"-n", "--number", type=int, default=10, help="Number of results to return"
)
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
def similar(collection, id, input, content, binary, number, database):
"""
Return top N similar IDs from a collection
Example usage:
\b
llm similar my-collection -c "I like cats"
Or to find content similar to a specific stored ID:
\b
llm similar my-collection 1234
"""
if not id and not content and not input:
raise click.ClickException("Must provide content or an ID for the comparison")
if database:
db = sqlite_utils.Database(database)
else:
db = sqlite_utils.Database(user_dir() / "embeddings.db")
if not db["embeddings"].exists():
raise click.ClickException("No embeddings table found in database")
try:
collection_obj = Collection(collection, db, create=False)
except Collection.DoesNotExist:
raise click.ClickException("Collection does not exist")
if id:
try:
results = collection_obj.similar_by_id(id, number)
except Collection.DoesNotExist:
raise click.ClickException("ID not found in collection")
else:
# Resolve input text
if not content:
if not input or input == "-":
# Read from stdin
input_source = sys.stdin.buffer if binary else sys.stdin
content = input_source.read()
else:
mode = "rb" if binary else "r"
with open(input, mode) as f:
content = f.read()
if not content:
raise click.ClickException("No content provided")
results = collection_obj.similar(content, number)
for result in results:
click.echo(json.dumps(asdict(result)))
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def embed_models():
"Manage available embedding models"
@embed_models.command(name="list")
def embed_models_list():
"List available embedding models"
output = []
for model_with_aliases in get_embedding_models_with_aliases():
s = str(model_with_aliases.model.model_id)
if model_with_aliases.aliases:
s += " (aliases: {})".format(", ".join(model_with_aliases.aliases))
output.append(s)
click.echo("\n".join(output))
@embed_models.command(name="default")
@click.argument("model", required=False)
@click.option(
"--remove-default", is_flag=True, help="Reset to specifying no default model"
)
def embed_models_default(model, remove_default):
"Show or set the default embedding model"
if not model and not remove_default:
default = get_default_embedding_model()
if default is None:
click.echo("<No default embedding model set>", err=True)
else:
click.echo(default)
return
# Validate it is a known model
try:
if remove_default:
set_default_embedding_model(None)
else:
model = get_embedding_model(model)
set_default_embedding_model(model.model_id)
except KeyError:
raise click.ClickException("Unknown embedding model: {}".format(model))
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def collections():
"View and manage collections of embeddings"
@collections.command(name="path")
def collections_path():
"Output the path to the embeddings database"
click.echo(user_dir() / "embeddings.db")
@collections.command(name="list")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
help="Path to embeddings database",
)
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
def embed_db_collections(database, json_):
"View a list of collections"
database = database or (user_dir() / "embeddings.db")
db = sqlite_utils.Database(str(database))
if not db["collections"].exists():
raise click.ClickException("No collections table found in {}".format(database))
rows = db.query(
"""
select
collections.name,
collections.model,
count(embeddings.id) as num_embeddings
from
collections left join embeddings
on collections.id = embeddings.collection_id
group by
collections.name, collections.model
"""
)
if json_:
click.echo(json.dumps(list(rows), indent=4))
else:
for row in rows:
click.echo("{}: {}".format(row["name"], row["model"]))
click.echo(
" {} embedding{}".format(
row["num_embeddings"], "s" if row["num_embeddings"] != 1 else ""
)
)
@collections.command(name="delete")
@click.argument("collection")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
help="Path to embeddings database",
)
def collections_delete(collection, database):
"""
Delete the specified collection
Example usage:
\b
llm collections delete my-collection
"""
database = database or (user_dir() / "embeddings.db")
db = sqlite_utils.Database(str(database))
try:
collection_obj = Collection(collection, db, create=False)
except Collection.DoesNotExist:
raise click.ClickException("Collection does not exist")
collection_obj.delete()
def template_dir():
path = user_dir() / "templates"
path.mkdir(parents=True, exist_ok=True)
return path
def _truncate_string(s, max_length=100):
if len(s) > max_length:
return s[: max_length - 3] + "..."
return s
def logs_db_path():
return user_dir() / "logs.db"
def load_template(name):
path = template_dir() / f"{name}.yaml"
if not path.exists():
raise click.ClickException(f"Invalid template: {name}")
try:
loaded = yaml.safe_load(path.read_text())
except yaml.YAMLError as ex:
raise click.ClickException("Invalid YAML: {}".format(str(ex)))
if isinstance(loaded, str):
return Template(name=name, prompt=loaded)
loaded["name"] = name
try:
return Template(**loaded)
except pydantic.ValidationError as ex:
msg = "A validation error occurred:\n"
msg += render_errors(ex.errors())
raise click.ClickException(msg)
def get_history(chat_id):
if chat_id is None:
return None, []
log_path = logs_db_path()
db = sqlite_utils.Database(log_path)
migrate(db)
if chat_id == -1:
# Return the most recent chat
last_row = list(db["logs"].rows_where(order_by="-id", limit=1))
if last_row:
chat_id = last_row[0].get("chat_id") or last_row[0].get("id")
else: # Database is empty
return None, []
rows = db["logs"].rows_where(
"id = ? or chat_id = ?", [chat_id, chat_id], order_by="id"
)
return chat_id, rows
def render_errors(errors):
output = []
for error in errors:
output.append(", ".join(error["loc"]))
output.append(" " + error["msg"])
return "\n".join(output)
load_plugins()
pm.hook.register_commands(cli=cli)
def _human_readable_size(size_bytes):
if size_bytes == 0:
return "0B"
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
i = 0
while size_bytes >= 1024 and i < len(size_name) - 1:
size_bytes /= 1024.0
i += 1
return "{:.2f}{}".format(size_bytes, size_name[i])
def logs_on():
return not (user_dir() / "logs-off").exists()
</document_content>
</document>
<document index="5">
<source>/Users/simon/Dropbox/Development/llm/llm/embeddings.py</source>
<document_content>
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
from dataclasses import dataclass
import hashlib
from itertools import islice
import json
from sqlite_utils import Database
from sqlite_utils.db import Table
import time
from typing import cast, Any, Dict, Iterable, List, Optional, Tuple, Union
@dataclass
class Entry:
id: str
score: Optional[float]
content: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
class Collection:
class DoesNotExist(Exception):
pass
def __init__(
self,
name: str,
db: Optional[Database] = None,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
create: bool = True,
) -> None:
"""
A collection of embeddings
Returns the collection with the given name, creating it if it does not exist.
If you set create=False a Collection.DoesNotExist exception will be raised if the
collection does not already exist.
Args:
db (sqlite_utils.Database): Database to store the collection in
name (str): Name of the collection
model (llm.models.EmbeddingModel, optional): Embedding model to use
model_id (str, optional): Alternatively, ID of the embedding model to use
create (bool, optional): Whether to create the collection if it does not exist
"""
import llm
self.db = db or Database(memory=True)
self.name = name
self._model = model
embeddings_migrations.apply(self.db)
rows = list(self.db["collections"].rows_where("name = ?", [self.name]))
if rows:
row = rows[0]
self.id = row["id"]
self.model_id = row["model"]
else:
if create:
# Collection does not exist, so model or model_id is required
if not model and not model_id:
raise ValueError(
"Either model= or model_id= must be provided when creating a new collection"
)
# Create it
if model_id:
# Resolve alias
model = llm.get_embedding_model(model_id)
self._model = model
model_id = cast(EmbeddingModel, model).model_id
self.id = (
cast(Table, self.db["collections"])
.insert(
{
"name": self.name,
"model": model_id,
}
)
.last_pk
)
else:
raise self.DoesNotExist(f"Collection '{name}' does not exist")
def model(self) -> EmbeddingModel:
"Return the embedding model used by this collection"
import llm
if self._model is None:
self._model = llm.get_embedding_model(self.model_id)
return cast(EmbeddingModel, self._model)
def count(self) -> int:
"""
Count the number of items in the collection.
Returns:
int: Number of items in the collection
"""
return next(
self.db.query(
"""
select count(*) as c from embeddings where collection_id = (
select id from collections where name = ?
)
""",
(self.name,),
)
)["c"]
def embed(
self,
id: str,
value: Union[str, bytes],
metadata: Optional[Dict[str, Any]] = None,
store: bool = False,
) -> None:
"""
Embed value and store it in the collection with a given ID.
Args:
id (str): ID for the value
value (str or bytes): value to be embedded
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the value in the content or content_blob column
"""
from llm import encode
content_hash = self.content_hash(value)
if self.db["embeddings"].count_where(
"content_hash = ? and collection_id = ?", [content_hash, self.id]
):
return
embedding = self.model().embed(value)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id,
"id": id,
"embedding": encode(embedding),
"content": value if (store and isinstance(value, str)) else None,
"content_blob": value if (store and isinstance(value, bytes)) else None,
"content_hash": content_hash,
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
},
replace=True,
)
def embed_multi(
self,
entries: Iterable[Tuple[str, Union[str, bytes]]],
store: bool = False,
batch_size: int = 100,
) -> None:
"""
Embed multiple texts and store them in the collection with given IDs.
Args:
entries (iterable): Iterable of (id: str, text: str) tuples
store (bool, optional): Whether to store the text in the content column
batch_size (int, optional): custom maximum batch size to use
"""
self.embed_multi_with_metadata(
((id, value, None) for id, value in entries),
store=store,
batch_size=batch_size,
)
def embed_multi_with_metadata(
self,
entries: Iterable[Tuple[str, Union[str, bytes], Optional[Dict[str, Any]]]],
store: bool = False,
batch_size: int = 100,
) -> None:
"""
Embed multiple values along with metadata and store them in the collection with given IDs.
Args:
entries (iterable): Iterable of (id: str, value: str or bytes, metadata: None or dict)
store (bool, optional): Whether to store the value in the content or content_blob column
batch_size (int, optional): custom maximum batch size to use
"""
import llm
batch_size = min(batch_size, (self.model().batch_size or batch_size))
iterator = iter(entries)
collection_id = self.id
while True:
batch = list(islice(iterator, batch_size))
if not batch:
break
# Calculate hashes first
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
# Any of those hashes already exist?
existing_ids = [
row["id"]
for row in self.db.query(
"""
select id from embeddings
where collection_id = ? and content_hash in ({})
""".format(
",".join("?" for _ in items_and_hashes)
),
[collection_id]
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
)
]
filtered_batch = [item for item in batch if item[0] not in existing_ids]
embeddings = list(
self.model().embed_multi(item[1] for item in filtered_batch)
)
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
{
"collection_id": collection_id,
"id": id,
"embedding": llm.encode(embedding),
"content": (
value if (store and isinstance(value, str)) else None
),
"content_blob": (
value if (store and isinstance(value, bytes)) else None
),
"content_hash": self.content_hash(value),
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
}
for (embedding, (id, value, metadata)) in zip(
embeddings, filtered_batch
)
),
replace=True,
)
def similar_by_vector(
self, vector: List[float], number: int = 10, skip_id: Optional[str] = None
) -> List[Entry]:
"""
Find similar items in the collection by a given vector.
Args:
vector (list): Vector to search by
number (int, optional): Number of similar items to return
Returns:
list: List of Entry objects
"""
import llm
def distance_score(other_encoded):
other_vector = llm.decode(other_encoded)
return llm.cosine_similarity(other_vector, vector)
self.db.register_function(distance_score, replace=True)
where_bits = ["collection_id = ?"]
where_args = [str(self.id)]
if skip_id:
where_bits.append("id != ?")
where_args.append(skip_id)
return [
Entry(
id=row["id"],
score=row["score"],
content=row["content"],
metadata=json.loads(row["metadata"]) if row["metadata"] else None,
)
for row in self.db.query(
"""
select id, content, metadata, distance_score(embedding) as score
from embeddings
where {where}
order by score desc limit {number}
""".format(
where=" and ".join(where_bits),
number=number,
),
where_args,
)
]
def similar_by_id(self, id: str, number: int = 10) -> List[Entry]:
"""
Find similar items in the collection by a given ID.
Args:
id (str): ID to search by
number (int, optional): Number of similar items to return
Returns:
list: List of Entry objects
"""
import llm
matches = list(
self.db["embeddings"].rows_where(
"collection_id = ? and id = ?", (self.id, id)
)
)
if not matches:
raise self.DoesNotExist("ID not found")
embedding = matches[0]["embedding"]
comparison_vector = llm.decode(embedding)
return self.similar_by_vector(comparison_vector, number, skip_id=id)
def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]:
"""
Find similar items in the collection by a given value.
Args:
value (str or bytes): value to search by
number (int, optional): Number of similar items to return
Returns:
list: List of Entry objects
"""
comparison_vector = self.model().embed(value)
return self.similar_by_vector(comparison_vector, number)
@classmethod
def exists(cls, db: Database, name: str) -> bool:
"""
Does this collection exist in the database?
Args:
name (str): Name of the collection
"""
rows = list(db["collections"].rows_where("name = ?", [name]))
return bool(rows)
def delete(self):
"""
Delete the collection and its embeddings from the database
"""
with self.db.conn:
self.db.execute("delete from embeddings where collection_id = ?", [self.id])
self.db.execute("delete from collections where id = ?", [self.id])
@staticmethod
def content_hash(input: Union[str, bytes]) -> bytes:
"Hash content for deduplication. Override to change hashing behavior."
if isinstance(input, str):
input = input.encode("utf8")
return hashlib.md5(input).digest()
</document_content>
</document>
<document index="6">
<source>/Users/simon/Dropbox/Development/llm/llm/embeddings_migrations.py</source>
<document_content>
from sqlite_migrate import Migrations
import hashlib
import time
embeddings_migrations = Migrations("llm.embeddings")
@embeddings_migrations()
def m001_create_tables(db):
db["collections"].create({"id": int, "name": str, "model": str}, pk="id")
db["collections"].create_index(["name"], unique=True)
db["embeddings"].create(
{
"collection_id": int,
"id": str,
"embedding": bytes,
"content": str,
"metadata": str,
},
pk=("collection_id", "id"),
)
@embeddings_migrations()
def m002_foreign_key(db):
db["embeddings"].add_foreign_key("collection_id", "collections", "id")
@embeddings_migrations()
def m003_add_updated(db):
db["embeddings"].add_column("updated", int)
# Pretty-print the schema
db["embeddings"].transform()
# Assume anything existing was last updated right now
db.query(
"update embeddings set updated = ? where updated is null", [int(time.time())]
)
@embeddings_migrations()
def m004_store_content_hash(db):
db["embeddings"].add_column("content_hash", bytes)
db["embeddings"].transform(
column_order=(
"collection_id",
"id",
"embedding",
"content",
"content_hash",
"metadata",
"updated",
)
)
# Register functions manually so we can de-register later
def md5(text):
return hashlib.md5(text.encode("utf8")).digest()
def random_md5():
return hashlib.md5(str(time.time()).encode("utf8")).digest()
db.conn.create_function("temp_md5", 1, md5)
db.conn.create_function("temp_random_md5", 0, random_md5)
with db.conn:
db.execute(
"""
update embeddings
set content_hash = temp_md5(content)
where content is not null
"""
)
db.execute(
"""
update embeddings
set content_hash = temp_random_md5()
where content is null
"""
)
db["embeddings"].create_index(["content_hash"])
# De-register functions
db.conn.create_function("temp_md5", 1, None)
db.conn.create_function("temp_random_md5", 0, None)
@embeddings_migrations()
def m005_add_content_blob(db):
db["embeddings"].add_column("content_blob", bytes)
db["embeddings"].transform(
column_order=("collection_id", "id", "embedding", "content", "content_blob")
)
</document_content>
</document>
<document index="7">
<source>/Users/simon/Dropbox/Development/llm/llm/errors.py</source>
<document_content>
class ModelError(Exception):
"Models can raise this error, which will be displayed to the user"
pass
class NeedsKeyException(ModelError):
"Model needs an API key which has not been provided"
pass
</document_content>
</document>
<document index="8">
<source>/Users/simon/Dropbox/Development/llm/llm/hookspecs.py</source>
<document_content>
from pluggy import HookimplMarker
from pluggy import HookspecMarker
hookspec = HookspecMarker("llm")
hookimpl = HookimplMarker("llm")
@hookspec
def register_commands(cli):
"""Register additional CLI commands, e.g. 'llm mycommand ...'"""
@hookspec
def register_models(register):
"Register additional model instances representing LLM models that can be called"
@hookspec
def register_embedding_models(register):
"Register additional model instances that can be used for embedding"
</document_content>
</document>
<document index="9">
<source>/Users/simon/Dropbox/Development/llm/llm/migrations.py</source>
<document_content>
import datetime
from typing import Callable, List
MIGRATIONS: List[Callable] = []
migration = MIGRATIONS.append
def migrate(db):
ensure_migrations_table(db)
already_applied = {r["name"] for r in db["_llm_migrations"].rows}
for fn in MIGRATIONS:
name = fn.__name__
if name not in already_applied:
fn(db)
db["_llm_migrations"].insert(
{
"name": name,
"applied_at": str(datetime.datetime.now(datetime.timezone.utc)),
}
)
already_applied.add(name)
def ensure_migrations_table(db):
if not db["_llm_migrations"].exists():
db["_llm_migrations"].create(
{
"name": str,
"applied_at": str,
},
pk="name",
)
@migration
def m001_initial(db):
# Ensure the original table design exists, so other migrations can run
if db["log"].exists():
# It needs to have the chat_id column
if "chat_id" not in db["log"].columns_dict:
db["log"].add_column("chat_id")
return
db["log"].create(
{
"provider": str,
"system": str,
"prompt": str,
"chat_id": str,
"response": str,
"model": str,
"timestamp": str,
}
)
@migration
def m002_id_primary_key(db):
db["log"].transform(pk="id")
@migration
def m003_chat_id_foreign_key(db):
db["log"].transform(types={"chat_id": int})
db["log"].add_foreign_key("chat_id", "log", "id")
@migration
def m004_column_order(db):
db["log"].transform(
column_order=(
"id",
"model",
"timestamp",
"prompt",
"system",
"response",
"chat_id",
)
)
@migration
def m004_drop_provider(db):
db["log"].transform(drop=("provider",))
@migration
def m005_debug(db):
db["log"].add_column("debug", str)
db["log"].add_column("duration_ms", int)
@migration
def m006_new_logs_table(db):
columns = db["log"].columns_dict
for column, type in (
("options_json", str),
("prompt_json", str),
("response_json", str),
("reply_to_id", int),
):
# It's possible people running development code like myself
# might have accidentally created these columns already
if column not in columns:
db["log"].add_column(column, type)
# Use .transform() to rename options and timestamp_utc, and set new order
db["log"].transform(
column_order=(
"id",
"model",
"prompt",
"system",
"prompt_json",
"options_json",
"response",
"response_json",
"reply_to_id",
"chat_id",
"duration_ms",
"timestamp_utc",
),
rename={
"timestamp": "timestamp_utc",
"options": "options_json",
},
)
@migration
def m007_finish_logs_table(db):
db["log"].transform(
drop={"debug"},
rename={"timestamp_utc": "datetime_utc"},
drop_foreign_keys=("chat_id",),
)
with db.conn:
db.execute("alter table log rename to logs")
@migration
def m008_reply_to_id_foreign_key(db):
db["logs"].add_foreign_key("reply_to_id", "logs", "id")
@migration
def m008_fix_column_order_in_logs(db):
# reply_to_id ended up at the end after foreign key added
db["logs"].transform(
column_order=(
"id",
"model",
"prompt",
"system",
"prompt_json",
"options_json",
"response",
"response_json",
"reply_to_id",
"chat_id",
"duration_ms",
"timestamp_utc",
),
)
@migration
def m009_delete_logs_table_if_empty(db):
# We moved to a new table design, but we don't delete the table
# if someone has put data in it
if not db["logs"].count:
db["logs"].drop()
@migration
def m010_create_new_log_tables(db):
db["conversations"].create(
{
"id": str,
"name": str,
"model": str,
},
pk="id",
)
db["responses"].create(
{
"id": str,
"model": str,
"prompt": str,
"system": str,
"prompt_json": str,
"options_json": str,
"response": str,
"response_json": str,
"conversation_id": str,
"duration_ms": int,
"datetime_utc": str,
},
pk="id",
foreign_keys=(("conversation_id", "conversations", "id"),),
)
@migration
def m011_fts_for_responses(db):
db["responses"].enable_fts(["prompt", "response"], create_triggers=True)
@migration
def m012_attachments_tables(db):
db["attachments"].create(
{
"id": str,
"type": str,
"path": str,
"url": str,
"content": bytes,
},
pk="id",
)
db["prompt_attachments"].create(
{
"response_id": str,
"attachment_id": str,
"order": int,
},
foreign_keys=(
("response_id", "responses", "id"),
("attachment_id", "attachments", "id"),
),
pk=("response_id", "attachment_id"),
)
@migration
def m013_usage(db):
db["responses"].add_column("input_tokens", int)
db["responses"].add_column("output_tokens", int)
db["responses"].add_column("token_details", str)
</document_content>
</document>
<document index="10">
<source>/Users/simon/Dropbox/Development/llm/llm/models.py</source>
<document_content>
import asyncio
import base64
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import hashlib
import httpx
from itertools import islice
import re
import time
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
)
from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
from ulid import ULID
CONVERSATION_NAME_LENGTH = 32
@dataclass
class Usage:
input: Optional[int] = None
output: Optional[int] = None
details: Optional[Dict[str, Any]] = None
@dataclass
class Attachment:
type: Optional[str] = None
path: Optional[str] = None
url: Optional[str] = None
content: Optional[bytes] = None
_id: Optional[str] = None
def id(self):
# Hash of the binary content, or of '{"url": "https://..."}' for URL attachments
if self._id is None:
if self.content:
self._id = hashlib.sha256(self.content).hexdigest()
elif self.path:
self._id = hashlib.sha256(open(self.path, "rb").read()).hexdigest()
else:
self._id = hashlib.sha256(
json.dumps({"url": self.url}).encode("utf-8")
).hexdigest()
return self._id
def resolve_type(self):
if self.type:
return self.type
# Derive it from path or url or content
if self.path:
return mimetype_from_path(self.path)
if self.url:
response = httpx.head(self.url)
response.raise_for_status()
return response.headers.get("content-type")
if self.content:
return mimetype_from_string(self.content)
raise ValueError("Attachment has no type and no content to derive it from")
def content_bytes(self):
content = self.content
if not content:
if self.path:
content = open(self.path, "rb").read()
elif self.url:
response = httpx.get(self.url)
response.raise_for_status()
content = response.content
return content
def base64_content(self):
return base64.b64encode(self.content_bytes()).decode("utf-8")
@classmethod
def from_row(cls, row):
return cls(
_id=row["id"],
type=row["type"],
path=row["path"],
url=row["url"],
content=row["content"],
)
@dataclass
class Prompt:
prompt: str
model: "Model"
attachments: Optional[List[Attachment]]
system: Optional[str]
prompt_json: Optional[str]
options: "Options"
def __init__(
self,
prompt,
model,
*,
attachments=None,
system=None,
prompt_json=None,
options=None,
):
self.prompt = prompt
self.model = model
self.attachments = list(attachments or [])
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
@dataclass
class _BaseConversation:
model: "_BaseModel"
id: str = field(default_factory=lambda: str(ULID()).lower())
name: Optional[str] = None
responses: List["_BaseResponse"] = field(default_factory=list)
@classmethod
def from_row(cls, row):
from llm import get_model
return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
)
@dataclass
class Conversation(_BaseConversation):
def prompt(
self,
prompt: Optional[str],
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> "Response":
return Response(
Prompt(
prompt,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
stream,
conversation=self,
)
@dataclass
class AsyncConversation(_BaseConversation):
def prompt(
self,
prompt: Optional[str],
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> "AsyncResponse":
return AsyncResponse(
Prompt(
prompt,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
stream,
conversation=self,
)
class _BaseResponse:
"""Base response class shared between sync and async responses"""
prompt: "Prompt"
stream: bool
conversation: Optional["_BaseConversation"] = None
def __init__(
self,
prompt: Prompt,
model: "_BaseModel",
stream: bool,
conversation: Optional[_BaseConversation] = None,
):
self.prompt = prompt
self._prompt_json = None
self.model = model
self.stream = stream
self._chunks: List[str] = []
self._done = False
self.response_json = None
self.conversation = conversation
self.attachments: List[Attachment] = []
self._start: Optional[float] = None
self._end: Optional[float] = None
self._start_utcnow: Optional[datetime.datetime] = None
self.input_tokens: Optional[int] = None
self.output_tokens: Optional[int] = None
self.token_details: Optional[dict] = None
self.done_callbacks: List[Callable] = []
def set_usage(
self,
*,
input: Optional[int] = None,
output: Optional[int] = None,
details: Optional[dict] = None,
):
self.input_tokens = input
self.output_tokens = output
self.token_details = details
@classmethod
def from_row(cls, db, row):
from llm import get_model
model = get_model(row["model"])
response = cls(
model=model,
prompt=Prompt(
prompt=row["prompt"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
)
response.id = row["id"]
response._prompt_json = json.loads(row["prompt_json"] or "null")
response.response_json = json.loads(row["response_json"] or "null")
response._done = True
response._chunks = [row["response"]]
# Attachments
response.attachments = [
Attachment.from_row(arow)
for arow in db.query(
"""
select attachments.* from attachments
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
where prompt_attachments.response_id = ?
order by prompt_attachments."order"
""",
[row["id"]],
)
]
return response
def token_usage(self) -> str:
return token_usage_string(
self.input_tokens, self.output_tokens, self.token_details
)
def log_to_db(self, db):
conversation = self.conversation
if not conversation:
conversation = Conversation(model=self.model)
db["conversations"].insert(
{
"id": conversation.id,
"name": _conversation_name(
self.prompt.prompt or self.prompt.system or ""
),
"model": conversation.model.model_id,
},
ignore=True,
)
response_id = str(ULID()).lower()
response = {
"id": response_id,
"model": self.model.model_id,
"prompt": self.prompt.prompt,
"system": self.prompt.system,
"prompt_json": self._prompt_json,
"options_json": {
key: value
for key, value in dict(self.prompt.options).items()
if value is not None
},
"response": self.text_or_raise(),
"response_json": self.json(),
"conversation_id": conversation.id,
"duration_ms": self.duration_ms(),
"datetime_utc": self.datetime_utc(),
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"token_details": (
json.dumps(self.token_details) if self.token_details else None
),
}
db["responses"].insert(response)
# Persist any attachments - loop through with index
for index, attachment in enumerate(self.prompt.attachments):
attachment_id = attachment.id()
db["attachments"].insert(
{
"id": attachment_id,
"type": attachment.resolve_type(),
"path": attachment.path,
"url": attachment.url,
"content": attachment.content,
},
replace=True,
)
db["prompt_attachments"].insert(
{
"response_id": response_id,
"attachment_id": attachment_id,
"order": index,
},
)
class Response(_BaseResponse):
model: "Model"
conversation: Optional["Conversation"] = None
def on_done(self, callback):
if not self._done:
self.done_callbacks.append(callback)
else:
callback(self)
def _on_done(self):
for callback in self.done_callbacks:
callback(self)
def __str__(self) -> str:
return self.text()
def _force(self):
if not self._done:
list(self)
def text(self) -> str:
self._force()
return "".join(self._chunks)
def text_or_raise(self) -> str:
return self.text()
def json(self) -> Optional[Dict[str, Any]]:
self._force()
return self.response_json
def duration_ms(self) -> int:
self._force()
return int(((self._end or 0) - (self._start or 0)) * 1000)
def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""
def usage(self) -> Usage:
self._force()
return Usage(
input=self.input_tokens,
output=self.output_tokens,
details=self.token_details,
)
def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)
if self._done:
yield from self._chunks
return
for chunk in self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
):
yield chunk
self._chunks.append(chunk)
if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
self._on_done()
def __repr__(self):
text = "... not yet done ..."
if self._done:
text = "".join(self._chunks)
return "<Response prompt='{}' text='{}'>".format(self.prompt.prompt, text)
class AsyncResponse(_BaseResponse):
model: "AsyncModel"
conversation: Optional["AsyncConversation"] = None
async def on_done(self, callback):
if not self._done:
self.done_callbacks.append(callback)
else:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
await callback
async def _on_done(self):
for callback in self.done_callbacks:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
await callback
def __aiter__(self):
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)
return self
async def __anext__(self) -> str:
if self._done:
if not self._chunks:
raise StopAsyncIteration
chunk = self._chunks.pop(0)
if not self._chunks:
raise StopAsyncIteration
return chunk
if not hasattr(self, "_generator"):
self._generator = self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
)
try:
chunk = await self._generator.__anext__()
self._chunks.append(chunk)
return chunk
except StopAsyncIteration:
if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
await self._on_done()
raise
async def _force(self):
if not self._done:
async for _ in self:
pass
return self
def text_or_raise(self) -> str:
if not self._done:
raise ValueError("Response not yet awaited")
return "".join(self._chunks)
async def text(self) -> str:
await self._force()
return "".join(self._chunks)
async def json(self) -> Optional[Dict[str, Any]]:
await self._force()
return self.response_json
async def duration_ms(self) -> int:
await self._force()
return int(((self._end or 0) - (self._start or 0)) * 1000)
async def datetime_utc(self) -> str:
await self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""
async def usage(self) -> Usage:
await self._force()
return Usage(
input=self.input_tokens,
output=self.output_tokens,
details=self.token_details,
)
def __await__(self):
return self._force().__await__()
async def to_sync_response(self) -> Response:
await self._force()
response = Response(
self.prompt,
self.model,
self.stream,
conversation=self.conversation,
)
response._chunks = self._chunks
response._done = True
response._end = self._end
response._start = self._start
response._start_utcnow = self._start_utcnow
response.input_tokens = self.input_tokens
response.output_tokens = self.output_tokens
response.token_details = self.token_details
return response
@classmethod
def fake(
cls,
model: "AsyncModel",
prompt: str,
*attachments: List[Attachment],
system: str,
response: str,
):
"Utility method to help with writing tests"
response_obj = cls(
model=model,
prompt=Prompt(
prompt,
model=model,
attachments=attachments,
system=system,
),
stream=False,
)
response_obj._done = True
response_obj._chunks = [response]
return response_obj
def __repr__(self):
text = "... not yet awaited ..."
if self._done:
text = "".join(self._chunks)
return "<AsyncResponse prompt='{}' text='{}'>".format(self.prompt.prompt, text)
class Options(BaseModel):
# Note: using pydantic v1 style Configs,
# these are also compatible with pydantic v2
class Config:
extra = "forbid"
_Options = Options
class _get_key_mixin:
def get_key(self):
from llm import get_key
if self.needs_key is None:
# This model doesn't use an API key
return None
if self.key is not None:
# Someone already set model.key='...'
return self.key
# Attempt to load a key using llm.get_key()
key = get_key(
explicit_key=None, key_alias=self.needs_key, env_var=self.key_env_var
)
if key:
return key
# Show a useful error message
message = "No key found - add one using 'llm keys set {}'".format(
self.needs_key
)
if self.key_env_var:
message += " or set the {} environment variable".format(self.key_env_var)
raise NeedsKeyException(message)
class _BaseModel(ABC, _get_key_mixin):
model_id: str
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
attachment_types: Set = set()
class Options(_Options):
pass
def _validate_attachments(
self, attachments: Optional[List[Attachment]] = None
) -> None:
if attachments and not self.attachment_types:
raise ValueError("This model does not support attachments")
for attachment in attachments or []:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
f"This model does not support attachments of type '{attachment_type}', "
f"only {', '.join(self.attachment_types)}"
)
def __str__(self) -> str:
return "{}: {}".format(self.__class__.__name__, self.model_id)
def __repr__(self):
return "<{} '{}'>".format(self.__class__.__name__, self.model_id)
class Model(_BaseModel):
def conversation(self) -> Conversation:
return Conversation(model=self)
@abstractmethod
def execute(
self,
prompt: Prompt,
stream: bool,
response: Response,
conversation: Optional[Conversation],
) -> Iterator[str]:
pass
def prompt(
self,
prompt: str,
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> Response:
self._validate_attachments(attachments)
return Response(
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
self,
stream,
)
class AsyncModel(_BaseModel):
def conversation(self) -> AsyncConversation:
return AsyncConversation(model=self)
@abstractmethod
async def execute(
self,
prompt: Prompt,
stream: bool,
response: AsyncResponse,
conversation: Optional[AsyncConversation],
) -> AsyncGenerator[str, None]:
yield ""
def prompt(
self,
prompt: str,
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> AsyncResponse:
self._validate_attachments(attachments)
return AsyncResponse(
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
self,
stream,
)
class EmbeddingModel(ABC, _get_key_mixin):
model_id: str
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
supports_text: bool = True
supports_binary: bool = False
batch_size: Optional[int] = None
def _check(self, item: Union[str, bytes]):
if not self.supports_binary and isinstance(item, bytes):
raise ValueError(
"This model does not support binary data, only text strings"
)
if not self.supports_text and isinstance(item, str):
raise ValueError(
"This model does not support text strings, only binary data"
)
def embed(self, item: Union[str, bytes]) -> List[float]:
"Embed a single text string or binary blob, return a list of floats"
self._check(item)
return next(iter(self.embed_batch([item])))
def embed_multi(
self, items: Iterable[Union[str, bytes]], batch_size: Optional[int] = None
) -> Iterator[List[float]]:
"Embed multiple items in batches according to the model batch_size"
iter_items = iter(items)
batch_size = self.batch_size if batch_size is None else batch_size
if (not self.supports_binary) or (not self.supports_text):
def checking_iter(items):
for item in items:
self._check(item)
yield item
iter_items = checking_iter(items)
if batch_size is None:
yield from self.embed_batch(iter_items)
return
while True:
batch_items = list(islice(iter_items, batch_size))
if not batch_items:
break
yield from self.embed_batch(batch_items)
@abstractmethod
def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
"""
Embed a batch of strings or blobs, return a list of lists of floats
"""
pass
@dataclass
class ModelWithAliases:
model: Model
async_model: AsyncModel
aliases: Set[str]
def matches(self, query: str) -> bool:
query = query.lower()
all_strings: List[str] = []
all_strings.extend(self.aliases)
if self.model:
all_strings.append(str(self.model))
if self.async_model:
all_strings.append(str(self.async_model.model_id))
return any(query in alias.lower() for alias in all_strings)
@dataclass
class EmbeddingModelWithAliases:
model: EmbeddingModel
aliases: Set[str]
def _conversation_name(text):
# Collapse whitespace, including newlines
text = re.sub(r"\s+", " ", text)
if len(text) <= CONVERSATION_NAME_LENGTH:
return text
return text[: CONVERSATION_NAME_LENGTH - 1] + "…"
</document_content>
</document>
<document index="11">
<source>/Users/simon/Dropbox/Development/llm/llm/plugins.py</source>
<document_content>
import importlib
from importlib import metadata
import os
import pluggy
import sys
from . import hookspecs
DEFAULT_PLUGINS = ("llm.default_plugins.openai_models",)
pm = pluggy.PluginManager("llm")
pm.add_hookspecs(hookspecs)
LLM_LOAD_PLUGINS = os.environ.get("LLM_LOAD_PLUGINS", None)
_loaded = False
def load_plugins():
global _loaded
if _loaded:
return
_loaded = True
if not hasattr(sys, "_called_from_test") and LLM_LOAD_PLUGINS is None:
# Only load plugins if not running tests
pm.load_setuptools_entrypoints("llm")
# Load any plugins specified in LLM_LOAD_PLUGINS")
if LLM_LOAD_PLUGINS is not None:
for package_name in [
name for name in LLM_LOAD_PLUGINS.split(",") if name.strip()
]:
try:
distribution = metadata.distribution(package_name) # Updated call
llm_entry_points = [
ep for ep in distribution.entry_points if ep.group == "llm"
]
for entry_point in llm_entry_points:
mod = entry_point.load()
pm.register(mod, name=entry_point.name)
# Ensure name can be found in plugin_to_distinfo later:
pm._plugin_distinfo.append((mod, distribution)) # type: ignore
except metadata.PackageNotFoundError:
sys.stderr.write(f"Plugin {package_name} could not be found\n")
for plugin in DEFAULT_PLUGINS:
mod = importlib.import_module(plugin)
pm.register(mod, plugin)
</document_content>
</document>
<document index="12">
<source>/Users/simon/Dropbox/Development/llm/llm/serve.py</source>
<document_content>
import json
import llm
async def error(send, status_code: int, message: str):
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": [(b"content-type", b"application/json")],
}
)
await send(
{
"type": "http.response.body",
"body": json.dumps({"error": message}).encode("utf-8"),
"more_body": False,
}
)
return
async def read_request_body(receive):
"""
Reads and concatenates all HTTP request body chunks into a single bytes object.
"""
body = b""
more_body = True
while more_body:
message = await receive()
if message["type"] == "http.request":
body += message.get("body", b"")
more_body = message.get("more_body", False)
return body
async def handle_completions_request(scope, receive, send):
"""
Handle POST /v1/completions with possible streaming (SSE) or non-streamed JSON output.
"""
# Read and parse JSON payload
raw_body = await read_request_body(receive)
try:
data = json.loads(raw_body.decode("utf-8"))
print(data)
except json.JSONDecodeError:
await error(send, 400, "Invalid JSON")
return
prompt = data.get("prompt", "")
is_stream = data.get("stream", False)
try:
model = llm.get_async_model(data.get("model"))
except llm.UnknownModelError:
await error(send, 400, "Unknown model")
return
if is_stream:
# Streamed SSE response
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/event-stream"),
(b"cache-control", b"no-cache"),
(b"connection", b"keep-alive"),
],
}
)
# Each chunk from the model is sent as an SSE "data: ..." line
async for chunk in model.prompt(prompt):
# For OpenAI-compatible SSE, each chunk is typically wrapped in JSON
# The "choices" list can hold partial text, e.g. chunk, in "text"
chunk_data = {
"id": "cmpl-xxx",
"object": "text_completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{"text": chunk, "index": 0, "logprobs": None, "finish_reason": None}
],
}
sse_line = f"data: {json.dumps(chunk_data)}\n\n"
await send(
{
"type": "http.response.body",
"body": sse_line.encode("utf-8"),
"more_body": True,
}
)
# Signal that the stream is complete
await send(
{
"type": "http.response.body",
"body": b"data: [DONE]\n\n",
"more_body": False,
}
)
else:
# Non-streamed JSON response: collect all chunks first
full_output = []
async for chunk in model.prompt(prompt):
full_output.append(chunk)
concatenated = "".join(full_output)
# Build an OpenAI-like JSON response
response_body = {
"id": "cmpl-xxx",
"object": "text_completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{
"text": concatenated,
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
# "usage" field omitted for brevity
}
# Send JSON response
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
}
)
await send(
{
"type": "http.response.body",
"body": json.dumps(response_body).encode("utf-8"),
"more_body": False,
}
)
async def app(scope, receive, send):
"""
A simple ASGI application that routes /v1/completions to our OpenAI-compatible handler.
"""
if scope["type"] == "http":
path = scope.get("path", "")
method = scope.get("method", "").upper()
# Route to /v1/completions
if path == "/v1/completions":
if method == "POST":
await handle_completions_request(scope, receive, send)
else:
await error(send, 405, "Method not allowed")
return
else:
# Handle unrecognized paths or methods with a simple 404
await error(send, 404, "Not found")
return
else:
pass
</document_content>
</document>
<document index="13">
<source>/Users/simon/Dropbox/Development/llm/llm/templates.py</source>
<document_content>
from pydantic import BaseModel
import string
from typing import Optional, Any, Dict, List, Tuple
class Template(BaseModel):
name: str
prompt: Optional[str] = None
system: Optional[str] = None
model: Optional[str] = None
defaults: Optional[Dict[str, Any]] = None
# Should a fenced code block be extracted?
extract: Optional[bool] = None
extract_last: Optional[bool] = None
class Config:
extra = "forbid"
class MissingVariables(Exception):
pass
def evaluate(
self, input: str, params: Optional[Dict[str, Any]] = None
) -> Tuple[Optional[str], Optional[str]]:
params = params or {}
params["input"] = input
if self.defaults:
for k, v in self.defaults.items():
if k not in params:
params[k] = v
prompt: Optional[str] = None
system: Optional[str] = None
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
else:
prompt = self.interpolate(self.prompt, params)
system = self.interpolate(self.system, params)
return prompt, system
@classmethod
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
if not text:
return text
# Confirm all variables in text are provided
string_template = string.Template(text)
vars = cls.extract_vars(string_template)
missing = [p for p in vars if p not in params]
if missing:
raise cls.MissingVariables(
"Missing variables: {}".format(", ".join(missing))
)
return string_template.substitute(**params)
@staticmethod
def extract_vars(string_template: string.Template) -> List[str]:
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)
]
</document_content>
</document>
<document index="14">
<source>/Users/simon/Dropbox/Development/llm/llm/utils.py</source>
<document_content>
import click
import httpx
import json
import puremagic
import re
import textwrap
from typing import List, Dict, Optional
MIME_TYPE_FIXES = {
"audio/wave": "audio/wav",
}
def mimetype_from_string(content) -> Optional[str]:
try:
type_ = puremagic.from_string(content, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None
def mimetype_from_path(path) -> Optional[str]:
try:
type_ = puremagic.from_file(path, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None
def dicts_to_table_string(
headings: List[str], dicts: List[Dict[str, str]]
) -> List[str]:
max_lengths = [len(h) for h in headings]
# Compute maximum length for each column
for d in dicts:
for i, h in enumerate(headings):
if h in d and len(str(d[h])) > max_lengths[i]:
max_lengths[i] = len(str(d[h]))
# Generate formatted table strings
res = []
res.append(" ".join(h.ljust(max_lengths[i]) for i, h in enumerate(headings)))
for d in dicts:
row = []
for i, h in enumerate(headings):
row.append(str(d.get(h, "")).ljust(max_lengths[i]))
res.append(" ".join(row))
return res
def remove_dict_none_values(d):
"""
Recursively remove keys with value of None or value of a dict that is all values of None
"""
if not isinstance(d, dict):
return d
new_dict = {}
for key, value in d.items():
if value is not None:
if isinstance(value, dict):
nested = remove_dict_none_values(value)
if nested:
new_dict[key] = nested
elif isinstance(value, list):
new_dict[key] = [remove_dict_none_values(v) for v in value]
else:
new_dict[key] = value
return new_dict
class _LogResponse(httpx.Response):
def iter_bytes(self, *args, **kwargs):
for chunk in super().iter_bytes(*args, **kwargs):
click.echo(chunk.decode(), err=True)
yield chunk
class _LogTransport(httpx.BaseTransport):
def __init__(self, transport: httpx.BaseTransport):
self.transport = transport
def handle_request(self, request: httpx.Request) -> httpx.Response:
response = self.transport.handle_request(request)
return _LogResponse(
status_code=response.status_code,
headers=response.headers,
stream=response.stream,
extensions=response.extensions,
)
def _no_accept_encoding(request: httpx.Request):
request.headers.pop("accept-encoding", None)
def _log_response(response: httpx.Response):
request = response.request
click.echo(f"Request: {request.method} {request.url}", err=True)
click.echo(" Headers:", err=True)
for key, value in request.headers.items():
if key.lower() == "authorization":
value = "[...]"
if key.lower() == "cookie":
value = value.split("=")[0] + "=..."
click.echo(f" {key}: {value}", err=True)
click.echo(" Body:", err=True)
try:
request_body = json.loads(request.content)
click.echo(
textwrap.indent(json.dumps(request_body, indent=2), " "), err=True
)
except json.JSONDecodeError:
click.echo(textwrap.indent(request.content.decode(), " "), err=True)
click.echo(f"Response: status_code={response.status_code}", err=True)
click.echo(" Headers:", err=True)
for key, value in response.headers.items():
if key.lower() == "set-cookie":
value = value.split("=")[0] + "=..."
click.echo(f" {key}: {value}", err=True)
click.echo(" Body:", err=True)
def logging_client() -> httpx.Client:
return httpx.Client(
transport=_LogTransport(httpx.HTTPTransport()),
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
)
def simplify_usage_dict(d):
# Recursively remove keys with value 0 and empty dictionaries
def remove_empty_and_zero(obj):
if isinstance(obj, dict):
cleaned = {
k: remove_empty_and_zero(v)
for k, v in obj.items()
if v != 0 and v != {}
}
return {k: v for k, v in cleaned.items() if v is not None and v != {}}
return obj
return remove_empty_and_zero(d) or {}
def token_usage_string(input_tokens, output_tokens, token_details) -> str:
bits = []
if input_tokens is not None:
bits.append(f"{format(input_tokens, ',')} input")
if output_tokens is not None:
bits.append(f"{format(output_tokens, ',')} output")
if token_details:
bits.append(json.dumps(token_details))
return ", ".join(bits)
def extract_fenced_code_block(text: str, last: bool = False) -> Optional[str]:
"""
Extracts and returns Markdown fenced code block found in the given text.
The function handles fenced code blocks that:
- Use at least three backticks (`).
- May include a language tag immediately after the opening backticks.
- Use more than three backticks as long as the closing fence has the same number.
If no fenced code block is found, the function returns None.
Args:
text (str): The input text to search for a fenced code block.
last (bool): Extract the last code block if True, otherwise the first.
Returns:
Optional[str]: The content of the fenced code block, or None if not found.
"""
# Regex pattern to match fenced code blocks
# - ^ or \n ensures that the fence is at the start of a line
# - (`{3,}) captures the opening backticks (at least three)
# - (\w+)? optionally captures the language tag
# - \n matches the newline after the opening fence
# - (.*?) non-greedy match for the code block content
# - (?P=fence) ensures that the closing fence has the same number of backticks
# - [ ]* allows for optional spaces between the closing fence and newline
# - (?=\n|$) ensures that the closing fence is followed by a newline or end of string
pattern = re.compile(
r"""(?m)^(?P<fence>`{3,})(?P<lang>\w+)?\n(?P<code>.*?)^(?P=fence)[ ]*(?=\n|$)""",
re.DOTALL,
)
matches = list(pattern.finditer(text))
if matches:
match = matches[-1] if last else matches[0]
return match.group("code")
return None
</document_content>
</document>
<document index="15">
<source>/Users/simon/Dropbox/Development/llm/llm/default_plugins/__init__.py</source>
<document_content>
</document_content>
</document>
<document index="16">
<source>/Users/simon/Dropbox/Development/llm/llm/default_plugins/chunkers.py</source>
<document_content>
from llm import hookimpl
def lines(text):
"Chunk text into lines"
for line in text.split("\n"):
if line.strip():
yield line
@hookimpl
def register_chunker_functions(register):
register(lines, name="lines")
</document_content>
</document>
<document index="17">
<source>/Users/simon/Dropbox/Development/llm/llm/default_plugins/openai_models.py</source>
<document_content>
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
import llm
from llm.utils import (
dicts_to_table_string,
remove_dict_none_values,
logging_client,
simplify_usage_dict,
)
import click
import datetime
import httpx
import openai
import os
try:
# Pydantic 2
from pydantic import field_validator, Field # type: ignore
except ImportError:
# Pydantic 1
from pydantic.fields import Field
from pydantic.class_validators import validator as field_validator # type: ignore [no-redef]
from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union
import json
import yaml
@hookimpl
def register_models(register):
# GPT-4o
register(
Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",)
)
register(
Chat("gpt-4o-mini", vision=True),
AsyncChat("gpt-4o-mini", vision=True),
aliases=("4o-mini",),
)
for audio_model_id in (
"gpt-4o-audio-preview",
"gpt-4o-audio-preview-2024-12-17",
"gpt-4o-audio-preview-2024-10-01",
"gpt-4o-mini-audio-preview",
"gpt-4o-mini-audio-preview-2024-12-17",
):
register(
Chat(audio_model_id, audio=True),
AsyncChat(audio_model_id, audio=True),
)
# 3.5 and 4
register(
Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")
)
register(
Chat("gpt-3.5-turbo-16k"),
AsyncChat("gpt-3.5-turbo-16k"),
aliases=("chatgpt-16k", "3.5-16k"),
)
register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",))
# GPT-4 Turbo models
register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09"))
register(
Chat("gpt-4-turbo"),
AsyncChat("gpt-4-turbo"),
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
)
# o1
for model_id in ("o1", "o1-2024-12-17"):
register(
Chat(model_id, vision=True, can_stream=False),
AsyncChat(model_id, vision=True, can_stream=False),
)
register(
Chat("o1-preview", allows_system_prompt=False),
AsyncChat("o1-preview", allows_system_prompt=False),
)
register(
Chat("o1-mini", allows_system_prompt=False),
AsyncChat("o1-mini", allows_system_prompt=False),
)
# The -instruct completion model
register(
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
aliases=("3.5-instruct", "chatgpt-instruct"),
)
# Load extra models
extra_path = llm.user_dir() / "extra-openai-models.yaml"
if not extra_path.exists():
return
with open(extra_path) as f:
extra_models = yaml.safe_load(f)
for extra_model in extra_models:
model_id = extra_model["model_id"]
aliases = extra_model.get("aliases", [])
model_name = extra_model["model_name"]
api_base = extra_model.get("api_base")
api_type = extra_model.get("api_type")
api_version = extra_model.get("api_version")
api_engine = extra_model.get("api_engine")
headers = extra_model.get("headers")
kwargs = {}
if extra_model.get("can_stream") is False:
kwargs["can_stream"] = False
if extra_model.get("completion"):
klass = Completion
else:
klass = Chat
chat_model = klass(
model_id,
model_name=model_name,
api_base=api_base,
api_type=api_type,
api_version=api_version,
api_engine=api_engine,
headers=headers,
**kwargs,
)
if api_base:
chat_model.needs_key = None
if extra_model.get("api_key_name"):
chat_model.needs_key = extra_model["api_key_name"]
register(
chat_model,
aliases=aliases,
)
@hookimpl
def register_embedding_models(register):
register(
OpenAIEmbeddingModel("text-embedding-ada-002", "text-embedding-ada-002"),
aliases=(
"ada",
"ada-002",
),
)
register(
OpenAIEmbeddingModel("text-embedding-3-small", "text-embedding-3-small"),
aliases=("3-small",),
)
register(
OpenAIEmbeddingModel("text-embedding-3-large", "text-embedding-3-large"),
aliases=("3-large",),
)
# With varying dimensions
register(
OpenAIEmbeddingModel(
"text-embedding-3-small-512", "text-embedding-3-small", 512
),
aliases=("3-small-512",),
)
register(
OpenAIEmbeddingModel(
"text-embedding-3-large-256", "text-embedding-3-large", 256
),
aliases=("3-large-256",),
)
register(
OpenAIEmbeddingModel(
"text-embedding-3-large-1024", "text-embedding-3-large", 1024
),
aliases=("3-large-1024",),
)
class OpenAIEmbeddingModel(EmbeddingModel):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
batch_size = 100
def __init__(self, model_id, openai_model_id, dimensions=None):
self.model_id = model_id
self.openai_model_id = openai_model_id
self.dimensions = dimensions
def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
kwargs = {
"input": items,
"model": self.openai_model_id,
}
if self.dimensions:
kwargs["dimensions"] = self.dimensions
client = openai.OpenAI(api_key=self.get_key())
results = client.embeddings.create(**kwargs).data
return ([float(r) for r in result.embedding] for result in results)
@hookimpl
def register_commands(cli):
@cli.group(name="openai")
def openai_():
"Commands for working directly with the OpenAI API"
@openai_.command()
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
@click.option("--key", help="OpenAI API key")
def models(json_, key):
"List models available to you from the OpenAI API"
from llm.cli import get_key
api_key = get_key(key, "openai", "OPENAI_API_KEY")
response = httpx.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise click.ClickException(
f"Error {response.status_code} from OpenAI API: {response.text}"
)
models = response.json()["data"]
if json_:
click.echo(json.dumps(models, indent=4))
else:
to_print = []
for model in models:
# Print id, owned_by, root, created as ISO 8601
created_str = datetime.datetime.fromtimestamp(
model["created"], datetime.timezone.utc
).isoformat()
to_print.append(
{
"id": model["id"],
"owned_by": model["owned_by"],
"created": created_str,
}
)
done = dicts_to_table_string("id owned_by created".split(), to_print)
print("\n".join(done))
class SharedOptions(llm.Options):
temperature: Optional[float] = Field(
description=(
"What sampling temperature to use, between 0 and 2. Higher values like "
"0.8 will make the output more random, while lower values like 0.2 will "
"make it more focused and deterministic."
),
ge=0,
le=2,
default=None,
)
max_tokens: Optional[int] = Field(
description="Maximum number of tokens to generate.", default=None
)
top_p: Optional[float] = Field(
description=(
"An alternative to sampling with temperature, called nucleus sampling, "
"where the model considers the results of the tokens with top_p "
"probability mass. So 0.1 means only the tokens comprising the top "
"10% probability mass are considered. Recommended to use top_p or "
"temperature but not both."
),
ge=0,
le=1,
default=None,
)
frequency_penalty: Optional[float] = Field(
description=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on their existing frequency in the text so far, decreasing the model's "
"likelihood to repeat the same line verbatim."
),
ge=-2,
le=2,
default=None,
)
presence_penalty: Optional[float] = Field(
description=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on whether they appear in the text so far, increasing the model's "
"likelihood to talk about new topics."
),
ge=-2,
le=2,
default=None,
)
stop: Optional[str] = Field(
description=("A string where the API will stop generating further tokens."),
default=None,
)
logit_bias: Optional[Union[dict, str]] = Field(
description=(
"Modify the likelihood of specified tokens appearing in the completion. "
'Pass a JSON string like \'{"1712":-100, "892":-100, "1489":-100}\''
),
default=None,
)
seed: Optional[int] = Field(
description="Integer seed to attempt to sample deterministically",
default=None,
)
@field_validator("logit_bias")
def validate_logit_bias(cls, logit_bias):
if logit_bias is None:
return None
if isinstance(logit_bias, str):
try:
logit_bias = json.loads(logit_bias)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in logit_bias string")
validated_logit_bias = {}
for key, value in logit_bias.items():
try:
int_key = int(key)
int_value = int(value)
if -100 <= int_value <= 100:
validated_logit_bias[int_key] = int_value
else:
raise ValueError("Value must be between -100 and 100")
except ValueError:
raise ValueError("Invalid key-value pair in logit_bias dictionary")
return validated_logit_bias
def _attachment(attachment):
url = attachment.url
base64_content = ""
if not url or attachment.resolve_type().startswith("audio/"):
base64_content = attachment.base64_content()
url = f"data:{attachment.resolve_type()};base64,{base64_content}"
if attachment.resolve_type().startswith("image/"):
return {"type": "image_url", "image_url": {"url": url}}
else:
format_ = "wav" if attachment.resolve_type() == "audio/wav" else "mp3"
return {
"type": "input_audio",
"input_audio": {
"data": base64_content,
"format": format_,
},
}
class _Shared:
def __init__(
self,
model_id,
key=None,
model_name=None,
api_base=None,
api_type=None,
api_version=None,
api_engine=None,
headers=None,
can_stream=True,
vision=False,
audio=False,
allows_system_prompt=True,
):
self.model_id = model_id
self.key = key
self.model_name = model_name
self.api_base = api_base
self.api_type = api_type
self.api_version = api_version
self.api_engine = api_engine
self.headers = headers
self.can_stream = can_stream
self.vision = vision
self.allows_system_prompt = allows_system_prompt
self.attachment_types = set()
if vision:
self.attachment_types.update(
{
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
}
)
if audio:
self.attachment_types.update(
{
"audio/wav",
"audio/mpeg",
}
)
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)
def build_messages(self, prompt, conversation):
messages = []
current_system = None
if conversation is not None:
for prev_response in conversation.responses:
if (
prev_response.prompt.system
and prev_response.prompt.system != current_system
):
messages.append(
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
if prev_response.attachments:
attachment_message = []
if prev_response.prompt.prompt:
attachment_message.append(
{"type": "text", "text": prev_response.prompt.prompt}
)
for attachment in prev_response.attachments:
attachment_message.append(_attachment(attachment))
messages.append({"role": "user", "content": attachment_message})
else:
messages.append(
{"role": "user", "content": prev_response.prompt.prompt}
)
messages.append(
{"role": "assistant", "content": prev_response.text_or_raise()}
)
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
if not prompt.attachments:
messages.append({"role": "user", "content": prompt.prompt})
else:
attachment_message = []
if prompt.prompt:
attachment_message.append({"type": "text", "text": prompt.prompt})
for attachment in prompt.attachments:
attachment_message.append(_attachment(attachment))
messages.append({"role": "user", "content": attachment_message})
return messages
def set_usage(self, response, usage):
if not usage:
return
input_tokens = usage.pop("prompt_tokens")
output_tokens = usage.pop("completion_tokens")
usage.pop("total_tokens")
response.set_usage(
input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
)
def get_client(self, async_=False):
kwargs = {}
if self.api_base:
kwargs["base_url"] = self.api_base
if self.api_type:
kwargs["api_type"] = self.api_type
if self.api_version:
kwargs["api_version"] = self.api_version
if self.api_engine:
kwargs["engine"] = self.api_engine
if self.needs_key:
kwargs["api_key"] = self.get_key()
else:
# OpenAI-compatible models don't need a key, but the
# openai client library requires one
kwargs["api_key"] = "DUMMY_KEY"
if self.headers:
kwargs["default_headers"] = self.headers
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
kwargs["http_client"] = logging_client()
if async_:
return openai.AsyncOpenAI(**kwargs)
else:
return openai.OpenAI(**kwargs)
def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
if stream:
kwargs["stream_options"] = {"include_usage": True}
return kwargs
class Chat(_Shared, Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None
class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)
def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
usage = None
if stream:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
for chunk in completion:
chunks.append(chunk)
if chunk.usage:
usage = chunk.usage.model_dump()
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
usage = completion.usage.model_dump()
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})
class AsyncChat(_Shared, AsyncModel):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None
class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)
async def execute(
self, prompt, stream, response, conversation=None
) -> AsyncGenerator[str, None]:
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client(async_=True)
usage = None
if stream:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
async for chunk in completion:
if chunk.usage:
usage = chunk.usage.model_dump()
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
usage = completion.usage.model_dump()
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})
class Completion(Chat):
class Options(SharedOptions):
logprobs: Optional[int] = Field(
description="Include the log probabilities of most likely N per token",
default=None,
le=5,
)
def __init__(self, *args, default_max_tokens=None, **kwargs):
super().__init__(*args, **kwargs)
self.default_max_tokens = default_max_tokens
def __str__(self):
return "OpenAI Completion: {}".format(self.model_id)
def execute(self, prompt, stream, response, conversation=None):
if prompt.system:
raise NotImplementedError(
"System prompts are not supported for OpenAI completion models"
)
messages = []
if conversation is not None:
for prev_response in conversation.responses:
messages.append(prev_response.prompt.prompt)
messages.append(prev_response.text())
messages.append(prompt.prompt)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.completions.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=True,
**kwargs,
)
chunks = []
for chunk in completion:
chunks.append(chunk)
try:
content = chunk.choices[0].text
except IndexError:
content = None
if content is not None:
yield content
combined = combine_chunks(chunks)
cleaned = remove_dict_none_values(combined)
response.response_json = cleaned
else:
completion = client.completions.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].text
response._prompt_json = redact_data({"messages": messages})
def not_nulls(data) -> dict:
return {key: value for key, value in data if value is not None}
def combine_chunks(chunks: List) -> dict:
content = ""
role = None
finish_reason = None
# If any of them have log probability, we're going to persist
# those later on
logprobs = []
usage = {}
for item in chunks:
if item.usage:
usage = item.usage.dict()
for choice in item.choices:
if choice.logprobs and hasattr(choice.logprobs, "top_logprobs"):
logprobs.append(
{
"text": choice.text if hasattr(choice, "text") else None,
"top_logprobs": choice.logprobs.top_logprobs,
}
)
if not hasattr(choice, "delta"):
content += choice.text
continue
role = choice.delta.role
if choice.delta.content is not None:
content += choice.delta.content
if choice.finish_reason is not None:
finish_reason = choice.finish_reason
# Imitations of the OpenAI API may be missing some of these fields
combined = {
"content": content,
"role": role,
"finish_reason": finish_reason,
"usage": usage,
}
if logprobs:
combined["logprobs"] = logprobs
if chunks:
for key in ("id", "object", "model", "created", "index"):
value = getattr(chunks[0], key, None)
if value is not None:
combined[key] = value
return combined
def redact_data(input_dict):
"""
Recursively search through the input dictionary for any 'image_url' keys
and modify the 'url' value to be just 'data:...'.
Also redact input_audio.data keys
"""
if isinstance(input_dict, dict):
for key, value in input_dict.items():
if (
key == "image_url"
and isinstance(value, dict)
and "url" in value
and value["url"].startswith("data:")
):
value["url"] = "data:..."
elif key == "input_audio" and isinstance(value, dict) and "data" in value:
value["data"] = "..."
else:
redact_data(value)
elif isinstance(input_dict, list):
for item in input_dict:
redact_data(item)
return input_dict
</document_content>
</document>
<document index="18">
<source>/Users/simon/Dropbox/Development/llm/tests/conftest.py</source>
<document_content>
import pytest
import sqlite_utils
import json
import llm
from llm.plugins import pm
from pydantic import Field
from pytest_httpx import IteratorStream
from typing import Optional
def pytest_configure(config):
import sys
sys._called_from_test = True
@pytest.fixture
def user_path(tmpdir):
dir = tmpdir / "llm.datasette.io"
dir.mkdir()
return dir
@pytest.fixture
def logs_db(user_path):
return sqlite_utils.Database(str(user_path / "logs.db"))
@pytest.fixture
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection("demo", db, model_id="embed-demo")
collection.embed("1", "hello world")
collection.embed("2", "goodbye world")
@pytest.fixture
def templates_path(user_path):
dir = user_path / "templates"
dir.mkdir()
return dir
@pytest.fixture(autouse=True)
def env_setup(monkeypatch, user_path):
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
class MockModel(llm.Model):
model_id = "mock"
attachment_types = {"image/png", "audio/wav"}
class Options(llm.Options):
max_tokens: Optional[int] = Field(
description="Maximum number of tokens to generate.", default=None
)
def __init__(self):
self.history = []
self._queue = []
def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)
def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
gathered = []
while True:
try:
messages = self._queue.pop(0)
for message in messages:
gathered.append(message)
yield message
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
class AsyncMockModel(llm.AsyncModel):
model_id = "mock"
def __init__(self):
self.history = []
self._queue = []
def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)
async def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
gathered = []
while True:
try:
messages = self._queue.pop(0)
for message in messages:
gathered.append(message)
yield message
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
supports_binary = True
def __init__(self):
self.embedded_content = []
def embed_batch(self, texts):
if not hasattr(self, "batch_count"):
self.batch_count = 0
self.batch_count += 1
for text in texts:
self.embedded_content.append(text)
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
yield embedding
class EmbedBinaryOnly(EmbedDemo):
model_id = "embed-binary-only"
supports_text = False
supports_binary = True
class EmbedTextOnly(EmbedDemo):
model_id = "embed-text-only"
supports_text = True
supports_binary = False
@pytest.fixture
def embed_demo():
return EmbedDemo()
@pytest.fixture
def mock_model():
return MockModel()
@pytest.fixture
def async_mock_model():
return AsyncMockModel()
@pytest.fixture(autouse=True)
def register_embed_demo_model(embed_demo, mock_model, async_mock_model):
class MockModelsPlugin:
__name__ = "MockModelsPlugin"
@llm.hookimpl
def register_embedding_models(self, register):
register(embed_demo)
register(EmbedBinaryOnly())
register(EmbedTextOnly())
@llm.hookimpl
def register_models(self, register):
register(mock_model, async_model=async_mock_model)
pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
try:
yield
finally:
pm.unregister(name="undo-mock-models-plugin")
@pytest.fixture
def mocked_openai_chat(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
@pytest.fixture
def mocked_openai_chat_returning_fenced_code(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [
{
"message": {
"content": "Code:\n\n````javascript\nfunction foo() {\n return 'bar';\n}\n````\nDone.",
}
}
],
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
def stream_events():
for delta, finish_reason in (
({"role": "assistant", "content": ""}, None),
({"content": "Hi"}, None),
({"content": "."}, None),
({}, "stop"),
):
yield "data: {}\n\n".format(
json.dumps(
{
"id": "chat-1",
"object": "chat.completion.chunk",
"created": 1695096940,
"model": "gpt-3.5-turbo-0613",
"choices": [
{"index": 0, "delta": delta, "finish_reason": finish_reason}
],
}
)
).encode("utf-8")
yield "data: [DONE]\n\n".encode("utf-8")
@pytest.fixture
def mocked_openai_chat_stream(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
stream=IteratorStream(stream_events()),
headers={"Content-Type": "text/event-stream"},
)
@pytest.fixture
def mocked_openai_completion(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/completions",
json={
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": "\n\nThis is indeed a test",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12},
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
def stream_completion_events():
choices_chunks = [
[
{
"text": "\n\n",
"index": 0,
"logprobs": {
"tokens": ["\n\n"],
"token_logprobs": [-0.6],
"top_logprobs": [{"\n\n": -0.6, "\n": -1.9}],
"text_offset": [16],
},
"finish_reason": None,
}
],
[
{
"text": "Hi",
"index": 0,
"logprobs": {
"tokens": ["Hi"],
"token_logprobs": [-1.1],
"top_logprobs": [{"Hi": -1.1, "Hello": -0.7}],
"text_offset": [18],
},
"finish_reason": None,
}
],
[
{
"text": ".",
"index": 0,
"logprobs": {
"tokens": ["."],
"token_logprobs": [-1.1],
"top_logprobs": [{".": -1.1, "!": -0.9}],
"text_offset": [20],
},
"finish_reason": None,
}
],
[
{
"text": "",
"index": 0,
"logprobs": {
"tokens": [],
"token_logprobs": [],
"top_logprobs": [],
"text_offset": [],
},
"finish_reason": "stop",
}
],
]
for choices in choices_chunks:
yield "data: {}\n\n".format(
json.dumps(
{
"id": "cmpl-80MdSaou7NnPuff5ZyRMysWBmgSPS",
"object": "text_completion",
"created": 1695097702,
"choices": choices,
"model": "gpt-3.5-turbo-instruct",
}
)
).encode("utf-8")
yield "data: [DONE]\n\n".encode("utf-8")
@pytest.fixture
def mocked_openai_completion_logprobs_stream(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/completions",
stream=IteratorStream(stream_completion_events()),
headers={"Content-Type": "text/event-stream"},
)
return httpx_mock
@pytest.fixture
def mocked_openai_completion_logprobs(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/completions",
json={
"id": "cmpl-80MeBfKJutM0uMNJkRrebJLeP3bxL",
"object": "text_completion",
"created": 1695097747,
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": "\n\nHi.",
"index": 0,
"logprobs": {
"tokens": ["\n\n", "Hi", "1"],
"token_logprobs": [-0.6, -1.1, -0.9],
"top_logprobs": [
{"\n\n": -0.6, "\n": -1.9},
{"Hi": -1.1, "Hello": -0.7},
{".": -0.9, "!": -1.1},
],
"text_offset": [16, 18, 20],
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8},
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
@pytest.fixture
def mocked_localai(httpx_mock):
httpx_mock.add_response(
method="POST",
url="http://localai.localhost/chat/completions",
json={
"model": "orca",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)
httpx_mock.add_response(
method="POST",
url="http://localai.localhost/completions",
json={
"model": "completion-babbage",
"usage": {},
"choices": [{"text": "Hello"}],
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
@pytest.fixture
def collection():
collection = llm.Collection("test", model_id="embed-demo")
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
return collection
</document_content>
</document>
<document index="19">
<source>/Users/simon/Dropbox/Development/llm/tests/test_aliases.py</source>
<document_content>
from click.testing import CliRunner
from llm.cli import cli
import llm
import json
import pytest
import re
@pytest.mark.parametrize("model_id_or_alias", ("gpt-3.5-turbo", "chatgpt"))
def test_set_alias(model_id_or_alias):
with pytest.raises(llm.UnknownModelError):
llm.get_model("this-is-a-new-alias")
llm.set_alias("this-is-a-new-alias", model_id_or_alias)
assert llm.get_model("this-is-a-new-alias").model_id == "gpt-3.5-turbo"
def test_remove_alias():
with pytest.raises(KeyError):
llm.remove_alias("some-other-alias")
llm.set_alias("some-other-alias", "gpt-3.5-turbo")
assert llm.get_model("some-other-alias").model_id == "gpt-3.5-turbo"
llm.remove_alias("some-other-alias")
with pytest.raises(llm.UnknownModelError):
llm.get_model("some-other-alias")
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
def test_cli_aliases_list(args):
llm.set_alias("e-demo", "embed-demo")
runner = CliRunner()
result = runner.invoke(cli, args)
assert result.exit_code == 0
for line in (
"3.5 : gpt-3.5-turbo\n"
"chatgpt : gpt-3.5-turbo\n"
"chatgpt-16k : gpt-3.5-turbo-16k\n"
"3.5-16k : gpt-3.5-turbo-16k\n"
"4 : gpt-4\n"
"gpt4 : gpt-4\n"
"4-32k : gpt-4-32k\n"
"e-demo : embed-demo (embedding)\n"
"ada : text-embedding-ada-002 (embedding)\n"
).split("\n"):
line = line.strip()
if not line:
continue
# Turn the whitespace into a regex
regex = r"\s+".join(re.escape(part) for part in line.split())
assert re.search(regex, result.output)
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
def test_cli_aliases_list_json(args):
llm.set_alias("e-demo", "embed-demo")
runner = CliRunner()
result = runner.invoke(cli, args + ["--json"])
assert result.exit_code == 0
assert (
json.loads(result.output).items()
>= {
"3.5": "gpt-3.5-turbo",
"chatgpt": "gpt-3.5-turbo",
"chatgpt-16k": "gpt-3.5-turbo-16k",
"3.5-16k": "gpt-3.5-turbo-16k",
"4": "gpt-4",
"gpt4": "gpt-4",
"4-32k": "gpt-4-32k",
"ada": "text-embedding-ada-002",
"e-demo": "embed-demo",
}.items()
)
def test_cli_aliases_set(user_path):
# Should be not aliases.json at start
assert not (user_path / "aliases.json").exists()
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "set", "foo", "bar"])
assert result.exit_code == 0
assert (user_path / "aliases.json").exists()
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {"foo": "bar"}
def test_cli_aliases_path(user_path):
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "path"])
assert result.exit_code == 0
assert result.output.strip() == str(user_path / "aliases.json")
def test_cli_aliases_remove(user_path):
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "remove", "foo"])
assert result.exit_code == 0
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {}
def test_cli_aliases_remove_invalid(user_path):
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "remove", "invalid"])
assert result.exit_code == 1
assert result.output == "Error: No such alias: invalid\n"
@pytest.mark.parametrize("args", (["models"], ["models", "list"]))
def test_cli_aliases_are_registered(user_path, args):
(user_path / "aliases.json").write_text(
json.dumps({"foo": "bar", "turbo": "gpt-3.5-turbo"}), "utf-8"
)
runner = CliRunner()
result = runner.invoke(cli, args)
assert result.exit_code == 0
assert "gpt-3.5-turbo (aliases: 3.5, chatgpt, turbo)" in result.output
</document_content>
</document>
<document index="20">
<source>/Users/simon/Dropbox/Development/llm/tests/test_async.py</source>
<document_content>
import llm
import pytest
@pytest.mark.asyncio
async def test_async_model(async_mock_model):
gathered = []
async_mock_model.enqueue(["hello world"])
async for chunk in async_mock_model.prompt("hello"):
gathered.append(chunk)
assert gathered == ["hello world"]
# Not as an iterator
async_mock_model.enqueue(["hello world"])
response = await async_mock_model.prompt("hello")
text = await response.text()
assert text == "hello world"
assert isinstance(response, llm.AsyncResponse)
usage = await response.usage()
assert usage.input == 1
assert usage.output == 1
assert usage.details is None
@pytest.mark.asyncio
async def test_async_model_conversation(async_mock_model):
async_mock_model.enqueue(["joke 1"])
conversation = async_mock_model.conversation()
response = await conversation.prompt("joke")
text = await response.text()
assert text == "joke 1"
async_mock_model.enqueue(["joke 2"])
response2 = await conversation.prompt("again")
text2 = await response2.text()
assert text2 == "joke 2"
@pytest.mark.asyncio
async def test_async_on_done(async_mock_model):
async_mock_model.enqueue(["hello world"])
response = await async_mock_model.prompt(prompt="hello")
caught = []
def done(response):
caught.append(response)
assert len(caught) == 0
await response.on_done(done)
await response.text()
assert response._done
assert len(caught) == 1
</document_content>
</document>
<document index="21">
<source>/Users/simon/Dropbox/Development/llm/tests/test_attachments.py</source>
<document_content>
from click.testing import CliRunner
from unittest.mock import ANY
import llm
from llm import cli
import pytest
TINY_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xa6\x00\x00\x01\x1a"
b"\x02\x03\x00\x00\x00\xe6\x99\xc4^\x00\x00\x00\tPLTE\xff\xff\xff"
b"\x00\xff\x00\xfe\x01\x00\x12t\x01J\x00\x00\x00GIDATx\xda\xed\xd81\x11"
b"\x000\x08\xc0\xc0.]\xea\xaf&Q\x89\x04V\xe0>\xf3+\xc8\x91Z\xf4\xa2\x08EQ\x14E"
b"Q\x14EQ\x14EQ\xd4B\x91$I3\xbb\xbf\x08EQ\x14EQ\x14EQ\x14E\xd1\xa5"
b"\xd4\x17\x91\xc6\x95\x05\x15\x0f\x9f\xc5\t\x9f\xa4\x00\x00\x00\x00IEND\xaeB`"
b"\x82"
)
TINY_WAV = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00D\xac\x00\x00"
@pytest.mark.parametrize(
"attachment_type,attachment_content",
[
("image/png", TINY_PNG),
("audio/wav", TINY_WAV),
],
)
def test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content):
runner = CliRunner()
mock_model.enqueue(["two boxes"])
result = runner.invoke(
cli.cli,
["prompt", "-m", "mock", "describe file", "-a", "-"],
input=attachment_content,
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert result.output == "two boxes\n"
assert mock_model.history[0][0].attachments[0] == llm.Attachment(
type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY
)
# Check it was logged correctly
conversations = list(logs_db["conversations"].rows)
assert len(conversations) == 1
conversation = conversations[0]
assert conversation["model"] == "mock"
assert conversation["name"] == "describe file"
response = list(logs_db["responses"].rows)[0]
attachment = list(logs_db["attachments"].rows)[0]
assert attachment == {
"id": ANY,
"type": attachment_type,
"path": None,
"url": None,
"content": attachment_content,
}
prompt_attachment = list(logs_db["prompt_attachments"].rows)[0]
assert prompt_attachment["attachment_id"] == attachment["id"]
assert prompt_attachment["response_id"] == response["id"]
</document_content>
</document>
<document index="22">
<source>/Users/simon/Dropbox/Development/llm/tests/test_chat.py</source>
<document_content>
from click.testing import CliRunner
import llm.cli
from unittest.mock import ANY
import pytest
import sys
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_basic(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["one world"])
mock_model.enqueue(["one again"])
result = runner.invoke(
llm.cli.cli,
["chat", "-m", "mock"],
input="Hi\nHi two\nquit\n",
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == (
"Chatting with mock"
"\nType 'exit' or 'quit' to exit"
"\nType '!multi' to enter multiple lines, then '!end' to finish"
"\n> Hi"
"\none world"
"\n> Hi two"
"\none again"
"\n> quit"
"\n"
)
# Should have logged
conversations = list(logs_db["conversations"].rows)
assert conversations[0] == {
"id": ANY,
"name": "Hi",
"model": "mock",
}
conversation_id = conversations[0]["id"]
responses = list(logs_db["responses"].rows)
assert responses == [
{
"id": ANY,
"model": "mock",
"prompt": "Hi",
"system": None,
"prompt_json": None,
"options_json": "{}",
"response": "one world",
"response_json": None,
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
},
{
"id": ANY,
"model": "mock",
"prompt": "Hi two",
"system": None,
"prompt_json": None,
"options_json": "{}",
"response": "one again",
"response_json": None,
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 2,
"output_tokens": 1,
"token_details": None,
},
]
# Now continue that conversation
mock_model.enqueue(["continued"])
result2 = runner.invoke(
llm.cli.cli,
["chat", "-m", "mock", "-c"],
input="Continue\nquit\n",
catch_exceptions=False,
)
assert result2.exit_code == 0
assert result2.output == (
"Chatting with mock"
"\nType 'exit' or 'quit' to exit"
"\nType '!multi' to enter multiple lines, then '!end' to finish"
"\n> Continue"
"\ncontinued"
"\n> quit"
"\n"
)
new_responses = list(
logs_db.query(
"select * from responses where id not in ({})".format(
", ".join("?" for _ in responses)
),
[r["id"] for r in responses],
)
)
assert new_responses == [
{
"id": ANY,
"model": "mock",
"prompt": "Continue",
"system": None,
"prompt_json": None,
"options_json": "{}",
"response": "continued",
"response_json": None,
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_system(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["I am mean"])
result = runner.invoke(
llm.cli.cli,
["chat", "-m", "mock", "--system", "You are mean"],
input="Hi\nquit\n",
)
assert result.exit_code == 0
assert result.output == (
"Chatting with mock"
"\nType 'exit' or 'quit' to exit"
"\nType '!multi' to enter multiple lines, then '!end' to finish"
"\n> Hi"
"\nI am mean"
"\n> quit"
"\n"
)
responses = list(logs_db["responses"].rows)
assert responses == [
{
"id": ANY,
"model": "mock",
"prompt": "Hi",
"system": "You are mean",
"prompt_json": None,
"options_json": "{}",
"response": "I am mean",
"response_json": None,
"conversation_id": ANY,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_chat_options(mock_model, logs_db):
runner = CliRunner()
mock_model.enqueue(["Some text"])
result = runner.invoke(
llm.cli.cli,
["chat", "-m", "mock", "--option", "max_tokens", "10"],
input="Hi\nquit\n",
)
assert result.exit_code == 0
responses = list(logs_db["responses"].rows)
assert responses == [
{
"id": ANY,
"model": "mock",
"prompt": "Hi",
"system": None,
"prompt_json": None,
"options_json": '{"max_tokens": 10}',
"response": "Some text",
"response_json": None,
"conversation_id": ANY,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize(
"input,expected",
(
(
"Hi\n!multi\nthis is multiple lines\nuntil the !end\n!end\nquit\n",
[
{"prompt": "Hi", "response": "One\n"},
{
"prompt": "this is multiple lines\nuntil the !end",
"response": "Two\n",
},
],
),
# quit should not work within !multi
(
"!multi\nthis is multiple lines\nquit\nuntil the !end\n!end\nquit\n",
[
{
"prompt": "this is multiple lines\nquit\nuntil the !end",
"response": "One\n",
}
],
),
# Try custom delimiter
(
"!multi abc\nCustom delimiter\n!end\n!end 123\n!end abc\nquit\n",
[{"prompt": "Custom delimiter\n!end\n!end 123", "response": "One\n"}],
),
),
)
def test_chat_multi(mock_model, logs_db, input, expected):
runner = CliRunner()
mock_model.enqueue(["One\n"])
mock_model.enqueue(["Two\n"])
mock_model.enqueue(["Three\n"])
result = runner.invoke(
llm.cli.cli, ["chat", "-m", "mock", "--option", "max_tokens", "10"], input=input
)
assert result.exit_code == 0
rows = list(logs_db["responses"].rows_where(select="prompt, response"))
assert rows == expected
</document_content>
</document>
<document index="23">
<source>/Users/simon/Dropbox/Development/llm/tests/test_cli_openai_models.py</source>
<document_content>
from click.testing import CliRunner
from llm.cli import cli
import pytest
import sqlite_utils
@pytest.fixture
def mocked_models(httpx_mock):
httpx_mock.add_response(
method="GET",
url="https://api.openai.com/v1/models",
json={
"data": [
{
"id": "ada:2020-05-03",
"object": "model",
"created": 1588537600,
"owned_by": "openai",
},
{
"id": "babbage:2020-05-03",
"object": "model",
"created": 1588537600,
"owned_by": "openai",
},
]
},
headers={"Content-Type": "application/json"},
)
return httpx_mock
def test_openai_models(mocked_models):
runner = CliRunner()
result = runner.invoke(cli, ["openai", "models", "--key", "x"])
assert result.exit_code == 0
assert result.output == (
"id owned_by created \n"
"ada:2020-05-03 openai 2020-05-03T20:26:40+00:00\n"
"babbage:2020-05-03 openai 2020-05-03T20:26:40+00:00\n"
)
def test_openai_options_min_max():
options = {
"temperature": [0, 2],
"top_p": [0, 1],
"frequency_penalty": [-2, 2],
"presence_penalty": [-2, 2],
}
runner = CliRunner()
for option, [min_val, max_val] in options.items():
result = runner.invoke(cli, ["-m", "chatgpt", "-o", option, "-10"])
assert result.exit_code == 1
assert f"greater than or equal to {min_val}" in result.output
result2 = runner.invoke(cli, ["-m", "chatgpt", "-o", option, "10"])
assert result2.exit_code == 1
assert f"less than or equal to {max_val}" in result2.output
@pytest.mark.parametrize("model", ("gpt-4o-mini", "gpt-4o-audio-preview"))
@pytest.mark.parametrize("filetype", ("mp3", "wav"))
def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype):
httpx_mock.add_response(
method="HEAD",
url=f"https://www.example.com/example.{filetype}",
content=b"binary-data",
headers={"Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wav"},
)
if model == "gpt-4o-audio-preview":
httpx_mock.add_response(
method="POST",
# chat completion request
url="https://api.openai.com/v1/chat/completions",
json={
"id": "chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh",
"object": "chat.completion",
"created": 1730871958,
"model": "gpt-4o-audio-preview-2024-10-01",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Why did the pelican get kicked out of the restaurant?\n\nBecause he had a big bill and no way to pay it!",
"refusal": None,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 55,
"completion_tokens": 25,
"total_tokens": 80,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 44,
"text_tokens": 11,
"image_tokens": 0,
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"text_tokens": 25,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
"system_fingerprint": "fp_49254d0e9b",
},
headers={"Content-Type": "application/json"},
)
httpx_mock.add_response(
method="GET",
url=f"https://www.example.com/example.{filetype}",
content=b"binary-data",
headers={
"Content-Type": "audio/mpeg" if filetype == "mp3" else "audio/wav"
},
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"-m",
model,
"-a",
f"https://www.example.com/example.{filetype}",
"--no-stream",
"--key",
"x",
],
)
if model == "gpt-4o-audio-preview":
assert result.exit_code == 0
assert result.output == (
"Why did the pelican get kicked out of the restaurant?\n\n"
"Because he had a big bill and no way to pay it!\n"
)
else:
assert result.exit_code == 1
long = "audio/mpeg" if filetype == "mp3" else "audio/wav"
assert (
f"This model does not support attachments of type '{long}'" in result.output
)
@pytest.mark.parametrize("async_", (False, True))
@pytest.mark.parametrize("usage", (None, "-u", "--usage"))
def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage):
user_path = tmpdir / "user_dir"
log_db = user_path / "logs.db"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
assert not log_db.exists()
httpx_mock.add_response(
method="POST",
# chat completion request
url="https://api.openai.com/v1/chat/completions",
json={
"id": "chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh",
"object": "chat.completion",
"created": 1730871958,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Ho ho ho",
"refusal": None,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1000,
"completion_tokens": 2000,
"total_tokens": 12,
},
"system_fingerprint": "fp_49254d0e9b",
},
headers={"Content-Type": "application/json"},
)
runner = CliRunner(mix_stderr=False)
args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"]
if usage:
args.append(usage)
if async_:
args.append("--async")
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Ho ho ho\n"
if usage:
assert result.stderr == "Token usage: 1,000 input, 2,000 output\n"
# Confirm it was correctly logged
assert log_db.exists()
db = sqlite_utils.Database(str(log_db))
assert db["responses"].count == 1
row = next(db["responses"].rows)
assert row["response"] == "Ho ho ho"
</document_content>
</document>
<document index="24">
<source>/Users/simon/Dropbox/Development/llm/tests/test_embed.py</source>
<document_content>
import json
import llm
from llm.embeddings import Entry
import pytest
import sqlite_utils
from unittest.mock import ANY
def test_demo_plugin():
model = llm.get_embedding_model("embed-demo")
assert model.embed("hello world") == [5, 5] + [0] * 14
@pytest.mark.parametrize(
"batch_size,expected_batches",
(
(None, 100),
(10, 100),
),
)
def test_embed_huge_list(batch_size, expected_batches):
model = llm.get_embedding_model("embed-demo")
huge_list = ("hello {}".format(i) for i in range(1000))
kwargs = {}
if batch_size:
kwargs["batch_size"] = batch_size
results = model.embed_multi(huge_list, **kwargs)
assert repr(type(results)) == "<class 'generator'>"
first_twos = {}
for result in results:
key = (result[0], result[1])
first_twos[key] = first_twos.get(key, 0) + 1
assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}
assert model.batch_count == expected_batches
def test_embed_store(collection):
collection.embed("3", "hello world again", store=True)
assert collection.db["embeddings"].count == 3
assert (
next(collection.db["embeddings"].rows_where("id = ?", ["3"]))["content"]
== "hello world again"
)
def test_embed_metadata(collection):
collection.embed("3", "hello yet again", metadata={"foo": "bar"}, store=True)
assert collection.db["embeddings"].count == 3
assert json.loads(
next(collection.db["embeddings"].rows_where("id = ?", ["3"]))["metadata"]
) == {"foo": "bar"}
entry = collection.similar("hello yet again")[0]
assert entry.id == "3"
assert entry.metadata == {"foo": "bar"}
assert entry.content == "hello yet again"
def test_collection(collection):
assert collection.id == 1
assert collection.count() == 2
# Check that the embeddings are there
rows = list(collection.db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "1",
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"content_blob": None,
"content_hash": collection.content_hash("hello world"),
"metadata": None,
"updated": ANY,
},
{
"collection_id": 1,
"id": "2",
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"content_blob": None,
"content_hash": collection.content_hash("goodbye world"),
"metadata": None,
"updated": ANY,
},
]
assert isinstance(rows[0]["updated"], int) and rows[0]["updated"] > 0
def test_similar(collection):
results = list(collection.similar("hello world"))
assert results == [
Entry(id="1", score=pytest.approx(0.9999999999999999)),
Entry(id="2", score=pytest.approx(0.9863939238321437)),
]
def test_similar_by_id(collection):
results = list(collection.similar_by_id("1"))
assert results == [
Entry(id="2", score=pytest.approx(0.9863939238321437)),
]
@pytest.mark.parametrize(
"batch_size,expected_batches",
(
(None, 100),
(5, 200),
),
)
@pytest.mark.parametrize("with_metadata", (False, True))
def test_embed_multi(with_metadata, batch_size, expected_batches):
db = sqlite_utils.Database(memory=True)
collection = llm.Collection("test", db, model_id="embed-demo")
model = collection.model()
assert getattr(model, "batch_count", 0) == 0
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
kwargs = {}
if batch_size is not None:
kwargs["batch_size"] = batch_size
if with_metadata:
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)
collection.embed_multi_with_metadata(ids_and_texts, **kwargs)
else:
# Exercise store=True here too
collection.embed_multi(ids_and_texts, store=True, **kwargs)
rows = list(db["embeddings"].rows)
assert len(rows) == 1000
rows_with_metadata = [row for row in rows if row["metadata"] is not None]
rows_with_content = [row for row in rows if row["content"] is not None]
if with_metadata:
assert len(rows_with_metadata) == 1000
assert len(rows_with_content) == 0
else:
assert len(rows_with_metadata) == 0
assert len(rows_with_content) == 1000
# Every row should have content_hash set
assert all(row["content_hash"] is not None for row in rows)
# Check batch count
assert collection.model().batch_count == expected_batches
def test_collection_delete(collection):
db = collection.db
assert db["embeddings"].count == 2
assert db["collections"].count == 1
collection.delete()
assert db["embeddings"].count == 0
assert db["collections"].count == 0
def test_binary_only_and_text_only_embedding_models():
binary_only = llm.get_embedding_model("embed-binary-only")
text_only = llm.get_embedding_model("embed-text-only")
assert binary_only.supports_binary
assert not binary_only.supports_text
assert not text_only.supports_binary
assert text_only.supports_text
with pytest.raises(ValueError):
binary_only.embed("hello world")
binary_only.embed(b"hello world")
with pytest.raises(ValueError):
text_only.embed(b"hello world")
text_only.embed("hello world")
# Try the multi versions too
# Have to call list() on this or the generator is not evaluated
with pytest.raises(ValueError):
list(binary_only.embed_multi(["hello world"]))
list(binary_only.embed_multi([b"hello world"]))
with pytest.raises(ValueError):
list(text_only.embed_multi([b"hello world"]))
list(text_only.embed_multi(["hello world"]))
</document_content>
</document>
<document index="25">
<source>/Users/simon/Dropbox/Development/llm/tests/test_embed_cli.py</source>
<document_content>
from click.testing import CliRunner
from llm.cli import cli
from llm import Collection
import json
import pathlib
import pytest
import sqlite_utils
import sys
from unittest.mock import ANY
@pytest.mark.parametrize(
"format_,expected",
(
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
(
"base64",
(
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
),
),
(
"hex",
(
"0000a0400000a04000000000000000000000000000000000000000000"
"000000000000000000000000000000000000000000000000000000000"
"00000000000000\n"
),
),
(
"blob",
(
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
).decode("utf-8"),
),
),
)
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
def test_embed_output_format(tmpdir, format_, expected, scenario):
runner = CliRunner()
args = ["embed", "--format", format_, "-m", "embed-demo"]
input = None
if scenario == "argument":
args.extend(["-c", "hello world"])
elif scenario == "file":
path = tmpdir / "input.txt"
path.write_text("hello world", "utf-8")
args.extend(["-i", str(path)])
elif scenario == "stdin":
input = "hello world"
args.extend(["-i", "-"])
result = runner.invoke(cli, args, input=input)
assert result.exit_code == 0
assert result.output == expected
@pytest.mark.parametrize(
"args,expected_error",
((["-c", "Content", "stories"], "Must provide both collection and id"),),
)
def test_embed_errors(args, expected_error):
runner = CliRunner()
result = runner.invoke(cli, ["embed"] + args)
assert result.exit_code == 1
assert expected_error in result.output
@pytest.mark.parametrize(
"metadata,metadata_error",
(
(None, None),
('{"foo": "bar"}', None),
('{"foo": [1, 2, 3]}', None),
("[1, 2, 3]", "Metadata must be a JSON object"), # Must be a dictionary
('{"foo": "incomplete}', "Metadata must be valid JSON"),
),
)
def test_embed_store(user_path, metadata, metadata_error):
embeddings_db = user_path / "embeddings.db"
assert not embeddings_db.exists()
runner = CliRunner()
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
assert result.exit_code == 0
# Should not have created the table
assert not embeddings_db.exists()
# Now run it to store
args = ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
if metadata is not None:
args.extend(("--metadata", metadata))
result = runner.invoke(cli, args)
if metadata_error:
# Should have returned an error message about invalid metadata
assert result.exit_code == 2
assert metadata_error in result.output
return
# No error, should have succeeded and stored the data
assert result.exit_code == 0
assert embeddings_db.exists()
# Check the contents
db = sqlite_utils.Database(str(embeddings_db))
rows = list(db["collections"].rows)
assert rows == [{"id": 1, "name": "items", "model": "embed-demo"}]
expected_metadata = None
if metadata and not metadata_error:
expected_metadata = metadata
rows = list(db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "1",
"embedding": (
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"content_blob": None,
"content_hash": Collection.content_hash("hello"),
"metadata": expected_metadata,
"updated": ANY,
}
]
# Should show up in 'llm collections list'
for is_json in (False, True):
args = ["collections"]
if is_json:
args.extend(["list", "--json"])
result2 = runner.invoke(cli, args)
assert result2.exit_code == 0
if is_json:
assert json.loads(result2.output) == [
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
]
else:
assert result2.output == "items: embed-demo\n 1 embedding\n"
# And test deleting it too
result = runner.invoke(cli, ["collections", "delete", "items"])
assert result.exit_code == 0
assert db["collections"].count == 0
assert db["embeddings"].count == 0
def test_embed_store_binary(user_path):
runner = CliRunner()
args = ["embed", "-m", "embed-demo", "items", "2", "--binary", "--store"]
result = runner.invoke(cli, args, input=b"\x00\x01\x02")
assert result.exit_code == 0
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
rows = list(db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "2",
"embedding": (
b"\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"content_blob": b"\x00\x01\x02",
"content_hash": b'\xb9_g\xf6\x1e\xbb\x03a\x96"\xd7\x98\xf4_\xc2\xd3',
"metadata": None,
"updated": ANY,
}
]
def test_collection_delete_errors(user_path):
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
collection = Collection("items", db, model_id="embed-demo")
collection.embed("1", "hello")
assert db["collections"].count == 1
assert db["embeddings"].count == 1
runner = CliRunner()
result = runner.invoke(
cli, ["collections", "delete", "does-not-exist"], catch_exceptions=False
)
assert result.exit_code == 1
assert "Collection does not exist" in result.output
assert db["collections"].count == 1
@pytest.mark.parametrize(
"args,expected_error",
(
([], "Missing argument 'COLLECTION'"),
(["badcollection", "-c", "content"], "Collection does not exist"),
(["demo", "bad-id"], "ID not found in collection"),
),
)
def test_similar_errors(args, expected_error, user_path_with_embeddings):
runner = CliRunner()
result = runner.invoke(cli, ["similar"] + args, catch_exceptions=False)
assert result.exit_code != 0
assert expected_error in result.output
def test_similar_by_id_cli(user_path_with_embeddings):
runner = CliRunner()
result = runner.invoke(cli, ["similar", "demo", "1"], catch_exceptions=False)
assert result.exit_code == 0
assert json.loads(result.output) == {
"id": "2",
"score": pytest.approx(0.9863939238321437),
"content": None,
"metadata": None,
}
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
runner = CliRunner()
args = ["similar", "demo"]
input = None
if scenario == "argument":
args.extend(["-c", "hello world"])
elif scenario == "file":
path = tmpdir / "content.txt"
path.write_text("hello world", "utf-8")
args.extend(["-i", str(path)])
elif scenario == "stdin":
input = "hello world"
args.extend(["-i", "-"])
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
lines = [line for line in result.output.splitlines() if line.strip()]
assert len(lines) == 2
assert json.loads(lines[0]) == {
"id": "1",
"score": pytest.approx(0.9999999999999999),
"content": None,
"metadata": None,
}
assert json.loads(lines[1]) == {
"id": "2",
"score": pytest.approx(0.9863939238321437),
"content": None,
"metadata": None,
}
@pytest.mark.parametrize("use_stdin", (False, True))
@pytest.mark.parametrize("prefix", (None, "prefix"))
@pytest.mark.parametrize(
"filename,content",
(
("phrases.csv", "id,phrase\n1,hello world\n2,goodbye world"),
("phrases.tsv", "id\tphrase\n1\thello world\n2\tgoodbye world"),
(
"phrases.jsonl",
'{"id": 1, "phrase": "hello world"}\n{"id": 2, "phrase": "goodbye world"}',
),
(
"phrases.json",
'[{"id": 1, "phrase": "hello world"}, {"id": 2, "phrase": "goodbye world"}]',
),
),
)
def test_embed_multi_file_input(tmpdir, use_stdin, prefix, filename, content):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"]
input = None
if use_stdin:
input = content
args.append("-")
else:
path = tmpdir / filename
path.write_text(content, "utf-8")
args.append(str(path))
if prefix:
args.extend(("--prefix", prefix))
# Auto-detection can't detect JSON-nl, so make that explicit
if filename.endswith(".jsonl"):
args.extend(("--format", "nl"))
runner = CliRunner()
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
# Check that everything was embedded correctly
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 2
ids = [row["id"] for row in db["embeddings"].rows]
expected_ids = ["1", "2"]
if prefix:
expected_ids = ["prefix1", "prefix2"]
assert ids == expected_ids
def test_embed_multi_files_binary_store(tmpdir):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "binfiles", "-d", str(db_path), "-m", "embed-demo"]
bin_path = tmpdir / "file.bin"
bin_path.write(b"\x00\x01\x02")
args.extend(("--files", str(tmpdir), "*.bin", "--store", "--binary"))
runner = CliRunner()
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 1
row = list(db["embeddings"].rows)[0]
assert row == {
"collection_id": 1,
"id": "file.bin",
"embedding": (
b"\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"content_blob": b"\x00\x01\x02",
"content_hash": b'\xb9_g\xf6\x1e\xbb\x03a\x96"\xd7\x98\xf4_\xc2\xd3',
"metadata": None,
"updated": ANY,
}
@pytest.mark.parametrize("use_other_db", (True, False))
@pytest.mark.parametrize("prefix", (None, "prefix"))
def test_embed_multi_sql(tmpdir, use_other_db, prefix):
db_path = str(tmpdir / "embeddings.db")
db = sqlite_utils.Database(db_path)
extra_args = []
if use_other_db:
db_path2 = str(tmpdir / "other.db")
db = sqlite_utils.Database(db_path2)
extra_args = ["--attach", "other", db_path2]
if prefix:
extra_args.extend(("--prefix", prefix))
db["content"].insert_all(
[
{"id": 1, "name": "cli", "description": "Command line interface"},
{"id": 2, "name": "sql", "description": "Structured query language"},
],
pk="id",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"embed-multi",
"stuff",
"-d",
db_path,
"--sql",
"select * from content",
"-m",
"embed-demo",
"--store",
]
+ extra_args,
)
assert result.exit_code == 0
embeddings_db = sqlite_utils.Database(db_path)
assert embeddings_db["embeddings"].count == 2
rows = list(embeddings_db.query("select id, content from embeddings order by id"))
assert rows == [
{"id": (prefix or "") + "1", "content": "cli Command line interface"},
{"id": (prefix or "") + "2", "content": "sql Structured query language"},
]
def test_embed_multi_batch_size(embed_demo, tmpdir):
db_path = str(tmpdir / "data.db")
runner = CliRunner()
sql = """
with recursive cte (id) as (
select 1
union all
select id+1 from cte where id < 100
)
select id, 'Row ' || cast(id as text) as value from cte
"""
assert getattr(embed_demo, "batch_count", 0) == 0
result = runner.invoke(
cli,
[
"embed-multi",
"rows",
"--sql",
sql,
"-d",
db_path,
"-m",
"embed-demo",
"--store",
"--batch-size",
"8",
],
)
assert result.exit_code == 0
db = sqlite_utils.Database(db_path)
assert db["embeddings"].count == 100
assert embed_demo.batch_count == 13
@pytest.fixture
def multi_files(tmpdir):
db_path = str(tmpdir / "files.db")
files = tmpdir / "files"
for filename, content in (
("file1.txt", b"hello world"),
("file2.txt", b"goodbye world"),
("nested/one.txt", b"one"),
("nested/two.txt", b"two"),
("nested/more/three.txt", b"three"),
# This tests the fallback to latin-1 encoding:
("nested/more/ignored.ini", b"Has weird \x96 character"),
):
path = pathlib.Path(files / filename)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(content)
return db_path, tmpdir / "files"
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("scenario", ("single", "multi"))
def test_embed_multi_files(multi_files, scenario):
db_path, files = multi_files
for filename, content in (
("file1.txt", b"hello world"),
("file2.txt", b"goodbye world"),
("nested/one.txt", b"one"),
("nested/two.txt", b"two"),
("nested/more/three.txt", b"three"),
# This tests the fallback to latin-1 encoding:
("nested/more.txt/ignored.ini", b"Has weird \x96 character"),
):
path = pathlib.Path(files / filename)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(content)
if scenario == "single":
extra_args = ["--files", str(files), "**/*.txt"]
else:
extra_args = [
"--files",
str(files / "nested" / "more"),
"**/*.ini",
"--files",
str(files / "nested"),
"*.txt",
]
runner = CliRunner()
result = runner.invoke(
cli,
[
"embed-multi",
"files",
"-d",
db_path,
"-m",
"embed-demo",
"--store",
]
+ extra_args,
)
assert result.exit_code == 0
embeddings_db = sqlite_utils.Database(db_path)
rows = list(embeddings_db.query("select id, content from embeddings order by id"))
if scenario == "single":
assert rows == [
{"id": "file1.txt", "content": "hello world"},
{"id": "file2.txt", "content": "goodbye world"},
{"id": "nested/more/three.txt", "content": "three"},
{"id": "nested/one.txt", "content": "one"},
{"id": "nested/two.txt", "content": "two"},
]
else:
assert rows == [
{"id": "ignored.ini", "content": "Has weird \x96 character"},
{"id": "one.txt", "content": "one"},
{"id": "two.txt", "content": "two"},
]
@pytest.mark.parametrize(
"args,expected_error",
((["not-a-dir", "*.txt"], "Invalid directory: not-a-dir"),),
)
def test_embed_multi_files_errors(multi_files, args, expected_error):
runner = CliRunner()
result = runner.invoke(
cli,
["embed-multi", "files", "-m", "embed-demo", "--files"] + args,
)
assert result.exit_code == 2
assert expected_error in result.output
@pytest.mark.parametrize(
"extra_args,expected_error",
(
# With no args default utf-8 with latin-1 fallback should work
([], None),
(["--encoding", "utf-8"], "Could not decode text in file"),
(["--encoding", "latin-1"], None),
(["--encoding", "latin-1", "--encoding", "utf-8"], None),
(["--encoding", "utf-8", "--encoding", "latin-1"], None),
),
)
def test_embed_multi_files_encoding(multi_files, extra_args, expected_error):
db_path, files = multi_files
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
cli,
[
"embed-multi",
"files",
"-d",
db_path,
"-m",
"embed-demo",
"--files",
str(files / "nested" / "more"),
"*.ini",
"--store",
]
+ extra_args,
)
if expected_error:
# Should still succeed with 0, but show a warning
assert result.exit_code == 0
assert expected_error in result.stderr
else:
assert result.exit_code == 0
assert not result.stderr
embeddings_db = sqlite_utils.Database(db_path)
rows = list(
embeddings_db.query("select id, content from embeddings order by id")
)
assert rows == [
{"id": "ignored.ini", "content": "Has weird \x96 character"},
]
def test_default_embedding_model():
runner = CliRunner()
result = runner.invoke(cli, ["embed-models", "default"])
assert result.exit_code == 0
assert result.output == "<No default embedding model set>\n"
result2 = runner.invoke(cli, ["embed-models", "default", "ada-002"])
assert result2.exit_code == 0
result3 = runner.invoke(cli, ["embed-models", "default"])
assert result3.exit_code == 0
assert result3.output == "text-embedding-ada-002\n"
result4 = runner.invoke(cli, ["embed-models", "default", "--remove-default"])
assert result4.exit_code == 0
result5 = runner.invoke(cli, ["embed-models", "default"])
assert result5.exit_code == 0
assert result5.output == "<No default embedding model set>\n"
# Now set the default and actually use it
result6 = runner.invoke(cli, ["embed-models", "default", "embed-demo"])
assert result6.exit_code == 0
result7 = runner.invoke(cli, ["embed", "-c", "hello world"])
assert result7.exit_code == 0
assert result7.output == "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
@pytest.mark.parametrize("default_is_set", (False, True))
@pytest.mark.parametrize("command", ("embed", "embed-multi"))
def test_default_embed_model_errors(user_path, default_is_set, command):
runner = CliRunner()
if default_is_set:
(user_path / "default_embedding_model.txt").write_text(
"embed-demo", encoding="utf8"
)
args = []
input = None
if command == "embed-multi":
args = ["embed-multi", "example", "-"]
input = "id,name\n1,hello"
else:
args = ["embed", "example", "1", "-c", "hello world"]
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
if default_is_set:
assert result.exit_code == 0
else:
assert result.exit_code == 1
assert (
"You need to specify an embedding model (no default model is set)"
in result.output
)
# Now set the default model and try again
result2 = runner.invoke(cli, ["embed-models", "default", "embed-demo"])
assert result2.exit_code == 0
result3 = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result3.exit_code == 0
# At the end of this, there should be 2 embeddings
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
assert db["embeddings"].count == 1
def test_duplicate_content_embedded_only_once(embed_demo):
# content_hash should avoid embedding the same content twice
# per collection
db = sqlite_utils.Database(memory=True)
assert len(embed_demo.embedded_content) == 0
collection = Collection("test", db, model_id="embed-demo")
collection.embed("1", "hello world")
assert len(embed_demo.embedded_content) == 1
collection.embed("2", "goodbye world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
collection.embed("1", "hello world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
# The same string in another collection should be embedded
c2 = Collection("test2", db, model_id="embed-demo")
c2.embed("1", "hello world")
assert db["embeddings"].count == 3
assert len(embed_demo.embedded_content) == 3
# Same again for embed_multi
collection.embed_multi(
(("1", "hello world"), ("2", "goodbye world"), ("3", "this is new"))
)
# Should have only embedded one more thing
assert db["embeddings"].count == 4
assert len(embed_demo.embedded_content) == 4
</document_content>
</document>
<document index="26">
<source>/Users/simon/Dropbox/Development/llm/tests/test_encode_decode.py</source>
<document_content>
import llm
import pytest
import numpy as np
@pytest.mark.parametrize(
"array",
(
(0.0, 1.0, 1.5),
(3423.0, 222.0, -1234.5),
),
)
def test_roundtrip(array):
encoded = llm.encode(array)
decoded = llm.decode(encoded)
assert decoded == array
# Try with numpy as well
numpy_decoded = np.frombuffer(encoded, "<f4")
assert tuple(numpy_decoded.tolist()) == array
</document_content>
</document>
<document index="27">
<source>/Users/simon/Dropbox/Development/llm/tests/test_keys.py</source>
<document_content>
from click.testing import CliRunner
import json
from llm.cli import cli
import pathlib
import pytest
import sys
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-keys-test"}))
def test_keys_in_user_path(monkeypatch, env, user_path):
for key, value in env.items():
monkeypatch.setenv(key, value)
runner = CliRunner()
result = runner.invoke(cli, ["keys", "path"])
assert result.exit_code == 0
if env:
expected = env["LLM_USER_PATH"] + "/keys.json"
else:
expected = user_path + "/keys.json"
assert result.output.strip() == expected
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_keys_set(monkeypatch, tmpdir):
user_path = tmpdir / "user/keys"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
keys_path = user_path / "keys.json"
assert not keys_path.exists()
runner = CliRunner()
result = runner.invoke(cli, ["keys", "set", "openai"], input="foo")
assert result.exit_code == 0
assert keys_path.exists()
# Should be chmod 600
assert oct(keys_path.stat().mode)[-3:] == "600"
content = keys_path.read_text("utf-8")
assert json.loads(content) == {
"// Note": "This file stores secret API credentials. Do not share!",
"openai": "foo",
}
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
def test_keys_get(monkeypatch, tmpdir):
user_path = tmpdir / "user/keys"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
runner = CliRunner()
result = runner.invoke(cli, ["keys", "set", "openai"], input="fx")
assert result.exit_code == 0
result2 = runner.invoke(cli, ["keys", "get", "openai"])
assert result2.exit_code == 0
assert result2.output.strip() == "fx"
@pytest.mark.parametrize("args", (["keys", "list"], ["keys"]))
def test_keys_list(monkeypatch, tmpdir, args):
user_path = str(tmpdir / "user/keys")
monkeypatch.setenv("LLM_USER_PATH", user_path)
runner = CliRunner()
result = runner.invoke(cli, ["keys", "set", "openai"], input="foo")
assert result.exit_code == 0
result2 = runner.invoke(cli, args)
assert result2.exit_code == 0
assert result2.output.strip() == "openai"
@pytest.mark.httpx_mock(
assert_all_requests_were_expected=False, can_send_already_matched_responses=True
)
def test_uses_correct_key(mocked_openai_chat, monkeypatch, tmpdir):
user_dir = tmpdir / "user-dir"
pathlib.Path(user_dir).mkdir()
keys_path = user_dir / "keys.json"
KEYS = {
"openai": "from-keys-file",
"other": "other-key",
}
keys_path.write_text(json.dumps(KEYS), "utf-8")
monkeypatch.setenv("LLM_USER_PATH", str(user_dir))
monkeypatch.setenv("OPENAI_API_KEY", "from-env")
def assert_key(key):
request = mocked_openai_chat.get_requests()[-1]
assert request.headers["Authorization"] == "Bearer {}".format(key)
runner = CliRunner()
# Called without --key uses stored key
result = runner.invoke(cli, ["hello", "--no-stream"], catch_exceptions=False)
assert result.exit_code == 0
assert_key("from-keys-file")
# Called without --key and without keys.json uses environment variable
keys_path.write_text("{}", "utf-8")
result2 = runner.invoke(cli, ["hello", "--no-stream"], catch_exceptions=False)
assert result2.exit_code == 0
assert_key("from-env")
keys_path.write_text(json.dumps(KEYS), "utf-8")
# Called with --key name-in-keys.json uses that value
result3 = runner.invoke(
cli, ["hello", "--key", "other", "--no-stream"], catch_exceptions=False
)
assert result3.exit_code == 0
assert_key("other-key")
# Called with --key something-else uses exactly that
result4 = runner.invoke(
cli, ["hello", "--key", "custom-key", "--no-stream"], catch_exceptions=False
)
assert result4.exit_code == 0
assert_key("custom-key")
</document_content>
</document>
<document index="28">
<source>/Users/simon/Dropbox/Development/llm/tests/test_llm.py</source>
<document_content>
from click.testing import CliRunner
import datetime
import llm
from llm.cli import cli
from llm.migrations import migrate
from llm.models import Usage
import json
import os
import pathlib
import pytest
import re
import sqlite_utils
import sys
from ulid import ULID
from unittest import mock
def test_version():
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert result.output.startswith("cli, version ")
@pytest.fixture
def log_path(user_path):
log_path = str(user_path / "logs.db")
db = sqlite_utils.Database(log_path)
migrate(db)
start = datetime.datetime.now(datetime.timezone.utc)
db["responses"].insert_all(
{
"id": str(ULID()).lower(),
"system": "system",
"prompt": "prompt",
"response": 'response\n```python\nprint("hello word")\n```',
"model": "davinci",
"datetime_utc": (start + datetime.timedelta(seconds=i)).isoformat(),
"conversation_id": "abc123",
"input_tokens": 2,
"output_tokens": 5,
}
for i in range(100)
)
return log_path
datetime_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@pytest.mark.parametrize("usage", (False, True))
def test_logs_text(log_path, usage):
runner = CliRunner()
args = ["logs", "-p", str(log_path)]
if usage:
args.append("-u")
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
output = result.output
# Replace 2023-08-17T20:53:58 with YYYY-MM-DDTHH:MM:SS
output = datetime_re.sub("YYYY-MM-DDTHH:MM:SS", output)
expected = (
(
"# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n"
"Model: **davinci**\n\n"
"## Prompt:\n\n"
"prompt\n\n"
"## System:\n\n"
"system\n\n"
"## Response:\n\n"
'response\n```python\nprint("hello word")\n```\n\n'
)
+ ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "")
+ (
"# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n"
"Model: **davinci**\n\n"
"## Prompt:\n\n"
"prompt\n\n"
"## Response:\n\n"
'response\n```python\nprint("hello word")\n```\n\n'
)
+ ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "")
+ (
"# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n"
"Model: **davinci**\n\n"
"## Prompt:\n\n"
"prompt\n\n"
"## Response:\n\n"
'response\n```python\nprint("hello word")\n```\n\n'
)
+ ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "")
)
assert output == expected
@pytest.mark.parametrize("n", (None, 0, 2))
def test_logs_json(n, log_path):
"Test that logs command correctly returns requested -n records"
runner = CliRunner()
args = ["logs", "-p", str(log_path), "--json"]
if n is not None:
args.extend(["-n", str(n)])
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
logs = json.loads(result.output)
expected_length = 3
if n is not None:
if n == 0:
expected_length = 100
else:
expected_length = n
assert len(logs) == expected_length
@pytest.mark.parametrize(
"args", (["-r"], ["--response"], ["list", "-r"], ["list", "--response"])
)
def test_logs_response_only(args, log_path):
"Test that logs -r/--response returns just the last response"
runner = CliRunner()
result = runner.invoke(cli, ["logs"] + args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == 'response\n```python\nprint("hello word")\n```\n'
@pytest.mark.parametrize(
"args",
(
["-x"],
["--extract"],
["list", "-x"],
["list", "--extract"],
# Using -xr together should have same effect as just -x
["-xr"],
["-x", "-r"],
["--extract", "--response"],
),
)
def test_logs_extract_first_code(args, log_path):
"Test that logs -x/--extract returns the first code block"
runner = CliRunner()
result = runner.invoke(cli, ["logs"] + args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == 'print("hello word")\n\n'
@pytest.mark.parametrize(
"args",
(
["--xl"],
["--extract-last"],
["list", "--xl"],
["list", "--extract-last"],
["--xl", "-r"],
["-x", "--xl"],
),
)
def test_logs_extract_last_code(args, log_path):
"Test that logs --xl/--extract-last returns the last code block"
runner = CliRunner()
result = runner.invoke(cli, ["logs"] + args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == 'print("hello word")\n\n'
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-user-path"}))
def test_logs_path(monkeypatch, env, user_path):
for key, value in env.items():
monkeypatch.setenv(key, value)
runner = CliRunner()
result = runner.invoke(cli, ["logs", "path"])
assert result.exit_code == 0
if env:
expected = env["LLM_USER_PATH"] + "/logs.db"
else:
expected = str(user_path) + "/logs.db"
assert result.output.strip() == expected
@pytest.mark.parametrize("model", ("davinci", "curie"))
def test_logs_filtered(user_path, model):
log_path = str(user_path / "logs.db")
db = sqlite_utils.Database(log_path)
migrate(db)
db["responses"].insert_all(
{
"id": str(ULID()).lower(),
"system": "system",
"prompt": "prompt",
"response": "response",
"model": "davinci" if i % 2 == 0 else "curie",
}
for i in range(100)
)
runner = CliRunner()
result = runner.invoke(cli, ["logs", "list", "-m", model, "--json"])
assert result.exit_code == 0
records = json.loads(result.output.strip())
assert all(record["model"] == model for record in records)
@pytest.mark.parametrize(
"query,extra_args,expected",
(
# With no search term order should be by datetime
("", [], ["doc1", "doc2", "doc3"]),
# With a search it's order by rank instead
("llama", [], ["doc1", "doc3"]),
("alpaca", [], ["doc2"]),
# Model filter should work too
("llama", ["-m", "davinci"], ["doc1", "doc3"]),
("llama", ["-m", "davinci2"], []),
),
)
def test_logs_search(user_path, query, extra_args, expected):
log_path = str(user_path / "logs.db")
db = sqlite_utils.Database(log_path)
migrate(db)
def _insert(id, text):
db["responses"].insert(
{
"id": id,
"system": "system",
"prompt": text,
"response": "response",
"model": "davinci",
}
)
_insert("doc1", "llama")
_insert("doc2", "alpaca")
_insert("doc3", "llama llama")
runner = CliRunner()
result = runner.invoke(cli, ["logs", "list", "-q", query, "--json"] + extra_args)
assert result.exit_code == 0
records = json.loads(result.output.strip())
assert [record["id"] for record in records] == expected
def test_llm_prompt_creates_log_database(mocked_openai_chat, tmpdir, monkeypatch):
user_path = tmpdir / "user"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
runner = CliRunner()
result = runner.invoke(
cli,
["three names \nfor a pet pelican", "--no-stream", "--key", "x"],
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
# Should have created user_path and put a logs.db in it
assert (user_path / "logs.db").exists()
assert sqlite_utils.Database(str(user_path / "logs.db"))["responses"].count == 1
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize("use_stdin", (True, False, "split"))
@pytest.mark.parametrize(
"logs_off,logs_args,should_log",
(
(True, [], False),
(False, [], True),
(False, ["--no-log"], False),
(False, ["--log"], True),
(True, ["-n"], False), # Short for --no-log
(True, ["--log"], True),
),
)
def test_llm_default_prompt(
mocked_openai_chat, use_stdin, user_path, logs_off, logs_args, should_log
):
# Reset the log_path database
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["responses"].delete_where()
logs_off_path = user_path / "logs-off"
if logs_off:
# Turn off logging
assert not logs_off_path.exists()
CliRunner().invoke(cli, ["logs", "off"])
assert logs_off_path.exists()
else:
# Turn on logging
CliRunner().invoke(cli, ["logs", "on"])
assert not logs_off_path.exists()
# Run the prompt
runner = CliRunner()
prompt = "three names \nfor a pet pelican"
input = None
args = ["--no-stream"]
if use_stdin == "split":
input = "three names"
args.append("\nfor a pet pelican")
elif use_stdin:
input = prompt
else:
args.append(prompt)
args += logs_args
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
last_request = mocked_openai_chat.get_requests()[-1]
assert last_request.headers["Authorization"] == "Bearer X"
# Was it logged?
rows = list(log_db["responses"].rows)
if not should_log:
assert len(rows) == 0
return
assert len(rows) == 1
expected = {
"model": "gpt-4o-mini",
"prompt": "three names \nfor a pet pelican",
"system": None,
"options_json": "{}",
"response": "Bob, Alice, Eve",
}
row = rows[0]
assert expected.items() <= row.items()
assert isinstance(row["duration_ms"], int)
assert isinstance(row["datetime_utc"], str)
assert json.loads(row["prompt_json"]) == {
"messages": [{"role": "user", "content": "three names \nfor a pet pelican"}]
}
assert json.loads(row["response_json"]) == {
"model": "gpt-4o-mini",
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
}
# Test "llm logs"
log_result = runner.invoke(
cli, ["logs", "-n", "1", "--json"], catch_exceptions=False
)
log_json = json.loads(log_result.output)
# Should have logged correctly:
assert (
log_json[0].items()
>= {
"model": "gpt-4o-mini",
"prompt": "three names \nfor a pet pelican",
"system": None,
"prompt_json": {
"messages": [
{"role": "user", "content": "three names \nfor a pet pelican"}
]
},
"options_json": {},
"response": "Bob, Alice, Eve",
"response_json": {
"model": "gpt-4o-mini",
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
# This doesn't have the \n after three names:
"conversation_name": "three names for a pet pelican",
"conversation_model": "gpt-4o-mini",
}.items()
)
@pytest.mark.parametrize(
"args,expect_just_code",
(
(["-x"], True),
(["--extract"], True),
(["-x", "--async"], True),
(["--extract", "--async"], True),
# Use --no-stream here to ensure it passes test same as -x/--extract cases
(["--no-stream"], False),
),
)
def test_extract_fenced_code(
mocked_openai_chat_returning_fenced_code, args, expect_just_code
):
runner = CliRunner()
result = runner.invoke(
cli,
["-m", "gpt-4o-mini", "--key", "x", "Write code"] + args,
catch_exceptions=False,
)
output = result.output
if expect_just_code:
assert "```" not in output
else:
assert "```" in output
def test_openai_chat_stream(mocked_openai_chat_stream, user_path):
runner = CliRunner()
result = runner.invoke(cli, ["-m", "gpt-3.5-turbo", "--key", "x", "Say hi"])
assert result.exit_code == 0
assert result.output == "Hi.\n"
def test_openai_completion(mocked_openai_completion, user_path):
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["responses"].delete_where()
runner = CliRunner()
result = runner.invoke(
cli,
[
"-m",
"gpt-3.5-turbo-instruct",
"Say this is a test",
"--no-stream",
"--key",
"x",
],
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == "\n\nThis is indeed a test\n"
# Should have requested 256 tokens
last_request = mocked_openai_completion.get_requests()[-1]
assert json.loads(last_request.content) == {
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"stream": False,
"max_tokens": 256,
}
# Check it was logged
rows = list(log_db["responses"].rows)
assert len(rows) == 1
expected = {
"model": "gpt-3.5-turbo-instruct",
"prompt": "Say this is a test",
"system": None,
"prompt_json": '{"messages": ["Say this is a test"]}',
"options_json": "{}",
"response": "\n\nThis is indeed a test",
}
row = rows[0]
assert expected.items() <= row.items()
def test_openai_completion_system_prompt_error():
runner = CliRunner()
result = runner.invoke(
cli,
[
"-m",
"gpt-3.5-turbo-instruct",
"Say this is a test",
"--no-stream",
"--key",
"x",
"--system",
"system prompts not allowed",
],
catch_exceptions=False,
)
assert result.exit_code == 1
assert (
result.output
== "Error: System prompts are not supported for OpenAI completion models\n"
)
def test_openai_completion_logprobs_stream(
mocked_openai_completion_logprobs_stream, user_path
):
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["responses"].delete_where()
runner = CliRunner()
args = [
"-m",
"gpt-3.5-turbo-instruct",
"Say hi",
"-o",
"logprobs",
"2",
"--key",
"x",
]
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "\n\nHi.\n"
rows = list(log_db["responses"].rows)
assert len(rows) == 1
row = rows[0]
assert json.loads(row["response_json"]) == {
"content": "\n\nHi.",
"logprobs": [
{"text": "\n\n", "top_logprobs": [{"\n\n": -0.6, "\n": -1.9}]},
{"text": "Hi", "top_logprobs": [{"Hi": -1.1, "Hello": -0.7}]},
{"text": ".", "top_logprobs": [{".": -1.1, "!": -0.9}]},
{"text": "", "top_logprobs": []},
],
"id": "cmpl-80MdSaou7NnPuff5ZyRMysWBmgSPS",
"object": "text_completion",
"model": "gpt-3.5-turbo-instruct",
"created": 1695097702,
}
def test_openai_completion_logprobs_nostream(
mocked_openai_completion_logprobs, user_path
):
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["responses"].delete_where()
runner = CliRunner()
args = [
"-m",
"gpt-3.5-turbo-instruct",
"Say hi",
"-o",
"logprobs",
"2",
"--key",
"x",
"--no-stream",
]
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "\n\nHi.\n"
rows = list(log_db["responses"].rows)
assert len(rows) == 1
row = rows[0]
assert json.loads(row["response_json"]) == {
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": {
"text_offset": [16, 18, 20],
"token_logprobs": [-0.6, -1.1, -0.9],
"tokens": ["\n\n", "Hi", "1"],
"top_logprobs": [
{"\n": -1.9, "\n\n": -0.6},
{"Hello": -0.7, "Hi": -1.1},
{"!": -1.1, ".": -0.9},
],
},
"text": "\n\nHi.",
}
],
"created": 1695097747,
"id": "cmpl-80MeBfKJutM0uMNJkRrebJLeP3bxL",
"model": "gpt-3.5-turbo-instruct",
"object": "text_completion",
"usage": {"completion_tokens": 3, "prompt_tokens": 5, "total_tokens": 8},
}
EXTRA_MODELS_YAML = """
- model_id: orca
model_name: orca-mini-3b
api_base: "http://localai.localhost"
- model_id: completion-babbage
model_name: babbage
api_base: "http://localai.localhost"
completion: 1
"""
def test_openai_localai_configuration(mocked_localai, user_path):
log_path = user_path / "logs.db"
sqlite_utils.Database(str(log_path))
# Write the configuration file
config_path = user_path / "extra-openai-models.yaml"
config_path.write_text(EXTRA_MODELS_YAML, "utf-8")
# Run the prompt
runner = CliRunner()
prompt = "three names \nfor a pet pelican"
result = runner.invoke(cli, ["--no-stream", "--model", "orca", prompt])
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
last_request = mocked_localai.get_requests()[-1]
assert json.loads(last_request.content) == {
"model": "orca-mini-3b",
"messages": [{"role": "user", "content": "three names \nfor a pet pelican"}],
"stream": False,
}
# And check the completion model too
result2 = runner.invoke(cli, ["--no-stream", "--model", "completion-babbage", "hi"])
assert result2.exit_code == 0
assert result2.output == "Hello\n"
last_request2 = mocked_localai.get_requests()[-1]
assert json.loads(last_request2.content) == {
"model": "babbage",
"prompt": "hi",
"stream": False,
}
EXPECTED_OPTIONS = """
OpenAI Chat: gpt-4o (aliases: 4o)
Options:
temperature: float
What sampling temperature to use, between 0 and 2. Higher values like
0.8 will make the output more random, while lower values like 0.2 will
make it more focused and deterministic.
max_tokens: int
Maximum number of tokens to generate.
top_p: float
An alternative to sampling with temperature, called nucleus sampling,
where the model considers the results of the tokens with top_p
probability mass. So 0.1 means only the tokens comprising the top 10%
probability mass are considered. Recommended to use top_p or
temperature but not both.
frequency_penalty: float
Number between -2.0 and 2.0. Positive values penalize new tokens based
on their existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim.
presence_penalty: float
Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far, increasing the model's
likelihood to talk about new topics.
stop: str
A string where the API will stop generating further tokens.
logit_bias: dict, str
Modify the likelihood of specified tokens appearing in the completion.
Pass a JSON string like '{"1712":-100, "892":-100, "1489":-100}'
seed: int
Integer seed to attempt to sample deterministically
json_object: boolean
Output a valid JSON object {...}. Prompt must mention JSON.
Attachment types:
image/gif, image/jpeg, image/png, image/webp
"""
def test_llm_models_options(user_path):
runner = CliRunner()
result = runner.invoke(cli, ["models", "--options"], catch_exceptions=False)
assert result.exit_code == 0
assert EXPECTED_OPTIONS.strip() in result.output
assert "AsyncMockModel: mock" not in result.output
def test_llm_models_async(user_path):
runner = CliRunner()
result = runner.invoke(cli, ["models", "--async"], catch_exceptions=False)
assert result.exit_code == 0
assert "AsyncMockModel: mock" in result.output
@pytest.mark.parametrize("option", ("-q", "--query"))
def test_llm_models_query(user_path, option):
runner = CliRunner()
result = runner.invoke(cli, ["models", option, "mockmodel"], catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "MockModel: mock\n"
def test_llm_user_dir(tmpdir, monkeypatch):
user_dir = str(tmpdir / "u")
monkeypatch.setenv("LLM_USER_PATH", user_dir)
assert not os.path.exists(user_dir)
user_dir2 = llm.user_dir()
assert user_dir == str(user_dir2)
assert os.path.exists(user_dir)
def test_model_defaults(tmpdir, monkeypatch):
user_dir = str(tmpdir / "u")
monkeypatch.setenv("LLM_USER_PATH", user_dir)
config_path = pathlib.Path(user_dir) / "default_model.txt"
assert not config_path.exists()
assert llm.get_default_model() == "gpt-4o-mini"
assert llm.get_model().model_id == "gpt-4o-mini"
llm.set_default_model("gpt-4o")
assert config_path.exists()
assert llm.get_default_model() == "gpt-4o"
assert llm.get_model().model_id == "gpt-4o"
def test_get_models():
models = llm.get_models()
assert all(isinstance(model, llm.Model) for model in models)
model_ids = [model.model_id for model in models]
assert "gpt-4o-mini" in model_ids
# Ensure no model_ids are duplicated
# https://github.com/simonw/llm/issues/667
assert len(model_ids) == len(set(model_ids))
def test_get_async_models():
models = llm.get_async_models()
assert all(isinstance(model, llm.AsyncModel) for model in models)
model_ids = [model.model_id for model in models]
assert "gpt-4o-mini" in model_ids
def test_mock_model(mock_model):
mock_model.enqueue(["hello world"])
mock_model.enqueue(["second"])
model = llm.get_model("mock")
response = model.prompt(prompt="hello")
assert response.text() == "hello world"
assert str(response) == "hello world"
assert model.history[0][0].prompt == "hello"
assert response.usage() == Usage(input=1, output=1, details=None)
response2 = model.prompt(prompt="hello again")
assert response2.text() == "second"
assert response2.usage() == Usage(input=2, output=1, details=None)
def test_sync_on_done(mock_model):
mock_model.enqueue(["hello world"])
model = llm.get_model("mock")
response = model.prompt(prompt="hello")
caught = []
def done(response):
caught.append(response)
response.on_done(done)
assert len(caught) == 0
str(response)
assert len(caught) == 1
</document_content>
</document>
<document index="29">
<source>/Users/simon/Dropbox/Development/llm/tests/test_migrate.py</source>
<document_content>
import llm
from llm.migrations import migrate
from llm.embeddings_migrations import embeddings_migrations
import pytest
import sqlite_utils
EXPECTED = {
"id": str,
"model": str,
"prompt": str,
"system": str,
"prompt_json": str,
"options_json": str,
"response": str,
"response_json": str,
"conversation_id": str,
"duration_ms": int,
"datetime_utc": str,
"input_tokens": int,
"output_tokens": int,
"token_details": str,
}
def test_migrate_blank():
db = sqlite_utils.Database(memory=True)
migrate(db)
assert set(db.table_names()).issuperset(
{"_llm_migrations", "conversations", "responses", "responses_fts"}
)
assert db["responses"].columns_dict == EXPECTED
foreign_keys = db["responses"].foreign_keys
for expected_fk in (
sqlite_utils.db.ForeignKey(
table="responses",
column="conversation_id",
other_table="conversations",
other_column="id",
),
):
assert expected_fk in foreign_keys
@pytest.mark.parametrize("has_record", [True, False])
def test_migrate_from_original_schema(has_record):
db = sqlite_utils.Database(memory=True)
if has_record:
db["log"].insert(
{
"provider": "provider",
"system": "system",
"prompt": "prompt",
"chat_id": None,
"response": "response",
"model": "model",
"timestamp": "timestamp",
},
)
else:
# Create empty logs table
db["log"].create(
{
"provider": str,
"system": str,
"prompt": str,
"chat_id": str,
"response": str,
"model": str,
"timestamp": str,
}
)
migrate(db)
expected_tables = {"_llm_migrations", "conversations", "responses", "responses_fts"}
if has_record:
expected_tables.add("logs")
assert set(db.table_names()).issuperset(expected_tables)
def test_migrations_with_legacy_alter_table():
# https://github.com/simonw/llm/issues/162
db = sqlite_utils.Database(memory=True)
db.execute("pragma legacy_alter_table=on")
migrate(db)
def test_migrations_for_embeddings():
db = sqlite_utils.Database(memory=True)
embeddings_migrations.apply(db)
assert db["collections"].columns_dict == {"id": int, "name": str, "model": str}
assert db["embeddings"].columns_dict == {
"collection_id": int,
"id": str,
"embedding": bytes,
"content": str,
"content_blob": bytes,
"content_hash": bytes,
"metadata": str,
"updated": int,
}
assert db["embeddings"].foreign_keys[0].column == "collection_id"
assert db["embeddings"].foreign_keys[0].other_table == "collections"
def test_backfill_content_hash():
db = sqlite_utils.Database(memory=True)
# Run migrations up to but not including m004_store_content_hash
embeddings_migrations.apply(db, stop_before="m004_store_content_hash")
assert "content_hash" not in db["embeddings"].columns_dict
# Add some some directly directly because llm.Collection would run migrations
db["embeddings"].insert_all(
[
{
"collection_id": 1,
"id": "1",
"embedding": (
b"\x00\x00\xa0@\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"metadata": None,
"updated": 1693763088,
},
{
"collection_id": 1,
"id": "2",
"embedding": (
b"\x00\x00\xe0@\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": "goodbye world",
"metadata": None,
"updated": 1693763088,
},
]
)
# Now finish the migrations
embeddings_migrations.apply(db)
row1, row2 = db["embeddings"].rows
# This one should be random:
assert row1["content_hash"] is not None
# This should be a hash of 'goodbye world'
assert row2["content_hash"] == llm.Collection.content_hash("goodbye world")
</document_content>
</document>
<document index="30">
<source>/Users/simon/Dropbox/Development/llm/tests/test_plugins.py</source>
<document_content>
from click.testing import CliRunner
import click
import importlib
import llm
from llm import cli, hookimpl, plugins
def test_register_commands():
importlib.reload(cli)
def plugin_names():
return [plugin["name"] for plugin in llm.get_plugins()]
assert "HelloWorldPlugin" not in plugin_names()
class HelloWorldPlugin:
__name__ = "HelloWorldPlugin"
@hookimpl
def register_commands(self, cli):
@cli.command(name="hello-world")
def hello_world():
"Print hello world"
click.echo("Hello world!")
try:
plugins.pm.register(HelloWorldPlugin(), name="HelloWorldPlugin")
importlib.reload(cli)
assert "HelloWorldPlugin" in plugin_names()
runner = CliRunner()
result = runner.invoke(cli.cli, ["hello-world"])
assert result.exit_code == 0
assert result.output == "Hello world!\n"
finally:
plugins.pm.unregister(name="HelloWorldPlugin")
importlib.reload(cli)
assert "HelloWorldPlugin" not in plugin_names()
</document_content>
</document>
<document index="31">
<source>/Users/simon/Dropbox/Development/llm/tests/test_templates.py</source>
<document_content>
from click.testing import CliRunner
import json
from llm import Template
from llm.cli import cli
import os
from unittest import mock
import pytest
import yaml
@pytest.mark.parametrize(
"prompt,system,defaults,params,expected_prompt,expected_system,expected_error",
(
("S: $input", None, None, {}, "S: input", None, None),
("S: $input", "system", None, {}, "S: input", "system", None),
("No vars", None, None, {}, "No vars", None, None),
("$one and $two", None, None, {}, None, None, "Missing variables: one, two"),
("$one and $two", None, None, {"one": 1, "two": 2}, "1 and 2", None, None),
("$one and $two", None, {"one": 1}, {"two": 2}, "1 and 2", None, None),
(
"$one and $two",
None,
{"one": 99},
{"one": 1, "two": 2},
"1 and 2",
None,
None,
),
),
)
def test_template_evaluate(
prompt, system, defaults, params, expected_prompt, expected_system, expected_error
):
t = Template(name="t", prompt=prompt, system=system, defaults=defaults)
if expected_error:
with pytest.raises(Template.MissingVariables) as ex:
prompt, system = t.evaluate("input", params)
assert ex.value.args[0] == expected_error
else:
prompt, system = t.evaluate("input", params)
assert prompt == expected_prompt
assert system == expected_system
def test_templates_list_no_templates_found():
runner = CliRunner()
result = runner.invoke(cli, ["templates", "list"])
assert result.exit_code == 0
assert result.output == ""
@pytest.mark.parametrize("args", (["templates", "list"], ["templates"]))
def test_templates_list(templates_path, args):
(templates_path / "one.yaml").write_text("template one", "utf-8")
(templates_path / "two.yaml").write_text("template two", "utf-8")
(templates_path / "three.yaml").write_text(
"template three is very long " * 4, "utf-8"
)
(templates_path / "four.yaml").write_text(
"'this one\n\nhas newlines in it'", "utf-8"
)
(templates_path / "both.yaml").write_text(
"system: summarize this\nprompt: $input", "utf-8"
)
(templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8")
runner = CliRunner()
result = runner.invoke(cli, args)
assert result.exit_code == 0
assert result.output == (
"both : system: summarize this prompt: $input\n"
"four : this one has newlines in it\n"
"one : template one\n"
"sys : system: Summarize this\n"
"three : template three is very long template three is very long template thre...\n"
"two : template two\n"
)
@pytest.mark.parametrize(
"args,expected_prompt,expected_error",
(
(["-m", "gpt4", "hello"], {"model": "gpt-4", "prompt": "hello"}, None),
(["hello $foo"], {"prompt": "hello $foo"}, None),
(["--system", "system"], {"system": "system"}, None),
(["-t", "template"], None, "--save cannot be used with --template"),
(["--continue"], None, "--save cannot be used with --continue"),
(["--cid", "123"], None, "--save cannot be used with --cid"),
(["--conversation", "123"], None, "--save cannot be used with --cid"),
(
["Say hello as $name", "-p", "name", "default-name"],
{"prompt": "Say hello as $name", "defaults": {"name": "default-name"}},
None,
),
# -x/--extract should be persisted:
(
["--system", "write python", "--extract"],
{"system": "write python", "extract": True},
None,
),
),
)
def test_templates_prompt_save(templates_path, args, expected_prompt, expected_error):
assert not (templates_path / "saved.yaml").exists()
runner = CliRunner()
result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False)
if not expected_error:
assert result.exit_code == 0
assert (
yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
== expected_prompt
)
else:
assert result.exit_code == 1
assert expected_error in result.output
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize(
"template,extra_args,expected_model,expected_input,expected_error",
(
(
"'Summarize this: $input'",
[],
"gpt-4o-mini",
"Summarize this: Input text",
None,
),
(
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
[],
"gpt-4",
"Summarize this: Input text",
None,
),
(
"prompt: 'Summarize this: $input'",
["-m", "4"],
"gpt-4",
"Summarize this: Input text",
None,
),
pytest.param(
"boo",
["-s", "s"],
None,
None,
"Error: Cannot use -t/--template and --system together",
marks=pytest.mark.httpx_mock(),
),
pytest.param(
"prompt: 'Say $hello'",
[],
None,
None,
"Error: Missing variables: hello",
marks=pytest.mark.httpx_mock(),
),
(
"prompt: 'Say $hello'",
["-p", "hello", "Blah"],
"gpt-4o-mini",
"Say Blah",
None,
),
),
)
def test_template_basic(
templates_path,
mocked_openai_chat,
template,
extra_args,
expected_model,
expected_input,
expected_error,
):
(templates_path / "template.yaml").write_text(template, "utf-8")
runner = CliRunner()
result = runner.invoke(
cli,
["--no-stream", "-t", "template", "Input text"] + extra_args,
catch_exceptions=False,
)
if expected_error is None:
assert result.exit_code == 0
last_request = mocked_openai_chat.get_requests()[-1]
assert json.loads(last_request.content) == {
"model": expected_model,
"messages": [{"role": "user", "content": expected_input}],
"stream": False,
}
else:
assert result.exit_code == 1
assert result.output.strip() == expected_error
mocked_openai_chat.reset()
</document_content>
</document>
<document index="32">
<source>/Users/simon/Dropbox/Development/llm/tests/test_utils.py</source>
<document_content>
import pytest
from llm.utils import simplify_usage_dict, extract_fenced_code_block
@pytest.mark.parametrize(
"input_data,expected_output",
[
(
{
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 1,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
{"completion_tokens_details": {"audio_tokens": 1}},
),
(
{
"details": {"tokens": 5, "audio_tokens": 2},
"more_details": {"accepted_tokens": 3},
},
{
"details": {"tokens": 5, "audio_tokens": 2},
"more_details": {"accepted_tokens": 3},
},
),
({"details": {"tokens": 0, "audio_tokens": 0}, "more_details": {}}, {}),
({"level1": {"level2": {"value": 0, "another_value": {}}}}, {}),
(
{
"level1": {"level2": {"value": 0, "another_value": 1}},
"level3": {"empty_dict": {}, "valid_token": 10},
},
{"level1": {"level2": {"another_value": 1}}, "level3": {"valid_token": 10}},
),
],
)
def test_simplify_usage_dict(input_data, expected_output):
assert simplify_usage_dict(input_data) == expected_output
@pytest.mark.parametrize(
"input,last,expected",
[
["This is a sample text without any code blocks.", False, None],
[
"Here is some text.\n\n```\ndef foo():\n return 'bar'\n```\n\nMore text.",
False,
"def foo():\n return 'bar'\n",
],
[
"Here is some text.\n\n```python\ndef foo():\n return 'bar'\n```\n\nMore text.",
False,
"def foo():\n return 'bar'\n",
],
[
"Here is some text.\n\n````\ndef foo():\n return 'bar'\n````\n\nMore text.",
False,
"def foo():\n return 'bar'\n",
],
[
"Here is some text.\n\n````javascript\nfunction foo() {\n return 'bar';\n}\n````\n\nMore text.",
False,
"function foo() {\n return 'bar';\n}\n",
],
[
"Here is some text.\n\n```python\ndef foo():\n return 'bar'\n````\n\nMore text.",
False,
None,
],
[
"First code block:\n\n```python\ndef foo():\n return 'bar'\n```\n\n"
"Second code block:\n\n```javascript\nfunction foo() {\n return 'bar';\n}\n```",
False,
"def foo():\n return 'bar'\n",
],
[
"First code block:\n\n```python\ndef foo():\n return 'bar'\n```\n\n"
"Second code block:\n\n```javascript\nfunction foo() {\n return 'bar';\n}\n```",
True,
"function foo() {\n return 'bar';\n}\n",
],
[
"First code block:\n\n```python\ndef foo():\n return 'bar'\n```\n\n"
# This one has trailing whitespace after the second code block:
# https://github.com/simonw/llm/pull/718#issuecomment-2613177036
"Second code block:\n\n```javascript\nfunction foo() {\n return 'bar';\n}\n``` ",
True,
"function foo() {\n return 'bar';\n}\n",
],
[
"Here is some text.\n\n```python\ndef foo():\n return `bar`\n```\n\nMore text.",
False,
"def foo():\n return `bar`\n",
],
],
)
def test_extract_fenced_code_block(input, last, expected):
actual = extract_fenced_code_block(input, last=last)
assert actual == expected
</document_content>
</document>
<document index="33">
<source>/Users/simon/Dropbox/Development/llm/docs/conf.py</source>
<document_content>
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from subprocess import PIPE, Popen
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ["myst_parser", "sphinx_copybutton"]
myst_enable_extensions = ["colon_fence"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# General information about the project.
project = "LLM"
copyright = "2023, Simon Willison"
author = "Simon Willison"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
pipe = Popen("git describe --tags --always", stdout=PIPE, shell=True)
git_version = pipe.stdout.read().decode("utf8")
if git_version:
version = git_version.rsplit("-", 1)[0]
release = git_version
else:
version = ""
release = ""
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "furo"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
html_theme_options = {}
html_title = "LLM"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "llm-doc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(
master_doc,
"llm.tex",
"LLM documentation",
"Simon Willison",
"manual",
)
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(
master_doc,
"llm",
"LLM documentation",
[author],
1,
)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"llm",
"LLM documentation",
author,
"llm",
" Access large language models from the command-line ",
"Miscellaneous",
)
]
</document_content>
</document>
<document index="34">
<source>/Users/simon/Dropbox/Development/llm/docs/plugins/llm-markov/llm_markov.py</source>
<document_content>
import llm
import random
import time
from typing import Optional
from pydantic import field_validator, Field
@llm.hookimpl
def register_models(register):
register(Markov())
def build_markov_table(text):
words = text.split()
transitions = {}
# Loop through all but the last word
for i in range(len(words) - 1):
word = words[i]
next_word = words[i + 1]
transitions.setdefault(word, []).append(next_word)
return transitions
def generate(transitions, length, start_word=None):
all_words = list(transitions.keys())
next_word = start_word or random.choice(all_words)
for i in range(length):
yield next_word
options = transitions.get(next_word) or all_words
next_word = random.choice(options)
class Markov(llm.Model):
model_id = "markov"
can_stream = True
class Options(llm.Options):
length: Optional[int] = Field(
description="Number of words to generate", default=None
)
delay: Optional[float] = Field(
description="Seconds to delay between each token", default=None
)
@field_validator("length")
def validate_length(cls, length):
if length is None:
return None
if length < 2:
raise ValueError("length must be >= 2")
return length
@field_validator("delay")
def validate_delay(cls, delay):
if delay is None:
return None
if not 0 <= delay <= 10:
raise ValueError("delay must be between 0 and 10")
return delay
def execute(self, prompt, stream, response, conversation):
text = prompt.prompt
transitions = build_markov_table(text)
length = prompt.options.length or 20
for word in generate(transitions, length):
yield word + " "
if prompt.options.delay:
time.sleep(prompt.options.delay)
</document_content>
</document>
</documents>
describe this codebase in detail
This is a Python codebase that extends the capabilities of llm
, an existing large language model (LLM) command-line interface. The main focus here appears to be adding a new feature called "Markov Model" for generating text based on user-provided input data using Markov chains.
Here's a detailed breakdown:
-
Main Modules:
llm
: This seems to import the base framework (llm
) that this codebase is extending.random
,time
, andtyping
: These are standard Python libraries used for generating random choices, handling time delays, and type hinting respectively.
-
Markov Chain Generation:
-
The
build_markov_table
function takes raw text input and builds a dictionary of word transitions where each key is a word from the document, and its value is a list of words that follow it.def build_markov_table(text): words = text.split() transitions = {} for i in range(len(words) - 1): word = words[i] next_word = words[i + 1] transitions.setdefault(word, []).append(next_word) return transitions
-
The
generate
function takes a transition table and a specified length to generate a sequence of words. It uses the provided start word or randomly selects one from available keys if none is provided.def generate(transitions, length, start_word=None): all_words = list(transitions.keys()) next_word = start_word or random.choice(all_words) for i in range(length): yield next_word options = transitions.get(next_word) or all_words next_word = random.choice(options)
-
-
Markov Model Implementation:
-
Markov
class inherits fromllm.Model
, suggesting it's part of a larger framework for handling different types of models.@llm.hookimpl def register_models(register): register(Markov())
This decorator is likely used by the
llm
framework to automatically register this model when initialized. -
The class contains an
Options
inner class for defining configuration parameters specific to the Markov model. These include:-
length
: Number of words to generate. -
delay
: Seconds between each word generation, with a validation that it must be within 0 and 10 seconds.@field_validator("length") def validate_length(cls, length): if length is None: return None if length < 2: raise ValueError("length must be >= 2") return length @field_validator("delay") def validate_delay(cls, delay): if delay is None: return None if not 0 <= delay <= 10: raise ValueError("delay must be between 0 and 10") return delay
-
-
The
execute
method processes the provided prompt and generates text using the Markov model. It first builds a transition table from the user-provided input (prompt.prompt
) and then starts generating words according to the specified length and delay.def execute(self, prompt, stream, response, conversation): text = prompt.prompt transitions = build_markov_table(text) length = prompt.options.length or 20 for word in generate(transitions, length): yield word + " " if prompt.options.delay: time.sleep(prompt.options.delay)
-
-
The
register_models
function from the framework is used to automatically register this model when the application initializes.@llm.hookimpl def register_models(register): register(Markov())
-
This implies that there's a parent or main module (
llm
) which provides hooks for extending its capabilities, and this script registers itself as one of those extensions.
The code makes use of pydantic
for data validation. Pydantic is used to ensure that the options provided (like length
and delay
) are within acceptable ranges before they're processed by the model.
from pydantic import field_validator, Field
In summary, this codebase provides an extension to a larger LLM framework (llm
) by adding functionality for generating text using Markov chains. It allows users to specify parameters such as how many words to generate and the delay between each word, offering a controlled way to explore or analyze sequences generated from provided input texts.
The design leverages Python's standard libraries alongside additional utilities like Pydantic for enhanced readability and maintainability.
2,048 input, 999 output