-
-
Save simonw/ac45c6638ea87942383e97c5cf69ae09 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"id": "setup.py:7", "code": "def get_long_description():\n with open(\n os.path.join(os.path.dirname(os.path.abspath(__file__)), \"README.md\"),\n encoding=\"utf8\",\n ) as fp:\n return fp.read()"} | |
{"id": "llm/plugins.py:18", "code": "def load_plugins():\n global _loaded\n if _loaded:\n return\n _loaded = True\n if not hasattr(sys, \"_called_from_test\") and LLM_LOAD_PLUGINS is None:\n # Only load plugins if not running tests\n pm.load_setuptools_entrypoints(\"llm\")\n\n # Load any plugins specified in LLM_LOAD_PLUGINS\")\n if LLM_LOAD_PLUGINS is not None:\n for package_name in [\n name for name in LLM_LOAD_PLUGINS.split(\",\") if name.strip()\n ]:\n try:\n distribution = metadata.distribution(package_name) # Updated call\n llm_entry_points = [\n ep for ep in distribution.entry_points if ep.group == \"llm\"\n ]\n for entry_point in llm_entry_points:\n mod = entry_point.load()\n pm.register(mod, name=entry_point.name)\n # Ensure name can be found in plugin_to_distinfo later:\n pm._plugin_distinfo.append((mod, distribution)) # type: ignore\n except metadata.PackageNotFoundError:\n sys.stderr.write(f\"Plugin {package_name} could not be found\\n\")\n\n for plugin in DEFAULT_PLUGINS:\n mod = importlib.import_module(plugin)\n pm.register(mod, plugin)"} | |
{"id": "llm/embeddings_migrations.py:8", "code": "@embeddings_migrations()\ndef m001_create_tables(db):\n db[\"collections\"].create({\"id\": int, \"name\": str, \"model\": str}, pk=\"id\")\n db[\"collections\"].create_index([\"name\"], unique=True)\n db[\"embeddings\"].create(\n {\n \"collection_id\": int,\n \"id\": str,\n \"embedding\": bytes,\n \"content\": str,\n \"metadata\": str,\n },\n pk=(\"collection_id\", \"id\"),\n )"} | |
{"id": "llm/embeddings_migrations.py:24", "code": "@embeddings_migrations()\ndef m002_foreign_key(db):\n db[\"embeddings\"].add_foreign_key(\"collection_id\", \"collections\", \"id\")"} | |
{"id": "llm/embeddings_migrations.py:29", "code": "@embeddings_migrations()\ndef m003_add_updated(db):\n db[\"embeddings\"].add_column(\"updated\", int)\n # Pretty-print the schema\n db[\"embeddings\"].transform()\n # Assume anything existing was last updated right now\n db.query(\n \"update embeddings set updated = ? where updated is null\", [int(time.time())]\n )"} | |
{"id": "llm/embeddings_migrations.py:40", "code": "@embeddings_migrations()\ndef m004_store_content_hash(db):\n db[\"embeddings\"].add_column(\"content_hash\", bytes)\n db[\"embeddings\"].transform(\n column_order=(\n \"collection_id\",\n \"id\",\n \"embedding\",\n \"content\",\n \"content_hash\",\n \"metadata\",\n \"updated\",\n )\n )\n\n # Register functions manually so we can de-register later\n def md5(text):\n return hashlib.md5(text.encode(\"utf8\")).digest()\n\n def random_md5():\n return hashlib.md5(str(time.time()).encode(\"utf8\")).digest()\n\n db.conn.create_function(\"temp_md5\", 1, md5)\n db.conn.create_function(\"temp_random_md5\", 0, random_md5)\n\n with db.conn:\n db.execute(\n \"\"\"\n update embeddings\n set content_hash = temp_md5(content)\n where content is not null\n \"\"\"\n )\n db.execute(\n \"\"\"\n update embeddings\n set content_hash = temp_random_md5()\n where content is null\n \"\"\"\n )\n\n db[\"embeddings\"].create_index([\"content_hash\"])\n\n # De-register functions\n db.conn.create_function(\"temp_md5\", 1, None)\n db.conn.create_function(\"temp_random_md5\", 0, None)"} | |
{"id": "llm/embeddings_migrations.py:88", "code": "@embeddings_migrations()\ndef m005_add_content_blob(db):\n db[\"embeddings\"].add_column(\"content_blob\", bytes)\n db[\"embeddings\"].transform(\n column_order=(\"collection_id\", \"id\", \"embedding\", \"content\", \"content_blob\")\n )"} | |
{"id": "llm/hookspecs.py:8", "code": "@hookspec\ndef register_commands(cli):\n \"\"\"Register additional CLI commands, e.g. 'llm mycommand ...'\"\"\""} | |
{"id": "llm/hookspecs.py:13", "code": "@hookspec\ndef register_models(register):\n \"Register additional model instances representing LLM models that can be called\""} | |
{"id": "llm/hookspecs.py:18", "code": "@hookspec\ndef register_embedding_models(register):\n \"Register additional model instances that can be used for embedding\""} | |
{"id": "llm/hookspecs.py:23", "code": "@hookspec\ndef register_template_loaders(register):\n \"Register additional template loaders with prefixes\""} | |
{"id": "llm/models.py:37", "code": "@dataclass\nclass Usage:\n input: Optional[int] = None\n output: Optional[int] = None\n details: Optional[Dict[str, Any]] = None"} | |
{"id": "llm/models.py:44", "code": "@dataclass\nclass Attachment:\n type: Optional[str] = None\n path: Optional[str] = None\n url: Optional[str] = None\n content: Optional[bytes] = None\n _id: Optional[str] = None\n\n def id(self):\n # Hash of the binary content, or of '{\"url\": \"https://...\"}' for URL attachments\n if self._id is None:\n if self.content:\n self._id = hashlib.sha256(self.content).hexdigest()\n elif self.path:\n self._id = hashlib.sha256(open(self.path, \"rb\").read()).hexdigest()\n else:\n self._id = hashlib.sha256(\n json.dumps({\"url\": self.url}).encode(\"utf-8\")\n ).hexdigest()\n return self._id\n\n def resolve_type(self):\n if self.type:\n return self.type\n # Derive it from path or url or content\n if self.path:\n return mimetype_from_path(self.path)\n if self.url:\n response = httpx.head(self.url)\n response.raise_for_status()\n return response.headers.get(\"content-type\")\n if self.content:\n return mimetype_from_string(self.content)\n raise ValueError(\"Attachment has no type and no content to derive it from\")\n\n def content_bytes(self):\n content = self.content\n if not content:\n if self.path:\n content = open(self.path, \"rb\").read()\n elif self.url:\n response = httpx.get(self.url)\n response.raise_for_status()\n content = response.content\n return content\n\n def base64_content(self):\n return base64.b64encode(self.content_bytes()).decode(\"utf-8\")\n\n @classmethod\n def from_row(cls, row):\n return cls(\n _id=row[\"id\"],\n type=row[\"type\"],\n path=row[\"path\"],\n url=row[\"url\"],\n content=row[\"content\"],\n )"} | |
{"id": "llm/models.py:52", "code": " def id(self):\n # Hash of the binary content, or of '{\"url\": \"https://...\"}' for URL attachments\n if self._id is None:\n if self.content:\n self._id = hashlib.sha256(self.content).hexdigest()\n elif self.path:\n self._id = hashlib.sha256(open(self.path, \"rb\").read()).hexdigest()\n else:\n self._id = hashlib.sha256(\n json.dumps({\"url\": self.url}).encode(\"utf-8\")\n ).hexdigest()\n return self._id"} | |
{"id": "llm/models.py:65", "code": " def resolve_type(self):\n if self.type:\n return self.type\n # Derive it from path or url or content\n if self.path:\n return mimetype_from_path(self.path)\n if self.url:\n response = httpx.head(self.url)\n response.raise_for_status()\n return response.headers.get(\"content-type\")\n if self.content:\n return mimetype_from_string(self.content)\n raise ValueError(\"Attachment has no type and no content to derive it from\")"} | |
{"id": "llm/models.py:79", "code": " def content_bytes(self):\n content = self.content\n if not content:\n if self.path:\n content = open(self.path, \"rb\").read()\n elif self.url:\n response = httpx.get(self.url)\n response.raise_for_status()\n content = response.content\n return content"} | |
{"id": "llm/models.py:90", "code": " def base64_content(self):\n return base64.b64encode(self.content_bytes()).decode(\"utf-8\")"} | |
{"id": "llm/models.py:93", "code": " @classmethod\n def from_row(cls, row):\n return cls(\n _id=row[\"id\"],\n type=row[\"type\"],\n path=row[\"path\"],\n url=row[\"url\"],\n content=row[\"content\"],\n )"} | |
{"id": "llm/models.py:104", "code": "@dataclass\nclass Prompt:\n prompt: Optional[str]\n model: \"Model\"\n attachments: Optional[List[Attachment]]\n system: Optional[str]\n prompt_json: Optional[str]\n schema: Optional[Union[Dict, type[BaseModel]]]\n options: \"Options\"\n\n def __init__(\n self,\n prompt,\n model,\n *,\n attachments=None,\n system=None,\n prompt_json=None,\n options=None,\n schema=None,\n ):\n self.prompt = prompt\n self.model = model\n self.attachments = list(attachments or [])\n self.system = system\n self.prompt_json = prompt_json\n if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):\n schema = schema.model_json_schema()\n self.schema = schema\n self.options = options or {}"} | |
{"id": "llm/models.py:114", "code": " def __init__(\n self,\n prompt,\n model,\n *,\n attachments=None,\n system=None,\n prompt_json=None,\n options=None,\n schema=None,\n ):\n self.prompt = prompt\n self.model = model\n self.attachments = list(attachments or [])\n self.system = system\n self.prompt_json = prompt_json\n if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):\n schema = schema.model_json_schema()\n self.schema = schema\n self.options = options or {}"} | |
{"id": "llm/models.py:136", "code": "@dataclass\nclass _BaseConversation:\n model: \"_BaseModel\"\n id: str = field(default_factory=lambda: str(ULID()).lower())\n name: Optional[str] = None\n responses: List[\"_BaseResponse\"] = field(default_factory=list)\n\n @classmethod\n @abstractmethod\n def from_row(cls, row: Any) -> \"_BaseConversation\":\n raise NotImplementedError"} | |
{"id": "llm/models.py:143", "code": " @classmethod\n @abstractmethod\n def from_row(cls, row: Any) -> \"_BaseConversation\":\n raise NotImplementedError"} | |
{"id": "llm/models.py:149", "code": "@dataclass\nclass Conversation(_BaseConversation):\n def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n key: Optional[str] = None,\n **options,\n ) -> \"Response\":\n return Response(\n Prompt(\n prompt,\n model=self.model,\n attachments=attachments,\n system=system,\n schema=schema,\n options=self.model.Options(**options),\n ),\n self.model,\n stream,\n conversation=self,\n key=key,\n )\n\n @classmethod\n def from_row(cls, row):\n from llm import get_model\n\n return cls(\n model=get_model(row[\"model\"]),\n id=row[\"id\"],\n name=row[\"name\"],\n )\n\n def __repr__(self):\n count = len(self.responses)\n s = \"s\" if count == 1 else \"\"\n return f\"<{self.__class__.__name__}: {self.id} - {count} response{s}\""} | |
{"id": "llm/models.py:151", "code": " def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n key: Optional[str] = None,\n **options,\n ) -> \"Response\":\n return Response(\n Prompt(\n prompt,\n model=self.model,\n attachments=attachments,\n system=system,\n schema=schema,\n options=self.model.Options(**options),\n ),\n self.model,\n stream,\n conversation=self,\n key=key,\n )"} | |
{"id": "llm/models.py:177", "code": " @classmethod\n def from_row(cls, row):\n from llm import get_model\n\n return cls(\n model=get_model(row[\"model\"]),\n id=row[\"id\"],\n name=row[\"name\"],\n )"} | |
{"id": "llm/models.py:187", "code": " def __repr__(self):\n count = len(self.responses)\n s = \"s\" if count == 1 else \"\"\n return f\"<{self.__class__.__name__}: {self.id} - {count} response{s}\""} | |
{"id": "llm/models.py:193", "code": "@dataclass\nclass AsyncConversation(_BaseConversation):\n def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n key: Optional[str] = None,\n **options,\n ) -> \"AsyncResponse\":\n return AsyncResponse(\n Prompt(\n prompt,\n model=self.model,\n attachments=attachments,\n system=system,\n schema=schema,\n options=self.model.Options(**options),\n ),\n self.model,\n stream,\n conversation=self,\n key=key,\n )\n\n @classmethod\n def from_row(cls, row):\n from llm import get_async_model\n\n return cls(\n model=get_async_model(row[\"model\"]),\n id=row[\"id\"],\n name=row[\"name\"],\n )\n\n def __repr__(self):\n count = len(self.responses)\n s = \"s\" if count == 1 else \"\"\n return f\"<{self.__class__.__name__}: {self.id} - {count} response{s}\""} | |
{"id": "llm/models.py:195", "code": " def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n key: Optional[str] = None,\n **options,\n ) -> \"AsyncResponse\":\n return AsyncResponse(\n Prompt(\n prompt,\n model=self.model,\n attachments=attachments,\n system=system,\n schema=schema,\n options=self.model.Options(**options),\n ),\n self.model,\n stream,\n conversation=self,\n key=key,\n )"} | |
{"id": "llm/models.py:221", "code": " @classmethod\n def from_row(cls, row):\n from llm import get_async_model\n\n return cls(\n model=get_async_model(row[\"model\"]),\n id=row[\"id\"],\n name=row[\"name\"],\n )"} | |
{"id": "llm/models.py:231", "code": " def __repr__(self):\n count = len(self.responses)\n s = \"s\" if count == 1 else \"\"\n return f\"<{self.__class__.__name__}: {self.id} - {count} response{s}\""} | |
{"id": "llm/models.py:237", "code": "class Annotation(BaseModel):\n start_index: int\n end_index: int\n data: dict\n\n @classmethod\n def from_row(cls, row):\n return cls(\n start_index=row[\"start_index\"],\n end_index=row[\"end_index\"],\n data=json.loads(row[\"data\"]),\n )"} | |
{"id": "llm/models.py:242", "code": " @classmethod\n def from_row(cls, row):\n return cls(\n start_index=row[\"start_index\"],\n end_index=row[\"end_index\"],\n data=json.loads(row[\"data\"]),\n )"} | |
{"id": "llm/models.py:251", "code": "class Chunk(BaseModel):\n text: str\n annotation: Dict[str, Any] = Field(default_factory=dict)\n start_index: Optional[int] = None\n end_index: Optional[int] = None\n\n def __str__(self):\n return self.text"} | |
{"id": "llm/models.py:257", "code": " def __str__(self):\n return self.text"} | |
{"id": "llm/models.py:261", "code": "class _BaseResponse:\n \"\"\"Base response class shared between sync and async responses\"\"\"\n\n prompt: \"Prompt\"\n stream: bool\n _annotations: List[Annotation] = field(default_factory=list)\n conversation: Optional[\"_BaseConversation\"] = None\n _key: Optional[str] = None\n\n def __init__(\n self,\n prompt: Prompt,\n model: \"_BaseModel\",\n stream: bool,\n conversation: Optional[_BaseConversation] = None,\n key: Optional[str] = None,\n ):\n self.prompt = prompt\n self._prompt_json = None\n self.model = model\n self.stream = stream\n self._key = key\n self._annotations: List[Annotation] = []\n self._chunks: List[Union[Chunk, str]] = []\n self._done = False\n self.response_json = None\n self.conversation = conversation\n self.attachments: List[Attachment] = []\n self._start: Optional[float] = None\n self._end: Optional[float] = None\n self._start_utcnow: Optional[datetime.datetime] = None\n self.input_tokens: Optional[int] = None\n self.output_tokens: Optional[int] = None\n self.token_details: Optional[dict] = None\n self.done_callbacks: List[Callable] = []\n\n if self.prompt.schema and not self.model.supports_schema:\n raise ValueError(f\"{self.model} does not support schemas\")\n\n def set_usage(\n self,\n *,\n input: Optional[int] = None,\n output: Optional[int] = None,\n details: Optional[dict] = None,\n ):\n self.input_tokens = input\n self.output_tokens = output\n self.token_details = details\n\n def add_annotations(self, annotations: List[Annotation]):\n self._annotations.extend(annotations)\n\n @property\n def annotations(self):\n return self._annotations or []\n\n @classmethod\n def from_row(cls, db, row, _async=False):\n from llm import get_model, get_async_model\n\n if _async:\n model = get_async_model(row[\"model\"])\n else:\n model = get_model(row[\"model\"])\n\n # Schema\n schema = None\n schema_id = row.get(\"schema_id\")\n if schema_id:\n schema = json.loads(db[\"schemas\"].get(row[\"schema_id\"])[\"content\"])\n\n response = cls(\n model=model,\n prompt=Prompt(\n prompt=row[\"prompt\"],\n model=model,\n attachments=[],\n system=row[\"system\"],\n schema=schema,\n options=model.Options(**json.loads(row[\"options_json\"])),\n ),\n stream=False,\n )\n response.id = row[\"id\"]\n response._prompt_json = json.loads(row[\"prompt_json\"] or \"null\")\n response.response_json = json.loads(row[\"response_json\"] or \"null\")\n response._done = True\n response._chunks = [row[\"response\"]]\n # Attachments\n response.attachments = [\n Attachment.from_row(arow)\n for arow in db.query(\n \"\"\"\n select attachments.* from attachments\n join prompt_attachments on attachments.id = prompt_attachments.attachment_id\n where prompt_attachments.response_id = ?\n order by prompt_attachments.\"order\"\n \"\"\",\n [row[\"id\"]],\n )\n ]\n # Annotations\n response._annotations = [\n Annotation.from_row(arow)\n for arow in db.query(\n \"\"\"\n select id, start_index, end_index, data\n from response_annotations\n where response_id = ?\n order by start_index\n \"\"\",\n [row[\"id\"]],\n )\n ]\n return response\n\n # iterates over chunks of text, so an iterator of Chunk\n def chunks_from_text(self, text) -> Iterator[Chunk]:\n annotations = sorted(self.annotations, key=lambda a: a.start_index)\n\n current_index = 0\n\n for annotation in annotations:\n # If there's a gap before this annotation, yield a gap chunk\n if current_index < annotation.start_index:\n gap_text = text[current_index : annotation.start_index]\n yield Chunk(\n text=gap_text,\n annotation={},\n start_index=current_index,\n end_index=annotation.start_index,\n )\n\n # Yield the chunk for this annotation\n chunk_text = text[annotation.start_index : annotation.end_index]\n yield Chunk(\n text=chunk_text,\n annotation=annotation.data,\n start_index=annotation.start_index,\n end_index=annotation.end_index,\n )\n\n current_index = annotation.end_index\n\n # If there's text after the last annotation, yield a final gap chunk\n if current_index < len(text):\n yield Chunk(\n text=text[current_index:],\n annotation={},\n start_index=current_index,\n end_index=len(text),\n )\n\n def token_usage(self) -> str:\n return token_usage_string(\n self.input_tokens, self.output_tokens, self.token_details\n )\n\n def log_to_db(self, db):\n conversation = self.conversation\n if not conversation:\n conversation = Conversation(model=self.model)\n db[\"conversations\"].insert(\n {\n \"id\": conversation.id,\n \"name\": _conversation_name(\n self.prompt.prompt or self.prompt.system or \"\"\n ),\n \"model\": conversation.model.model_id,\n },\n ignore=True,\n )\n schema_id = None\n if self.prompt.schema:\n schema_id, schema_json = make_schema_id(self.prompt.schema)\n db[\"schemas\"].insert({\"id\": schema_id, \"content\": schema_json}, ignore=True)\n\n response_id = str(ULID()).lower()\n response = {\n \"id\": response_id,\n \"model\": self.model.model_id,\n \"prompt\": self.prompt.prompt,\n \"system\": self.prompt.system,\n \"prompt_json\": self._prompt_json,\n \"options_json\": {\n key: value\n for key, value in dict(self.prompt.options).items()\n if value is not None\n },\n \"response\": self.text_or_raise(),\n \"response_json\": self.json(),\n \"conversation_id\": conversation.id,\n \"duration_ms\": self.duration_ms(),\n \"datetime_utc\": self.datetime_utc(),\n \"input_tokens\": self.input_tokens,\n \"output_tokens\": self.output_tokens,\n \"token_details\": (\n json.dumps(self.token_details) if self.token_details else None\n ),\n \"schema_id\": schema_id,\n }\n db[\"responses\"].insert(response)\n\n if self.annotations:\n db[\"response_annotations\"].insert_all(\n {\n \"response_id\": response_id,\n \"start_index\": annotation.start_index,\n \"end_index\": annotation.end_index,\n \"data\": json.dumps(annotation.data),\n }\n for annotation in self.annotations\n )\n\n # Persist any attachments - loop through with index\n for index, attachment in enumerate(self.prompt.attachments):\n attachment_id = attachment.id()\n db[\"attachments\"].insert(\n {\n \"id\": attachment_id,\n \"type\": attachment.resolve_type(),\n \"path\": attachment.path,\n \"url\": attachment.url,\n \"content\": attachment.content,\n },\n replace=True,\n )\n db[\"prompt_attachments\"].insert(\n {\n \"response_id\": response_id,\n \"attachment_id\": attachment_id,\n \"order\": index,\n },\n )"} | |
{"id": "llm/models.py:270", "code": " def __init__(\n self,\n prompt: Prompt,\n model: \"_BaseModel\",\n stream: bool,\n conversation: Optional[_BaseConversation] = None,\n key: Optional[str] = None,\n ):\n self.prompt = prompt\n self._prompt_json = None\n self.model = model\n self.stream = stream\n self._key = key\n self._annotations: List[Annotation] = []\n self._chunks: List[Union[Chunk, str]] = []\n self._done = False\n self.response_json = None\n self.conversation = conversation\n self.attachments: List[Attachment] = []\n self._start: Optional[float] = None\n self._end: Optional[float] = None\n self._start_utcnow: Optional[datetime.datetime] = None\n self.input_tokens: Optional[int] = None\n self.output_tokens: Optional[int] = None\n self.token_details: Optional[dict] = None\n self.done_callbacks: List[Callable] = []\n\n if self.prompt.schema and not self.model.supports_schema:\n raise ValueError(f\"{self.model} does not support schemas\")"} | |
{"id": "llm/models.py:300", "code": " def set_usage(\n self,\n *,\n input: Optional[int] = None,\n output: Optional[int] = None,\n details: Optional[dict] = None,\n ):\n self.input_tokens = input\n self.output_tokens = output\n self.token_details = details"} | |
{"id": "llm/models.py:311", "code": " def add_annotations(self, annotations: List[Annotation]):\n self._annotations.extend(annotations)"} | |
{"id": "llm/models.py:314", "code": " @property\n def annotations(self):\n return self._annotations or []"} | |
{"id": "llm/models.py:318", "code": " @classmethod\n def from_row(cls, db, row, _async=False):\n from llm import get_model, get_async_model\n\n if _async:\n model = get_async_model(row[\"model\"])\n else:\n model = get_model(row[\"model\"])\n\n # Schema\n schema = None\n schema_id = row.get(\"schema_id\")\n if schema_id:\n schema = json.loads(db[\"schemas\"].get(row[\"schema_id\"])[\"content\"])\n\n response = cls(\n model=model,\n prompt=Prompt(\n prompt=row[\"prompt\"],\n model=model,\n attachments=[],\n system=row[\"system\"],\n schema=schema,\n options=model.Options(**json.loads(row[\"options_json\"])),\n ),\n stream=False,\n )\n response.id = row[\"id\"]\n response._prompt_json = json.loads(row[\"prompt_json\"] or \"null\")\n response.response_json = json.loads(row[\"response_json\"] or \"null\")\n response._done = True\n response._chunks = [row[\"response\"]]\n # Attachments\n response.attachments = [\n Attachment.from_row(arow)\n for arow in db.query(\n \"\"\"\n select attachments.* from attachments\n join prompt_attachments on attachments.id = prompt_attachments.attachment_id\n where prompt_attachments.response_id = ?\n order by prompt_attachments.\"order\"\n \"\"\",\n [row[\"id\"]],\n )\n ]\n # Annotations\n response._annotations = [\n Annotation.from_row(arow)\n for arow in db.query(\n \"\"\"\n select id, start_index, end_index, data\n from response_annotations\n where response_id = ?\n order by start_index\n \"\"\",\n [row[\"id\"]],\n )\n ]\n return response"} | |
{"id": "llm/models.py:379", "code": " def chunks_from_text(self, text) -> Iterator[Chunk]:\n annotations = sorted(self.annotations, key=lambda a: a.start_index)\n\n current_index = 0\n\n for annotation in annotations:\n # If there's a gap before this annotation, yield a gap chunk\n if current_index < annotation.start_index:\n gap_text = text[current_index : annotation.start_index]\n yield Chunk(\n text=gap_text,\n annotation={},\n start_index=current_index,\n end_index=annotation.start_index,\n )\n\n # Yield the chunk for this annotation\n chunk_text = text[annotation.start_index : annotation.end_index]\n yield Chunk(\n text=chunk_text,\n annotation=annotation.data,\n start_index=annotation.start_index,\n end_index=annotation.end_index,\n )\n\n current_index = annotation.end_index\n\n # If there's text after the last annotation, yield a final gap chunk\n if current_index < len(text):\n yield Chunk(\n text=text[current_index:],\n annotation={},\n start_index=current_index,\n end_index=len(text),\n )"} | |
{"id": "llm/models.py:415", "code": " def token_usage(self) -> str:\n return token_usage_string(\n self.input_tokens, self.output_tokens, self.token_details\n )"} | |
{"id": "llm/models.py:420", "code": " def log_to_db(self, db):\n conversation = self.conversation\n if not conversation:\n conversation = Conversation(model=self.model)\n db[\"conversations\"].insert(\n {\n \"id\": conversation.id,\n \"name\": _conversation_name(\n self.prompt.prompt or self.prompt.system or \"\"\n ),\n \"model\": conversation.model.model_id,\n },\n ignore=True,\n )\n schema_id = None\n if self.prompt.schema:\n schema_id, schema_json = make_schema_id(self.prompt.schema)\n db[\"schemas\"].insert({\"id\": schema_id, \"content\": schema_json}, ignore=True)\n\n response_id = str(ULID()).lower()\n response = {\n \"id\": response_id,\n \"model\": self.model.model_id,\n \"prompt\": self.prompt.prompt,\n \"system\": self.prompt.system,\n \"prompt_json\": self._prompt_json,\n \"options_json\": {\n key: value\n for key, value in dict(self.prompt.options).items()\n if value is not None\n },\n \"response\": self.text_or_raise(),\n \"response_json\": self.json(),\n \"conversation_id\": conversation.id,\n \"duration_ms\": self.duration_ms(),\n \"datetime_utc\": self.datetime_utc(),\n \"input_tokens\": self.input_tokens,\n \"output_tokens\": self.output_tokens,\n \"token_details\": (\n json.dumps(self.token_details) if self.token_details else None\n ),\n \"schema_id\": schema_id,\n }\n db[\"responses\"].insert(response)\n\n if self.annotations:\n db[\"response_annotations\"].insert_all(\n {\n \"response_id\": response_id,\n \"start_index\": annotation.start_index,\n \"end_index\": annotation.end_index,\n \"data\": json.dumps(annotation.data),\n }\n for annotation in self.annotations\n )\n\n # Persist any attachments - loop through with index\n for index, attachment in enumerate(self.prompt.attachments):\n attachment_id = attachment.id()\n db[\"attachments\"].insert(\n {\n \"id\": attachment_id,\n \"type\": attachment.resolve_type(),\n \"path\": attachment.path,\n \"url\": attachment.url,\n \"content\": attachment.content,\n },\n replace=True,\n )\n db[\"prompt_attachments\"].insert(\n {\n \"response_id\": response_id,\n \"attachment_id\": attachment_id,\n \"order\": index,\n },\n )"} | |
{"id": "llm/models.py:498", "code": "class Response(_BaseResponse):\n model: \"Model\"\n conversation: Optional[\"Conversation\"] = None\n\n def chunks(self) -> Iterator[Chunk]:\n return self.chunks_from_text(self.text())\n\n def on_done(self, callback):\n if not self._done:\n self.done_callbacks.append(callback)\n else:\n callback(self)\n\n def _on_done(self):\n for callback in self.done_callbacks:\n callback(self)\n\n def __str__(self) -> str:\n return self.text()\n\n def _force(self):\n if not self._done:\n list(self)\n\n def text(self) -> str:\n self._force()\n return \"\".join(map(str, self._chunks))\n\n def text_or_raise(self) -> str:\n return self.text()\n\n def json(self) -> Optional[Dict[str, Any]]:\n self._force()\n return self.response_json\n\n def duration_ms(self) -> int:\n self._force()\n return int(((self._end or 0) - (self._start or 0)) * 1000)\n\n def datetime_utc(self) -> str:\n self._force()\n return self._start_utcnow.isoformat() if self._start_utcnow else \"\"\n\n def usage(self) -> Usage:\n self._force()\n return Usage(\n input=self.input_tokens,\n output=self.output_tokens,\n details=self.token_details,\n )\n\n def __iter__(self) -> Iterator[Union[Chunk, str]]:\n self._start = time.monotonic()\n self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)\n if self._done:\n yield from self._chunks\n return\n\n if isinstance(self.model, Model):\n chunk_iter = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n )\n elif isinstance(self.model, KeyModel):\n chunk_iter = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n key=self.model.get_key(self._key),\n )\n else:\n raise Exception(\"self.model must be a Model or KeyModel\")\n index = 0\n for chunk in chunk_iter:\n if isinstance(chunk, Chunk):\n chunk.start_index = index\n index += len(chunk.text)\n chunk.end_index = index\n else:\n index += len(chunk)\n yield chunk\n self._chunks.append(chunk)\n\n if self.conversation:\n self.conversation.responses.append(self)\n self._end = time.monotonic()\n self._done = True\n self._on_done()\n\n def __repr__(self):\n text = \"... not yet done ...\"\n if self._done:\n text = \"\".join(self._chunks)\n return \"<Response prompt='{}' text='{}'>\".format(self.prompt.prompt, text)"} | |
{"id": "llm/models.py:502", "code": " def chunks(self) -> Iterator[Chunk]:\n return self.chunks_from_text(self.text())"} | |
{"id": "llm/models.py:505", "code": " def on_done(self, callback):\n if not self._done:\n self.done_callbacks.append(callback)\n else:\n callback(self)"} | |
{"id": "llm/models.py:511", "code": " def _on_done(self):\n for callback in self.done_callbacks:\n callback(self)"} | |
{"id": "llm/models.py:515", "code": " def __str__(self) -> str:\n return self.text()"} | |
{"id": "llm/models.py:518", "code": " def _force(self):\n if not self._done:\n list(self)"} | |
{"id": "llm/models.py:522", "code": " def text(self) -> str:\n self._force()\n return \"\".join(map(str, self._chunks))"} | |
{"id": "llm/models.py:526", "code": " def text_or_raise(self) -> str:\n return self.text()"} | |
{"id": "llm/models.py:529", "code": " def json(self) -> Optional[Dict[str, Any]]:\n self._force()\n return self.response_json"} | |
{"id": "llm/models.py:533", "code": " def duration_ms(self) -> int:\n self._force()\n return int(((self._end or 0) - (self._start or 0)) * 1000)"} | |
{"id": "llm/models.py:537", "code": " def datetime_utc(self) -> str:\n self._force()\n return self._start_utcnow.isoformat() if self._start_utcnow else \"\""} | |
{"id": "llm/models.py:541", "code": " def usage(self) -> Usage:\n self._force()\n return Usage(\n input=self.input_tokens,\n output=self.output_tokens,\n details=self.token_details,\n )"} | |
{"id": "llm/models.py:549", "code": " def __iter__(self) -> Iterator[Union[Chunk, str]]:\n self._start = time.monotonic()\n self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)\n if self._done:\n yield from self._chunks\n return\n\n if isinstance(self.model, Model):\n chunk_iter = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n )\n elif isinstance(self.model, KeyModel):\n chunk_iter = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n key=self.model.get_key(self._key),\n )\n else:\n raise Exception(\"self.model must be a Model or KeyModel\")\n index = 0\n for chunk in chunk_iter:\n if isinstance(chunk, Chunk):\n chunk.start_index = index\n index += len(chunk.text)\n chunk.end_index = index\n else:\n index += len(chunk)\n yield chunk\n self._chunks.append(chunk)\n\n if self.conversation:\n self.conversation.responses.append(self)\n self._end = time.monotonic()\n self._done = True\n self._on_done()"} | |
{"id": "llm/models.py:590", "code": " def __repr__(self):\n text = \"... not yet done ...\"\n if self._done:\n text = \"\".join(self._chunks)\n return \"<Response prompt='{}' text='{}'>\".format(self.prompt.prompt, text)"} | |
{"id": "llm/models.py:597", "code": "class AsyncResponse(_BaseResponse):\n model: \"AsyncModel\"\n conversation: Optional[\"AsyncConversation\"] = None\n\n async def chunks(self) -> Iterator[Chunk]:\n return self.chunks_from_text(await self.text())\n\n @classmethod\n def from_row(cls, db, row, _async=False):\n return super().from_row(db, row, _async=True)\n\n async def on_done(self, callback):\n if not self._done:\n self.done_callbacks.append(callback)\n else:\n if callable(callback):\n callback = callback(self)\n if asyncio.iscoroutine(callback):\n await callback\n\n async def _on_done(self):\n for callback in self.done_callbacks:\n if callable(callback):\n callback = callback(self)\n if asyncio.iscoroutine(callback):\n await callback\n\n def __aiter__(self):\n self._start = time.monotonic()\n self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)\n self._generator_index = 0\n return self\n\n async def __anext__(self) -> Union[Chunk, str]:\n if self._done:\n if not self._chunks:\n raise StopAsyncIteration\n chunk = self._chunks.pop(0)\n if not self._chunks:\n raise StopAsyncIteration\n return chunk\n\n if not hasattr(self, \"_generator\"):\n if isinstance(self.model, AsyncModel):\n self._generator = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n )\n elif isinstance(self.model, AsyncKeyModel):\n self._generator = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n key=self.model.get_key(self._key),\n )\n else:\n raise ValueError(\"self.model must be an AsyncModel or AsyncKeyModel\")\n\n try:\n chunk = await self._generator.__anext__()\n if isinstance(chunk, Chunk):\n chunk.start_index = self._generator_index\n self._generator_index += len(chunk.text)\n chunk.end_index = self._generator_index\n else:\n self._generator_index += len(chunk)\n self._chunks.append(chunk)\n return chunk\n except StopAsyncIteration:\n if self.conversation:\n self.conversation.responses.append(self)\n self._end = time.monotonic()\n self._done = True\n await self._on_done()\n raise\n\n async def _force(self):\n if not self._done:\n async for _ in self:\n pass\n return self\n\n def text_or_raise(self) -> str:\n if not self._done:\n raise ValueError(\"Response not yet awaited\")\n return \"\".join(map(str, self._chunks))\n\n async def text(self) -> str:\n await self._force()\n return \"\".join(map(str, self._chunks))\n\n async def json(self) -> Optional[Dict[str, Any]]:\n await self._force()\n return self.response_json\n\n async def duration_ms(self) -> int:\n await self._force()\n return int(((self._end or 0) - (self._start or 0)) * 1000)\n\n async def datetime_utc(self) -> str:\n await self._force()\n return self._start_utcnow.isoformat() if self._start_utcnow else \"\"\n\n async def usage(self) -> Usage:\n await self._force()\n return Usage(\n input=self.input_tokens,\n output=self.output_tokens,\n details=self.token_details,\n )\n\n def __await__(self):\n return self._force().__await__()\n\n async def to_sync_response(self) -> Response:\n await self._force()\n response = Response(\n self.prompt,\n self.model,\n self.stream,\n conversation=self.conversation,\n )\n response._chunks = self._chunks\n response._done = True\n response._end = self._end\n response._start = self._start\n response._start_utcnow = self._start_utcnow\n response.input_tokens = self.input_tokens\n response.output_tokens = self.output_tokens\n response.token_details = self.token_details\n return response\n\n @classmethod\n def fake(\n cls,\n model: \"AsyncModel\",\n prompt: str,\n *attachments: List[Attachment],\n system: str,\n response: str,\n ):\n \"Utility method to help with writing tests\"\n response_obj = cls(\n model=model,\n prompt=Prompt(\n prompt,\n model=model,\n attachments=attachments,\n system=system,\n ),\n stream=False,\n )\n response_obj._done = True\n response_obj._chunks = [response]\n return response_obj\n\n def __repr__(self):\n text = \"... not yet awaited ...\"\n if self._done:\n text = \"\".join(self._chunks)\n return \"<AsyncResponse prompt='{}' text='{}'>\".format(self.prompt.prompt, text)"} | |
{"id": "llm/models.py:601", "code": " async def chunks(self) -> Iterator[Chunk]:\n return self.chunks_from_text(await self.text())"} | |
{"id": "llm/models.py:604", "code": " @classmethod\n def from_row(cls, db, row, _async=False):\n return super().from_row(db, row, _async=True)"} | |
{"id": "llm/models.py:608", "code": " async def on_done(self, callback):\n if not self._done:\n self.done_callbacks.append(callback)\n else:\n if callable(callback):\n callback = callback(self)\n if asyncio.iscoroutine(callback):\n await callback"} | |
{"id": "llm/models.py:617", "code": " async def _on_done(self):\n for callback in self.done_callbacks:\n if callable(callback):\n callback = callback(self)\n if asyncio.iscoroutine(callback):\n await callback"} | |
{"id": "llm/models.py:624", "code": " def __aiter__(self):\n self._start = time.monotonic()\n self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)\n self._generator_index = 0\n return self"} | |
{"id": "llm/models.py:630", "code": " async def __anext__(self) -> Union[Chunk, str]:\n if self._done:\n if not self._chunks:\n raise StopAsyncIteration\n chunk = self._chunks.pop(0)\n if not self._chunks:\n raise StopAsyncIteration\n return chunk\n\n if not hasattr(self, \"_generator\"):\n if isinstance(self.model, AsyncModel):\n self._generator = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n )\n elif isinstance(self.model, AsyncKeyModel):\n self._generator = self.model.execute(\n self.prompt,\n stream=self.stream,\n response=self,\n conversation=self.conversation,\n key=self.model.get_key(self._key),\n )\n else:\n raise ValueError(\"self.model must be an AsyncModel or AsyncKeyModel\")\n\n try:\n chunk = await self._generator.__anext__()\n if isinstance(chunk, Chunk):\n chunk.start_index = self._generator_index\n self._generator_index += len(chunk.text)\n chunk.end_index = self._generator_index\n else:\n self._generator_index += len(chunk)\n self._chunks.append(chunk)\n return chunk\n except StopAsyncIteration:\n if self.conversation:\n self.conversation.responses.append(self)\n self._end = time.monotonic()\n self._done = True\n await self._on_done()\n raise"} | |
{"id": "llm/models.py:676", "code": " async def _force(self):\n if not self._done:\n async for _ in self:\n pass\n return self"} | |
{"id": "llm/models.py:682", "code": " def text_or_raise(self) -> str:\n if not self._done:\n raise ValueError(\"Response not yet awaited\")\n return \"\".join(map(str, self._chunks))"} | |
{"id": "llm/models.py:687", "code": " async def text(self) -> str:\n await self._force()\n return \"\".join(map(str, self._chunks))"} | |
{"id": "llm/models.py:691", "code": " async def json(self) -> Optional[Dict[str, Any]]:\n await self._force()\n return self.response_json"} | |
{"id": "llm/models.py:695", "code": " async def duration_ms(self) -> int:\n await self._force()\n return int(((self._end or 0) - (self._start or 0)) * 1000)"} | |
{"id": "llm/models.py:699", "code": " async def datetime_utc(self) -> str:\n await self._force()\n return self._start_utcnow.isoformat() if self._start_utcnow else \"\""} | |
{"id": "llm/models.py:703", "code": " async def usage(self) -> Usage:\n await self._force()\n return Usage(\n input=self.input_tokens,\n output=self.output_tokens,\n details=self.token_details,\n )"} | |
{"id": "llm/models.py:711", "code": " def __await__(self):\n return self._force().__await__()"} | |
{"id": "llm/models.py:714", "code": " async def to_sync_response(self) -> Response:\n await self._force()\n response = Response(\n self.prompt,\n self.model,\n self.stream,\n conversation=self.conversation,\n )\n response._chunks = self._chunks\n response._done = True\n response._end = self._end\n response._start = self._start\n response._start_utcnow = self._start_utcnow\n response.input_tokens = self.input_tokens\n response.output_tokens = self.output_tokens\n response.token_details = self.token_details\n return response"} | |
{"id": "llm/models.py:732", "code": " @classmethod\n def fake(\n cls,\n model: \"AsyncModel\",\n prompt: str,\n *attachments: List[Attachment],\n system: str,\n response: str,\n ):\n \"Utility method to help with writing tests\"\n response_obj = cls(\n model=model,\n prompt=Prompt(\n prompt,\n model=model,\n attachments=attachments,\n system=system,\n ),\n stream=False,\n )\n response_obj._done = True\n response_obj._chunks = [response]\n return response_obj"} | |
{"id": "llm/models.py:756", "code": " def __repr__(self):\n text = \"... not yet awaited ...\"\n if self._done:\n text = \"\".join(self._chunks)\n return \"<AsyncResponse prompt='{}' text='{}'>\".format(self.prompt.prompt, text)"} | |
{"id": "llm/models.py:763", "code": "class Options(BaseModel):\n model_config = ConfigDict(extra=\"forbid\")"} | |
{"id": "llm/models.py:770", "code": "class _get_key_mixin:\n needs_key: Optional[str] = None\n key: Optional[str] = None\n key_env_var: Optional[str] = None\n\n def get_key(self, explicit_key: Optional[str] = None) -> Optional[str]:\n from llm import get_key\n\n if self.needs_key is None:\n # This model doesn't use an API key\n return None\n\n if self.key is not None:\n # Someone already set model.key='...'\n return self.key\n\n # Attempt to load a key using llm.get_key()\n key = get_key(\n explicit_key=explicit_key,\n key_alias=self.needs_key,\n env_var=self.key_env_var,\n )\n if key:\n return key\n\n # Show a useful error message\n message = \"No key found - add one using 'llm keys set {}'\".format(\n self.needs_key\n )\n if self.key_env_var:\n message += \" or set the {} environment variable\".format(self.key_env_var)\n raise NeedsKeyException(message)"} | |
{"id": "llm/models.py:775", "code": " def get_key(self, explicit_key: Optional[str] = None) -> Optional[str]:\n from llm import get_key\n\n if self.needs_key is None:\n # This model doesn't use an API key\n return None\n\n if self.key is not None:\n # Someone already set model.key='...'\n return self.key\n\n # Attempt to load a key using llm.get_key()\n key = get_key(\n explicit_key=explicit_key,\n key_alias=self.needs_key,\n env_var=self.key_env_var,\n )\n if key:\n return key\n\n # Show a useful error message\n message = \"No key found - add one using 'llm keys set {}'\".format(\n self.needs_key\n )\n if self.key_env_var:\n message += \" or set the {} environment variable\".format(self.key_env_var)\n raise NeedsKeyException(message)"} | |
{"id": "llm/models.py:804", "code": "class _BaseModel(ABC, _get_key_mixin):\n model_id: str\n can_stream: bool = False\n attachment_types: Set = set()\n\n supports_schema = False\n\n class Options(_Options):\n pass\n\n def _validate_attachments(\n self, attachments: Optional[List[Attachment]] = None\n ) -> None:\n if attachments and not self.attachment_types:\n raise ValueError(\"This model does not support attachments\")\n for attachment in attachments or []:\n attachment_type = attachment.resolve_type()\n if attachment_type not in self.attachment_types:\n raise ValueError(\n f\"This model does not support attachments of type '{attachment_type}', \"\n f\"only {', '.join(self.attachment_types)}\"\n )\n\n def __str__(self) -> str:\n return \"{}{}: {}\".format(\n self.__class__.__name__,\n \" (async)\" if isinstance(self, (AsyncModel, AsyncKeyModel)) else \"\",\n self.model_id,\n )\n\n def __repr__(self) -> str:\n return f\"<{str(self)}>\""} | |
{"id": "llm/models.py:814", "code": " def _validate_attachments(\n self, attachments: Optional[List[Attachment]] = None\n ) -> None:\n if attachments and not self.attachment_types:\n raise ValueError(\"This model does not support attachments\")\n for attachment in attachments or []:\n attachment_type = attachment.resolve_type()\n if attachment_type not in self.attachment_types:\n raise ValueError(\n f\"This model does not support attachments of type '{attachment_type}', \"\n f\"only {', '.join(self.attachment_types)}\"\n )"} | |
{"id": "llm/models.py:827", "code": " def __str__(self) -> str:\n return \"{}{}: {}\".format(\n self.__class__.__name__,\n \" (async)\" if isinstance(self, (AsyncModel, AsyncKeyModel)) else \"\",\n self.model_id,\n )"} | |
{"id": "llm/models.py:834", "code": " def __repr__(self) -> str:\n return f\"<{str(self)}>\""} | |
{"id": "llm/models.py:838", "code": "class _Model(_BaseModel):\n def conversation(self) -> Conversation:\n return Conversation(model=self)\n\n def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n stream: bool = True,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n **options,\n ) -> Response:\n key = options.pop(\"key\", None)\n self._validate_attachments(attachments)\n return Response(\n Prompt(\n prompt,\n attachments=attachments,\n system=system,\n schema=schema,\n model=self,\n options=self.Options(**options),\n ),\n self,\n stream,\n key=key,\n )"} | |
{"id": "llm/models.py:839", "code": " def conversation(self) -> Conversation:\n return Conversation(model=self)"} | |
{"id": "llm/models.py:842", "code": " def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n stream: bool = True,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n **options,\n ) -> Response:\n key = options.pop(\"key\", None)\n self._validate_attachments(attachments)\n return Response(\n Prompt(\n prompt,\n attachments=attachments,\n system=system,\n schema=schema,\n model=self,\n options=self.Options(**options),\n ),\n self,\n stream,\n key=key,\n )"} | |
{"id": "llm/models.py:869", "code": "class Model(_Model):\n @abstractmethod\n def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: Response,\n conversation: Optional[Conversation],\n ) -> Iterator[Union[str, Chunk]]:\n pass"} | |
{"id": "llm/models.py:870", "code": " @abstractmethod\n def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: Response,\n conversation: Optional[Conversation],\n ) -> Iterator[Union[str, Chunk]]:\n pass"} | |
{"id": "llm/models.py:881", "code": "class KeyModel(_Model):\n @abstractmethod\n def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: Response,\n conversation: Optional[Conversation],\n key: Optional[str],\n ) -> Iterator[Union[str, Chunk]]:\n pass"} | |
{"id": "llm/models.py:882", "code": " @abstractmethod\n def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: Response,\n conversation: Optional[Conversation],\n key: Optional[str],\n ) -> Iterator[Union[str, Chunk]]:\n pass"} | |
{"id": "llm/models.py:894", "code": "class _AsyncModel(_BaseModel):\n def conversation(self) -> AsyncConversation:\n return AsyncConversation(model=self)\n\n def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n **options,\n ) -> AsyncResponse:\n key = options.pop(\"key\", None)\n self._validate_attachments(attachments)\n return AsyncResponse(\n Prompt(\n prompt,\n attachments=attachments,\n system=system,\n schema=schema,\n model=self,\n options=self.Options(**options),\n ),\n self,\n stream,\n key=key,\n )"} | |
{"id": "llm/models.py:895", "code": " def conversation(self) -> AsyncConversation:\n return AsyncConversation(model=self)"} | |
{"id": "llm/models.py:898", "code": " def prompt(\n self,\n prompt: Optional[str] = None,\n *,\n attachments: Optional[List[Attachment]] = None,\n system: Optional[str] = None,\n schema: Optional[Union[dict, type[BaseModel]]] = None,\n stream: bool = True,\n **options,\n ) -> AsyncResponse:\n key = options.pop(\"key\", None)\n self._validate_attachments(attachments)\n return AsyncResponse(\n Prompt(\n prompt,\n attachments=attachments,\n system=system,\n schema=schema,\n model=self,\n options=self.Options(**options),\n ),\n self,\n stream,\n key=key,\n )"} | |
{"id": "llm/models.py:925", "code": "class AsyncModel(_AsyncModel):\n @abstractmethod\n async def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: AsyncResponse,\n conversation: Optional[AsyncConversation],\n ) -> AsyncGenerator[Union[str, Chunk], None]:\n yield \"\""} | |
{"id": "llm/models.py:926", "code": " @abstractmethod\n async def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: AsyncResponse,\n conversation: Optional[AsyncConversation],\n ) -> AsyncGenerator[Union[str, Chunk], None]:\n yield \"\""} | |
{"id": "llm/models.py:937", "code": "class AsyncKeyModel(_AsyncModel):\n @abstractmethod\n async def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: AsyncResponse,\n conversation: Optional[AsyncConversation],\n key: Optional[str],\n ) -> AsyncGenerator[Union[str, Chunk], None]:\n yield \"\""} | |
{"id": "llm/models.py:938", "code": " @abstractmethod\n async def execute(\n self,\n prompt: Prompt,\n stream: bool,\n response: AsyncResponse,\n conversation: Optional[AsyncConversation],\n key: Optional[str],\n ) -> AsyncGenerator[Union[str, Chunk], None]:\n yield \"\""} | |
{"id": "llm/models.py:950", "code": "class EmbeddingModel(ABC, _get_key_mixin):\n model_id: str\n key: Optional[str] = None\n needs_key: Optional[str] = None\n key_env_var: Optional[str] = None\n supports_text: bool = True\n supports_binary: bool = False\n batch_size: Optional[int] = None\n\n def _check(self, item: Union[str, bytes]):\n if not self.supports_binary and isinstance(item, bytes):\n raise ValueError(\n \"This model does not support binary data, only text strings\"\n )\n if not self.supports_text and isinstance(item, str):\n raise ValueError(\n \"This model does not support text strings, only binary data\"\n )\n\n def embed(self, item: Union[str, bytes]) -> List[float]:\n \"Embed a single text string or binary blob, return a list of floats\"\n self._check(item)\n return next(iter(self.embed_batch([item])))\n\n def embed_multi(\n self, items: Iterable[Union[str, bytes]], batch_size: Optional[int] = None\n ) -> Iterator[List[float]]:\n \"Embed multiple items in batches according to the model batch_size\"\n iter_items = iter(items)\n batch_size = self.batch_size if batch_size is None else batch_size\n if (not self.supports_binary) or (not self.supports_text):\n\n def checking_iter(items):\n for item in items:\n self._check(item)\n yield item\n\n iter_items = checking_iter(items)\n if batch_size is None:\n yield from self.embed_batch(iter_items)\n return\n while True:\n batch_items = list(islice(iter_items, batch_size))\n if not batch_items:\n break\n yield from self.embed_batch(batch_items)\n\n @abstractmethod\n def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:\n \"\"\"\n Embed a batch of strings or blobs, return a list of lists of floats\n \"\"\"\n pass\n\n def __str__(self) -> str:\n return \"{}: {}\".format(self.__class__.__name__, self.model_id)\n\n def __repr__(self) -> str:\n return f\"<{str(self)}>\""} | |
{"id": "llm/models.py:959", "code": " def _check(self, item: Union[str, bytes]):\n if not self.supports_binary and isinstance(item, bytes):\n raise ValueError(\n \"This model does not support binary data, only text strings\"\n )\n if not self.supports_text and isinstance(item, str):\n raise ValueError(\n \"This model does not support text strings, only binary data\"\n )"} | |
{"id": "llm/models.py:969", "code": " def embed(self, item: Union[str, bytes]) -> List[float]:\n \"Embed a single text string or binary blob, return a list of floats\"\n self._check(item)\n return next(iter(self.embed_batch([item])))"} | |
{"id": "llm/models.py:974", "code": " def embed_multi(\n self, items: Iterable[Union[str, bytes]], batch_size: Optional[int] = None\n ) -> Iterator[List[float]]:\n \"Embed multiple items in batches according to the model batch_size\"\n iter_items = iter(items)\n batch_size = self.batch_size if batch_size is None else batch_size\n if (not self.supports_binary) or (not self.supports_text):\n\n def checking_iter(items):\n for item in items:\n self._check(item)\n yield item\n\n iter_items = checking_iter(items)\n if batch_size is None:\n yield from self.embed_batch(iter_items)\n return\n while True:\n batch_items = list(islice(iter_items, batch_size))\n if not batch_items:\n break\n yield from self.embed_batch(batch_items)"} | |
{"id": "llm/models.py:997", "code": " @abstractmethod\n def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:\n \"\"\"\n Embed a batch of strings or blobs, return a list of lists of floats\n \"\"\"\n pass"} | |
{"id": "llm/models.py:1004", "code": " def __str__(self) -> str:\n return \"{}: {}\".format(self.__class__.__name__, self.model_id)"} | |
{"id": "llm/models.py:1007", "code": " def __repr__(self) -> str:\n return f\"<{str(self)}>\""} | |
{"id": "llm/models.py:1011", "code": "@dataclass\nclass ModelWithAliases:\n model: Model\n async_model: AsyncModel\n aliases: Set[str]\n\n def matches(self, query: str) -> bool:\n query = query.lower()\n all_strings: List[str] = []\n all_strings.extend(self.aliases)\n if self.model:\n all_strings.append(str(self.model))\n if self.async_model:\n all_strings.append(str(self.async_model.model_id))\n return any(query in alias.lower() for alias in all_strings)"} | |
{"id": "llm/models.py:1017", "code": " def matches(self, query: str) -> bool:\n query = query.lower()\n all_strings: List[str] = []\n all_strings.extend(self.aliases)\n if self.model:\n all_strings.append(str(self.model))\n if self.async_model:\n all_strings.append(str(self.async_model.model_id))\n return any(query in alias.lower() for alias in all_strings)"} | |
{"id": "llm/models.py:1028", "code": "@dataclass\nclass EmbeddingModelWithAliases:\n model: EmbeddingModel\n aliases: Set[str]\n\n def matches(self, query: str) -> bool:\n query = query.lower()\n all_strings: List[str] = []\n all_strings.extend(self.aliases)\n all_strings.append(str(self.model))\n return any(query in alias.lower() for alias in all_strings)"} | |
{"id": "llm/models.py:1033", "code": " def matches(self, query: str) -> bool:\n query = query.lower()\n all_strings: List[str] = []\n all_strings.extend(self.aliases)\n all_strings.append(str(self.model))\n return any(query in alias.lower() for alias in all_strings)"} | |
{"id": "llm/models.py:1041", "code": "def _conversation_name(text):\n # Collapse whitespace, including newlines\n text = re.sub(r\"\\s+\", \" \", text)\n if len(text) <= CONVERSATION_NAME_LENGTH:\n return text\n return text[: CONVERSATION_NAME_LENGTH - 1] + \"\u2026\""} | |
{"id": "llm/__init__.py:62", "code": "def get_plugins(all=False):\n plugins = []\n plugin_to_distinfo = dict(pm.list_plugin_distinfo())\n for plugin in pm.get_plugins():\n if not all and plugin.__name__.startswith(\"llm.default_plugins.\"):\n continue\n plugin_info = {\n \"name\": plugin.__name__,\n \"hooks\": [h.name for h in pm.get_hookcallers(plugin)],\n }\n distinfo = plugin_to_distinfo.get(plugin)\n if distinfo:\n plugin_info[\"version\"] = distinfo.version\n plugin_info[\"name\"] = (\n getattr(distinfo, \"name\", None) or distinfo.project_name\n )\n plugins.append(plugin_info)\n return plugins"} | |
{"id": "llm/__init__.py:82", "code": "def get_models_with_aliases() -> List[\"ModelWithAliases\"]:\n model_aliases = []\n\n # Include aliases from aliases.json\n aliases_path = user_dir() / \"aliases.json\"\n extra_model_aliases: Dict[str, list] = {}\n if aliases_path.exists():\n configured_aliases = json.loads(aliases_path.read_text())\n for alias, model_id in configured_aliases.items():\n extra_model_aliases.setdefault(model_id, []).append(alias)\n\n def register(model, async_model=None, aliases=None):\n alias_list = list(aliases or [])\n if model.model_id in extra_model_aliases:\n alias_list.extend(extra_model_aliases[model.model_id])\n model_aliases.append(ModelWithAliases(model, async_model, alias_list))\n\n load_plugins()\n pm.hook.register_models(register=register)\n\n return model_aliases"} | |
{"id": "llm/__init__.py:105", "code": "def get_template_loaders() -> Dict[str, Callable[[str], Template]]:\n load_plugins()\n loaders = {}\n\n def register(prefix, loader):\n suffix = 0\n prefix_to_try = prefix\n while prefix_to_try in loaders:\n suffix += 1\n prefix_to_try = f\"{prefix}_{suffix}\"\n loaders[prefix_to_try] = loader\n\n pm.hook.register_template_loaders(register=register)\n return loaders"} | |
{"id": "llm/__init__.py:121", "code": "def get_embedding_models_with_aliases() -> List[\"EmbeddingModelWithAliases\"]:\n model_aliases = []\n\n # Include aliases from aliases.json\n aliases_path = user_dir() / \"aliases.json\"\n extra_model_aliases: Dict[str, list] = {}\n if aliases_path.exists():\n configured_aliases = json.loads(aliases_path.read_text())\n for alias, model_id in configured_aliases.items():\n extra_model_aliases.setdefault(model_id, []).append(alias)\n\n def register(model, aliases=None):\n alias_list = list(aliases or [])\n if model.model_id in extra_model_aliases:\n alias_list.extend(extra_model_aliases[model.model_id])\n model_aliases.append(EmbeddingModelWithAliases(model, alias_list))\n\n load_plugins()\n pm.hook.register_embedding_models(register=register)\n\n return model_aliases"} | |
{"id": "llm/__init__.py:144", "code": "def get_embedding_models():\n models = []\n\n def register(model, aliases=None):\n models.append(model)\n\n load_plugins()\n pm.hook.register_embedding_models(register=register)\n return models"} | |
{"id": "llm/__init__.py:155", "code": "def get_embedding_model(name):\n aliases = get_embedding_model_aliases()\n try:\n return aliases[name]\n except KeyError:\n raise UnknownModelError(\"Unknown model: \" + str(name))"} | |
{"id": "llm/__init__.py:163", "code": "def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:\n model_aliases = {}\n for model_with_aliases in get_embedding_models_with_aliases():\n for alias in model_with_aliases.aliases:\n model_aliases[alias] = model_with_aliases.model\n model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model\n return model_aliases"} | |
{"id": "llm/__init__.py:172", "code": "def get_async_model_aliases() -> Dict[str, AsyncModel]:\n async_model_aliases = {}\n for model_with_aliases in get_models_with_aliases():\n if model_with_aliases.async_model:\n for alias in model_with_aliases.aliases:\n async_model_aliases[alias] = model_with_aliases.async_model\n async_model_aliases[model_with_aliases.model.model_id] = (\n model_with_aliases.async_model\n )\n return async_model_aliases"} | |
{"id": "llm/__init__.py:184", "code": "def get_model_aliases() -> Dict[str, Model]:\n model_aliases = {}\n for model_with_aliases in get_models_with_aliases():\n if model_with_aliases.model:\n for alias in model_with_aliases.aliases:\n model_aliases[alias] = model_with_aliases.model\n model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model\n return model_aliases"} | |
{"id": "llm/__init__.py:194", "code": "class UnknownModelError(KeyError):\n pass"} | |
{"id": "llm/__init__.py:198", "code": "def get_models() -> List[Model]:\n \"Get all registered models\"\n models_with_aliases = get_models_with_aliases()\n return [mwa.model for mwa in models_with_aliases if mwa.model]"} | |
{"id": "llm/__init__.py:204", "code": "def get_async_models() -> List[AsyncModel]:\n \"Get all registered async models\"\n models_with_aliases = get_models_with_aliases()\n return [mwa.async_model for mwa in models_with_aliases if mwa.async_model]"} | |
{"id": "llm/__init__.py:210", "code": "def get_async_model(name: Optional[str] = None) -> AsyncModel:\n \"Get an async model by name or alias\"\n aliases = get_async_model_aliases()\n name = name or get_default_model()\n try:\n return aliases[name]\n except KeyError:\n # Does a sync model exist?\n sync_model = None\n try:\n sync_model = get_model(name, _skip_async=True)\n except UnknownModelError:\n pass\n if sync_model:\n raise UnknownModelError(\"Unknown async model (sync model exists): \" + name)\n else:\n raise UnknownModelError(\"Unknown model: \" + name)"} | |
{"id": "llm/__init__.py:229", "code": "def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model:\n \"Get a model by name or alias\"\n aliases = get_model_aliases()\n name = name or get_default_model()\n try:\n return aliases[name]\n except KeyError:\n # Does an async model exist?\n if _skip_async:\n raise UnknownModelError(\"Unknown model: \" + name)\n async_model = None\n try:\n async_model = get_async_model(name)\n except UnknownModelError:\n pass\n if async_model:\n raise UnknownModelError(\"Unknown model (async model exists): \" + name)\n else:\n raise UnknownModelError(\"Unknown model: \" + name)"} | |
{"id": "llm/__init__.py:250", "code": "def get_key(\n explicit_key: Optional[str], key_alias: str, env_var: Optional[str] = None\n) -> Optional[str]:\n \"\"\"\n Return an API key based on a hierarchy of potential sources.\n\n :param provided_key: A key provided by the user. This may be the key, or an alias of a key in keys.json.\n :param key_alias: The alias used to retrieve the key from the keys.json file.\n :param env_var: Name of the environment variable to check for the key.\n \"\"\"\n stored_keys = load_keys()\n # If user specified an alias, use the key stored for that alias\n if explicit_key in stored_keys:\n return stored_keys[explicit_key]\n if explicit_key:\n # User specified a key that's not an alias, use that\n return explicit_key\n # Stored key over-rides environment variables over-ride the default key\n if key_alias in stored_keys:\n return stored_keys[key_alias]\n # Finally try environment variable\n if env_var and os.environ.get(env_var):\n return os.environ[env_var]\n # Couldn't find it\n return None"} | |
{"id": "llm/__init__.py:277", "code": "def load_keys():\n path = user_dir() / \"keys.json\"\n if path.exists():\n return json.loads(path.read_text())\n else:\n return {}"} | |
{"id": "llm/__init__.py:285", "code": "def user_dir():\n llm_user_path = os.environ.get(\"LLM_USER_PATH\")\n if llm_user_path:\n path = pathlib.Path(llm_user_path)\n else:\n path = pathlib.Path(click.get_app_dir(\"io.datasette.llm\"))\n path.mkdir(exist_ok=True, parents=True)\n return path"} | |
{"id": "llm/__init__.py:295", "code": "def set_alias(alias, model_id_or_alias):\n \"\"\"\n Set an alias to point to the specified model.\n \"\"\"\n path = user_dir() / \"aliases.json\"\n path.parent.mkdir(parents=True, exist_ok=True)\n if not path.exists():\n path.write_text(\"{}\\n\")\n try:\n current = json.loads(path.read_text())\n except json.decoder.JSONDecodeError:\n # We're going to write a valid JSON file in a moment:\n current = {}\n # Resolve model_id_or_alias to a model_id\n try:\n model = get_model(model_id_or_alias)\n model_id = model.model_id\n except UnknownModelError:\n # Try to resolve it to an embedding model\n try:\n model = get_embedding_model(model_id_or_alias)\n model_id = model.model_id\n except UnknownModelError:\n # Set the alias to the exact string they provided instead\n model_id = model_id_or_alias\n current[alias] = model_id\n path.write_text(json.dumps(current, indent=4) + \"\\n\")"} | |
{"id": "llm/__init__.py:324", "code": "def remove_alias(alias):\n \"\"\"\n Remove an alias.\n \"\"\"\n path = user_dir() / \"aliases.json\"\n if not path.exists():\n raise KeyError(\"No aliases.json file exists\")\n try:\n current = json.loads(path.read_text())\n except json.decoder.JSONDecodeError:\n raise KeyError(\"aliases.json file is not valid JSON\")\n if alias not in current:\n raise KeyError(\"No such alias: {}\".format(alias))\n del current[alias]\n path.write_text(json.dumps(current, indent=4) + \"\\n\")"} | |
{"id": "llm/__init__.py:341", "code": "def encode(values):\n return struct.pack(\"<\" + \"f\" * len(values), *values)"} | |
{"id": "llm/__init__.py:345", "code": "def decode(binary):\n return struct.unpack(\"<\" + \"f\" * (len(binary) // 4), binary)"} | |
{"id": "llm/__init__.py:349", "code": "def cosine_similarity(a, b):\n dot_product = sum(x * y for x, y in zip(a, b))\n magnitude_a = sum(x * x for x in a) ** 0.5\n magnitude_b = sum(x * x for x in b) ** 0.5\n return dot_product / (magnitude_a * magnitude_b)"} | |
{"id": "llm/__init__.py:356", "code": "def get_default_model(filename=\"default_model.txt\", default=DEFAULT_MODEL):\n path = user_dir() / filename\n if path.exists():\n return path.read_text().strip()\n else:\n return default"} | |
{"id": "llm/__init__.py:364", "code": "def set_default_model(model, filename=\"default_model.txt\"):\n path = user_dir() / filename\n if model is None and path.exists():\n path.unlink()\n else:\n path.write_text(model)"} | |
{"id": "llm/__init__.py:372", "code": "def get_default_embedding_model():\n return get_default_model(\"default_embedding_model.txt\", None)"} | |
{"id": "llm/__init__.py:376", "code": "def set_default_embedding_model(model):\n set_default_model(model, \"default_embedding_model.txt\")"} | |
{"id": "llm/serve.py:5", "code": "async def error(send, status_code: int, message: str):\n await send(\n {\n \"type\": \"http.response.start\",\n \"status\": status_code,\n \"headers\": [(b\"content-type\", b\"application/json\")],\n }\n )\n await send(\n {\n \"type\": \"http.response.body\",\n \"body\": json.dumps({\"error\": message}).encode(\"utf-8\"),\n \"more_body\": False,\n }\n )\n return"} | |
{"id": "llm/serve.py:23", "code": "async def read_request_body(receive):\n \"\"\"\n Reads and concatenates all HTTP request body chunks into a single bytes object.\n \"\"\"\n body = b\"\"\n more_body = True\n while more_body:\n message = await receive()\n if message[\"type\"] == \"http.request\":\n body += message.get(\"body\", b\"\")\n more_body = message.get(\"more_body\", False)\n return body"} | |
{"id": "llm/serve.py:37", "code": "async def handle_completions_request(scope, receive, send):\n \"\"\"\n Handle POST /v1/completions with possible streaming (SSE) or non-streamed JSON output.\n \"\"\"\n # Read and parse JSON payload\n raw_body = await read_request_body(receive)\n try:\n data = json.loads(raw_body.decode(\"utf-8\"))\n print(data)\n except json.JSONDecodeError:\n await error(send, 400, \"Invalid JSON\")\n return\n\n prompt = data.get(\"prompt\", \"\")\n is_stream = data.get(\"stream\", False)\n\n try:\n model = llm.get_async_model(data.get(\"model\"))\n except llm.UnknownModelError:\n await error(send, 400, \"Unknown model\")\n return\n\n if is_stream:\n # Streamed SSE response\n await send(\n {\n \"type\": \"http.response.start\",\n \"status\": 200,\n \"headers\": [\n (b\"content-type\", b\"text/event-stream\"),\n (b\"cache-control\", b\"no-cache\"),\n (b\"connection\", b\"keep-alive\"),\n ],\n }\n )\n\n # Each chunk from the model is sent as an SSE \"data: ...\" line\n async for chunk in model.prompt(prompt):\n # For OpenAI-compatible SSE, each chunk is typically wrapped in JSON\n # The \"choices\" list can hold partial text, e.g. chunk, in \"text\"\n chunk_data = {\n \"id\": \"cmpl-xxx\",\n \"object\": \"text_completion\",\n \"created\": 1234567890,\n \"model\": \"gpt-4\",\n \"choices\": [\n {\"text\": chunk, \"index\": 0, \"logprobs\": None, \"finish_reason\": None}\n ],\n }\n sse_line = f\"data: {json.dumps(chunk_data)}\\n\\n\"\n await send(\n {\n \"type\": \"http.response.body\",\n \"body\": sse_line.encode(\"utf-8\"),\n \"more_body\": True,\n }\n )\n\n # Signal that the stream is complete\n await send(\n {\n \"type\": \"http.response.body\",\n \"body\": b\"data: [DONE]\\n\\n\",\n \"more_body\": False,\n }\n )\n else:\n # Non-streamed JSON response: collect all chunks first\n full_output = []\n async for chunk in model.prompt(prompt):\n full_output.append(chunk)\n concatenated = \"\".join(full_output)\n\n # Build an OpenAI-like JSON response\n response_body = {\n \"id\": \"cmpl-xxx\",\n \"object\": \"text_completion\",\n \"created\": 1234567890,\n \"model\": \"gpt-4\",\n \"choices\": [\n {\n \"text\": concatenated,\n \"index\": 0,\n \"logprobs\": None,\n \"finish_reason\": \"stop\",\n }\n ],\n # \"usage\" field omitted for brevity\n }\n\n # Send JSON response\n await send(\n {\n \"type\": \"http.response.start\",\n \"status\": 200,\n \"headers\": [(b\"content-type\", b\"application/json\")],\n }\n )\n await send(\n {\n \"type\": \"http.response.body\",\n \"body\": json.dumps(response_body).encode(\"utf-8\"),\n \"more_body\": False,\n }\n )"} | |
{"id": "llm/serve.py:144", "code": "async def app(scope, receive, send):\n \"\"\"\n A simple ASGI application that routes /v1/completions to our OpenAI-compatible handler.\n \"\"\"\n if scope[\"type\"] == \"http\":\n path = scope.get(\"path\", \"\")\n method = scope.get(\"method\", \"\").upper()\n\n # Route to /v1/completions\n if path == \"/v1/completions\":\n if method == \"POST\":\n await handle_completions_request(scope, receive, send)\n else:\n await error(send, 405, \"Method not allowed\")\n return\n else:\n # Handle unrecognized paths or methods with a simple 404\n await error(send, 404, \"Not found\")\n return\n else:\n pass"} | |
{"id": "llm/templates.py:6", "code": "class Template(BaseModel):\n name: str\n prompt: Optional[str] = None\n system: Optional[str] = None\n model: Optional[str] = None\n defaults: Optional[Dict[str, Any]] = None\n options: Optional[Dict[str, Any]] = None\n # Should a fenced code block be extracted?\n extract: Optional[bool] = None\n extract_last: Optional[bool] = None\n schema_object: Optional[dict] = None\n\n model_config = ConfigDict(extra=\"forbid\")\n\n class MissingVariables(Exception):\n pass\n\n def evaluate(\n self, input: str, params: Optional[Dict[str, Any]] = None\n ) -> Tuple[Optional[str], Optional[str]]:\n params = params or {}\n params[\"input\"] = input\n if self.defaults:\n for k, v in self.defaults.items():\n if k not in params:\n params[k] = v\n prompt: Optional[str] = None\n system: Optional[str] = None\n if not self.prompt:\n system = self.interpolate(self.system, params)\n prompt = input\n else:\n prompt = self.interpolate(self.prompt, params)\n system = self.interpolate(self.system, params)\n return prompt, system\n\n def vars(self) -> set:\n all_vars = set()\n for text in [self.prompt, self.system]:\n if not text:\n continue\n all_vars.update(self.extract_vars(string.Template(text)))\n return all_vars\n\n @classmethod\n def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:\n if not text:\n return text\n # Confirm all variables in text are provided\n string_template = string.Template(text)\n vars = cls.extract_vars(string_template)\n missing = [p for p in vars if p not in params]\n if missing:\n raise cls.MissingVariables(\n \"Missing variables: {}\".format(\", \".join(missing))\n )\n return string_template.substitute(**params)\n\n @staticmethod\n def extract_vars(string_template: string.Template) -> List[str]:\n return [\n match.group(\"named\")\n for match in string_template.pattern.finditer(string_template.template)\n ]"} | |
{"id": "llm/templates.py:23", "code": " def evaluate(\n self, input: str, params: Optional[Dict[str, Any]] = None\n ) -> Tuple[Optional[str], Optional[str]]:\n params = params or {}\n params[\"input\"] = input\n if self.defaults:\n for k, v in self.defaults.items():\n if k not in params:\n params[k] = v\n prompt: Optional[str] = None\n system: Optional[str] = None\n if not self.prompt:\n system = self.interpolate(self.system, params)\n prompt = input\n else:\n prompt = self.interpolate(self.prompt, params)\n system = self.interpolate(self.system, params)\n return prompt, system"} | |
{"id": "llm/templates.py:42", "code": " def vars(self) -> set:\n all_vars = set()\n for text in [self.prompt, self.system]:\n if not text:\n continue\n all_vars.update(self.extract_vars(string.Template(text)))\n return all_vars"} | |
{"id": "llm/templates.py:50", "code": " @classmethod\n def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:\n if not text:\n return text\n # Confirm all variables in text are provided\n string_template = string.Template(text)\n vars = cls.extract_vars(string_template)\n missing = [p for p in vars if p not in params]\n if missing:\n raise cls.MissingVariables(\n \"Missing variables: {}\".format(\", \".join(missing))\n )\n return string_template.substitute(**params)"} | |
{"id": "llm/templates.py:64", "code": " @staticmethod\n def extract_vars(string_template: string.Template) -> List[str]:\n return [\n match.group(\"named\")\n for match in string_template.pattern.finditer(string_template.template)\n ]"} | |
{"id": "llm/embeddings.py:13", "code": "@dataclass\nclass Entry:\n id: str\n score: Optional[float]\n content: Optional[str] = None\n metadata: Optional[Dict[str, Any]] = None"} | |
{"id": "llm/embeddings.py:21", "code": "class Collection:\n class DoesNotExist(Exception):\n pass\n\n def __init__(\n self,\n name: str,\n db: Optional[Database] = None,\n *,\n model: Optional[EmbeddingModel] = None,\n model_id: Optional[str] = None,\n create: bool = True,\n ) -> None:\n \"\"\"\n A collection of embeddings\n\n Returns the collection with the given name, creating it if it does not exist.\n\n If you set create=False a Collection.DoesNotExist exception will be raised if the\n collection does not already exist.\n\n Args:\n db (sqlite_utils.Database): Database to store the collection in\n name (str): Name of the collection\n model (llm.models.EmbeddingModel, optional): Embedding model to use\n model_id (str, optional): Alternatively, ID of the embedding model to use\n create (bool, optional): Whether to create the collection if it does not exist\n \"\"\"\n import llm\n\n self.db = db or Database(memory=True)\n self.name = name\n self._model = model\n\n embeddings_migrations.apply(self.db)\n\n rows = list(self.db[\"collections\"].rows_where(\"name = ?\", [self.name]))\n if rows:\n row = rows[0]\n self.id = row[\"id\"]\n self.model_id = row[\"model\"]\n else:\n if create:\n # Collection does not exist, so model or model_id is required\n if not model and not model_id:\n raise ValueError(\n \"Either model= or model_id= must be provided when creating a new collection\"\n )\n # Create it\n if model_id:\n # Resolve alias\n model = llm.get_embedding_model(model_id)\n self._model = model\n model_id = cast(EmbeddingModel, model).model_id\n self.id = (\n cast(Table, self.db[\"collections\"])\n .insert(\n {\n \"name\": self.name,\n \"model\": model_id,\n }\n )\n .last_pk\n )\n else:\n raise self.DoesNotExist(f\"Collection '{name}' does not exist\")\n\n def model(self) -> EmbeddingModel:\n \"Return the embedding model used by this collection\"\n import llm\n\n if self._model is None:\n self._model = llm.get_embedding_model(self.model_id)\n\n return cast(EmbeddingModel, self._model)\n\n def count(self) -> int:\n \"\"\"\n Count the number of items in the collection.\n\n Returns:\n int: Number of items in the collection\n \"\"\"\n return next(\n self.db.query(\n \"\"\"\n select count(*) as c from embeddings where collection_id = (\n select id from collections where name = ?\n )\n \"\"\",\n (self.name,),\n )\n )[\"c\"]\n\n def embed(\n self,\n id: str,\n value: Union[str, bytes],\n metadata: Optional[Dict[str, Any]] = None,\n store: bool = False,\n ) -> None:\n \"\"\"\n Embed value and store it in the collection with a given ID.\n\n Args:\n id (str): ID for the value\n value (str or bytes): value to be embedded\n metadata (dict, optional): Metadata to be stored\n store (bool, optional): Whether to store the value in the content or content_blob column\n \"\"\"\n from llm import encode\n\n content_hash = self.content_hash(value)\n if self.db[\"embeddings\"].count_where(\n \"content_hash = ? and collection_id = ?\", [content_hash, self.id]\n ):\n return\n embedding = self.model().embed(value)\n cast(Table, self.db[\"embeddings\"]).insert(\n {\n \"collection_id\": self.id,\n \"id\": id,\n \"embedding\": encode(embedding),\n \"content\": value if (store and isinstance(value, str)) else None,\n \"content_blob\": value if (store and isinstance(value, bytes)) else None,\n \"content_hash\": content_hash,\n \"metadata\": json.dumps(metadata) if metadata else None,\n \"updated\": int(time.time()),\n },\n replace=True,\n )\n\n def embed_multi(\n self,\n entries: Iterable[Tuple[str, Union[str, bytes]]],\n store: bool = False,\n batch_size: int = 100,\n ) -> None:\n \"\"\"\n Embed multiple texts and store them in the collection with given IDs.\n\n Args:\n entries (iterable): Iterable of (id: str, text: str) tuples\n store (bool, optional): Whether to store the text in the content column\n batch_size (int, optional): custom maximum batch size to use\n \"\"\"\n self.embed_multi_with_metadata(\n ((id, value, None) for id, value in entries),\n store=store,\n batch_size=batch_size,\n )\n\n def embed_multi_with_metadata(\n self,\n entries: Iterable[Tuple[str, Union[str, bytes], Optional[Dict[str, Any]]]],\n store: bool = False,\n batch_size: int = 100,\n ) -> None:\n \"\"\"\n Embed multiple values along with metadata and store them in the collection with given IDs.\n\n Args:\n entries (iterable): Iterable of (id: str, value: str or bytes, metadata: None or dict)\n store (bool, optional): Whether to store the value in the content or content_blob column\n batch_size (int, optional): custom maximum batch size to use\n \"\"\"\n import llm\n\n batch_size = min(batch_size, (self.model().batch_size or batch_size))\n iterator = iter(entries)\n collection_id = self.id\n while True:\n batch = list(islice(iterator, batch_size))\n if not batch:\n break\n # Calculate hashes first\n items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]\n # Any of those hashes already exist?\n existing_ids = [\n row[\"id\"]\n for row in self.db.query(\n \"\"\"\n select id from embeddings\n where collection_id = ? and content_hash in ({})\n \"\"\".format(\n \",\".join(\"?\" for _ in items_and_hashes)\n ),\n [collection_id]\n + [item_and_hash[1] for item_and_hash in items_and_hashes],\n )\n ]\n filtered_batch = [item for item in batch if item[0] not in existing_ids]\n embeddings = list(\n self.model().embed_multi(item[1] for item in filtered_batch)\n )\n with self.db.conn:\n cast(Table, self.db[\"embeddings\"]).insert_all(\n (\n {\n \"collection_id\": collection_id,\n \"id\": id,\n \"embedding\": llm.encode(embedding),\n \"content\": (\n value if (store and isinstance(value, str)) else None\n ),\n \"content_blob\": (\n value if (store and isinstance(value, bytes)) else None\n ),\n \"content_hash\": self.content_hash(value),\n \"metadata\": json.dumps(metadata) if metadata else None,\n \"updated\": int(time.time()),\n }\n for (embedding, (id, value, metadata)) in zip(\n embeddings, filtered_batch\n )\n ),\n replace=True,\n )\n\n def similar_by_vector(\n self, vector: List[float], number: int = 10, skip_id: Optional[str] = None\n ) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given vector.\n\n Args:\n vector (list): Vector to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n import llm\n\n def distance_score(other_encoded):\n other_vector = llm.decode(other_encoded)\n return llm.cosine_similarity(other_vector, vector)\n\n self.db.register_function(distance_score, replace=True)\n\n where_bits = [\"collection_id = ?\"]\n where_args = [str(self.id)]\n\n if skip_id:\n where_bits.append(\"id != ?\")\n where_args.append(skip_id)\n\n return [\n Entry(\n id=row[\"id\"],\n score=row[\"score\"],\n content=row[\"content\"],\n metadata=json.loads(row[\"metadata\"]) if row[\"metadata\"] else None,\n )\n for row in self.db.query(\n \"\"\"\n select id, content, metadata, distance_score(embedding) as score\n from embeddings\n where {where}\n order by score desc limit {number}\n \"\"\".format(\n where=\" and \".join(where_bits),\n number=number,\n ),\n where_args,\n )\n ]\n\n def similar_by_id(self, id: str, number: int = 10) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given ID.\n\n Args:\n id (str): ID to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n import llm\n\n matches = list(\n self.db[\"embeddings\"].rows_where(\n \"collection_id = ? and id = ?\", (self.id, id)\n )\n )\n if not matches:\n raise self.DoesNotExist(\"ID not found\")\n embedding = matches[0][\"embedding\"]\n comparison_vector = llm.decode(embedding)\n return self.similar_by_vector(comparison_vector, number, skip_id=id)\n\n def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given value.\n\n Args:\n value (str or bytes): value to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n comparison_vector = self.model().embed(value)\n return self.similar_by_vector(comparison_vector, number)\n\n @classmethod\n def exists(cls, db: Database, name: str) -> bool:\n \"\"\"\n Does this collection exist in the database?\n\n Args:\n name (str): Name of the collection\n \"\"\"\n rows = list(db[\"collections\"].rows_where(\"name = ?\", [name]))\n return bool(rows)\n\n def delete(self):\n \"\"\"\n Delete the collection and its embeddings from the database\n \"\"\"\n with self.db.conn:\n self.db.execute(\"delete from embeddings where collection_id = ?\", [self.id])\n self.db.execute(\"delete from collections where id = ?\", [self.id])\n\n @staticmethod\n def content_hash(input: Union[str, bytes]) -> bytes:\n \"Hash content for deduplication. Override to change hashing behavior.\"\n if isinstance(input, str):\n input = input.encode(\"utf8\")\n return hashlib.md5(input).digest()"} | |
{"id": "llm/embeddings.py:25", "code": " def __init__(\n self,\n name: str,\n db: Optional[Database] = None,\n *,\n model: Optional[EmbeddingModel] = None,\n model_id: Optional[str] = None,\n create: bool = True,\n ) -> None:\n \"\"\"\n A collection of embeddings\n\n Returns the collection with the given name, creating it if it does not exist.\n\n If you set create=False a Collection.DoesNotExist exception will be raised if the\n collection does not already exist.\n\n Args:\n db (sqlite_utils.Database): Database to store the collection in\n name (str): Name of the collection\n model (llm.models.EmbeddingModel, optional): Embedding model to use\n model_id (str, optional): Alternatively, ID of the embedding model to use\n create (bool, optional): Whether to create the collection if it does not exist\n \"\"\"\n import llm\n\n self.db = db or Database(memory=True)\n self.name = name\n self._model = model\n\n embeddings_migrations.apply(self.db)\n\n rows = list(self.db[\"collections\"].rows_where(\"name = ?\", [self.name]))\n if rows:\n row = rows[0]\n self.id = row[\"id\"]\n self.model_id = row[\"model\"]\n else:\n if create:\n # Collection does not exist, so model or model_id is required\n if not model and not model_id:\n raise ValueError(\n \"Either model= or model_id= must be provided when creating a new collection\"\n )\n # Create it\n if model_id:\n # Resolve alias\n model = llm.get_embedding_model(model_id)\n self._model = model\n model_id = cast(EmbeddingModel, model).model_id\n self.id = (\n cast(Table, self.db[\"collections\"])\n .insert(\n {\n \"name\": self.name,\n \"model\": model_id,\n }\n )\n .last_pk\n )\n else:\n raise self.DoesNotExist(f\"Collection '{name}' does not exist\")"} | |
{"id": "llm/embeddings.py:88", "code": " def model(self) -> EmbeddingModel:\n \"Return the embedding model used by this collection\"\n import llm\n\n if self._model is None:\n self._model = llm.get_embedding_model(self.model_id)\n\n return cast(EmbeddingModel, self._model)"} | |
{"id": "llm/embeddings.py:97", "code": " def count(self) -> int:\n \"\"\"\n Count the number of items in the collection.\n\n Returns:\n int: Number of items in the collection\n \"\"\"\n return next(\n self.db.query(\n \"\"\"\n select count(*) as c from embeddings where collection_id = (\n select id from collections where name = ?\n )\n \"\"\",\n (self.name,),\n )\n )[\"c\"]"} | |
{"id": "llm/embeddings.py:115", "code": " def embed(\n self,\n id: str,\n value: Union[str, bytes],\n metadata: Optional[Dict[str, Any]] = None,\n store: bool = False,\n ) -> None:\n \"\"\"\n Embed value and store it in the collection with a given ID.\n\n Args:\n id (str): ID for the value\n value (str or bytes): value to be embedded\n metadata (dict, optional): Metadata to be stored\n store (bool, optional): Whether to store the value in the content or content_blob column\n \"\"\"\n from llm import encode\n\n content_hash = self.content_hash(value)\n if self.db[\"embeddings\"].count_where(\n \"content_hash = ? and collection_id = ?\", [content_hash, self.id]\n ):\n return\n embedding = self.model().embed(value)\n cast(Table, self.db[\"embeddings\"]).insert(\n {\n \"collection_id\": self.id,\n \"id\": id,\n \"embedding\": encode(embedding),\n \"content\": value if (store and isinstance(value, str)) else None,\n \"content_blob\": value if (store and isinstance(value, bytes)) else None,\n \"content_hash\": content_hash,\n \"metadata\": json.dumps(metadata) if metadata else None,\n \"updated\": int(time.time()),\n },\n replace=True,\n )"} | |
{"id": "llm/embeddings.py:153", "code": " def embed_multi(\n self,\n entries: Iterable[Tuple[str, Union[str, bytes]]],\n store: bool = False,\n batch_size: int = 100,\n ) -> None:\n \"\"\"\n Embed multiple texts and store them in the collection with given IDs.\n\n Args:\n entries (iterable): Iterable of (id: str, text: str) tuples\n store (bool, optional): Whether to store the text in the content column\n batch_size (int, optional): custom maximum batch size to use\n \"\"\"\n self.embed_multi_with_metadata(\n ((id, value, None) for id, value in entries),\n store=store,\n batch_size=batch_size,\n )"} | |
{"id": "llm/embeddings.py:173", "code": " def embed_multi_with_metadata(\n self,\n entries: Iterable[Tuple[str, Union[str, bytes], Optional[Dict[str, Any]]]],\n store: bool = False,\n batch_size: int = 100,\n ) -> None:\n \"\"\"\n Embed multiple values along with metadata and store them in the collection with given IDs.\n\n Args:\n entries (iterable): Iterable of (id: str, value: str or bytes, metadata: None or dict)\n store (bool, optional): Whether to store the value in the content or content_blob column\n batch_size (int, optional): custom maximum batch size to use\n \"\"\"\n import llm\n\n batch_size = min(batch_size, (self.model().batch_size or batch_size))\n iterator = iter(entries)\n collection_id = self.id\n while True:\n batch = list(islice(iterator, batch_size))\n if not batch:\n break\n # Calculate hashes first\n items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]\n # Any of those hashes already exist?\n existing_ids = [\n row[\"id\"]\n for row in self.db.query(\n \"\"\"\n select id from embeddings\n where collection_id = ? and content_hash in ({})\n \"\"\".format(\n \",\".join(\"?\" for _ in items_and_hashes)\n ),\n [collection_id]\n + [item_and_hash[1] for item_and_hash in items_and_hashes],\n )\n ]\n filtered_batch = [item for item in batch if item[0] not in existing_ids]\n embeddings = list(\n self.model().embed_multi(item[1] for item in filtered_batch)\n )\n with self.db.conn:\n cast(Table, self.db[\"embeddings\"]).insert_all(\n (\n {\n \"collection_id\": collection_id,\n \"id\": id,\n \"embedding\": llm.encode(embedding),\n \"content\": (\n value if (store and isinstance(value, str)) else None\n ),\n \"content_blob\": (\n value if (store and isinstance(value, bytes)) else None\n ),\n \"content_hash\": self.content_hash(value),\n \"metadata\": json.dumps(metadata) if metadata else None,\n \"updated\": int(time.time()),\n }\n for (embedding, (id, value, metadata)) in zip(\n embeddings, filtered_batch\n )\n ),\n replace=True,\n )"} | |
{"id": "llm/embeddings.py:240", "code": " def similar_by_vector(\n self, vector: List[float], number: int = 10, skip_id: Optional[str] = None\n ) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given vector.\n\n Args:\n vector (list): Vector to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n import llm\n\n def distance_score(other_encoded):\n other_vector = llm.decode(other_encoded)\n return llm.cosine_similarity(other_vector, vector)\n\n self.db.register_function(distance_score, replace=True)\n\n where_bits = [\"collection_id = ?\"]\n where_args = [str(self.id)]\n\n if skip_id:\n where_bits.append(\"id != ?\")\n where_args.append(skip_id)\n\n return [\n Entry(\n id=row[\"id\"],\n score=row[\"score\"],\n content=row[\"content\"],\n metadata=json.loads(row[\"metadata\"]) if row[\"metadata\"] else None,\n )\n for row in self.db.query(\n \"\"\"\n select id, content, metadata, distance_score(embedding) as score\n from embeddings\n where {where}\n order by score desc limit {number}\n \"\"\".format(\n where=\" and \".join(where_bits),\n number=number,\n ),\n where_args,\n )\n ]"} | |
{"id": "llm/embeddings.py:289", "code": " def similar_by_id(self, id: str, number: int = 10) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given ID.\n\n Args:\n id (str): ID to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n import llm\n\n matches = list(\n self.db[\"embeddings\"].rows_where(\n \"collection_id = ? and id = ?\", (self.id, id)\n )\n )\n if not matches:\n raise self.DoesNotExist(\"ID not found\")\n embedding = matches[0][\"embedding\"]\n comparison_vector = llm.decode(embedding)\n return self.similar_by_vector(comparison_vector, number, skip_id=id)"} | |
{"id": "llm/embeddings.py:313", "code": " def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]:\n \"\"\"\n Find similar items in the collection by a given value.\n\n Args:\n value (str or bytes): value to search by\n number (int, optional): Number of similar items to return\n\n Returns:\n list: List of Entry objects\n \"\"\"\n comparison_vector = self.model().embed(value)\n return self.similar_by_vector(comparison_vector, number)"} | |
{"id": "llm/embeddings.py:327", "code": " @classmethod\n def exists(cls, db: Database, name: str) -> bool:\n \"\"\"\n Does this collection exist in the database?\n\n Args:\n name (str): Name of the collection\n \"\"\"\n rows = list(db[\"collections\"].rows_where(\"name = ?\", [name]))\n return bool(rows)"} | |
{"id": "llm/embeddings.py:338", "code": " def delete(self):\n \"\"\"\n Delete the collection and its embeddings from the database\n \"\"\"\n with self.db.conn:\n self.db.execute(\"delete from embeddings where collection_id = ?\", [self.id])\n self.db.execute(\"delete from collections where id = ?\", [self.id])"} | |
{"id": "llm/embeddings.py:346", "code": " @staticmethod\n def content_hash(input: Union[str, bytes]) -> bytes:\n \"Hash content for deduplication. Override to change hashing behavior.\"\n if isinstance(input, str):\n input = input.encode(\"utf8\")\n return hashlib.md5(input).digest()"} | |
{"id": "llm/cli.py:76", "code": "class AttachmentType(click.ParamType):\n name = \"attachment\"\n\n def convert(self, value, param, ctx):\n if value == \"-\":\n content = sys.stdin.buffer.read()\n # Try to guess type\n mimetype = mimetype_from_string(content)\n if mimetype is None:\n raise click.BadParameter(\"Could not determine mimetype of stdin\")\n return Attachment(type=mimetype, path=None, url=None, content=content)\n if \"://\" in value:\n # Confirm URL exists and try to guess type\n try:\n response = httpx.head(value)\n response.raise_for_status()\n mimetype = response.headers.get(\"content-type\")\n except httpx.HTTPError as ex:\n raise click.BadParameter(str(ex))\n return Attachment(mimetype, None, value, None)\n # Check that the file exists\n path = pathlib.Path(value)\n if not path.exists():\n self.fail(f\"File {value} does not exist\", param, ctx)\n path = path.resolve()\n # Try to guess type\n mimetype = mimetype_from_path(str(path))\n if mimetype is None:\n raise click.BadParameter(f\"Could not determine mimetype of {value}\")\n return Attachment(type=mimetype, path=str(path), url=None, content=None)"} | |
{"id": "llm/cli.py:79", "code": " def convert(self, value, param, ctx):\n if value == \"-\":\n content = sys.stdin.buffer.read()\n # Try to guess type\n mimetype = mimetype_from_string(content)\n if mimetype is None:\n raise click.BadParameter(\"Could not determine mimetype of stdin\")\n return Attachment(type=mimetype, path=None, url=None, content=content)\n if \"://\" in value:\n # Confirm URL exists and try to guess type\n try:\n response = httpx.head(value)\n response.raise_for_status()\n mimetype = response.headers.get(\"content-type\")\n except httpx.HTTPError as ex:\n raise click.BadParameter(str(ex))\n return Attachment(mimetype, None, value, None)\n # Check that the file exists\n path = pathlib.Path(value)\n if not path.exists():\n self.fail(f\"File {value} does not exist\", param, ctx)\n path = path.resolve()\n # Try to guess type\n mimetype = mimetype_from_path(str(path))\n if mimetype is None:\n raise click.BadParameter(f\"Could not determine mimetype of {value}\")\n return Attachment(type=mimetype, path=str(path), url=None, content=None)"} | |
{"id": "llm/cli.py:108", "code": "def attachment_types_callback(ctx, param, values):\n collected = []\n for value, mimetype in values:\n if \"://\" in value:\n attachment = Attachment(mimetype, None, value, None)\n elif value == \"-\":\n content = sys.stdin.buffer.read()\n attachment = Attachment(mimetype, None, None, content)\n else:\n # Look for file\n path = pathlib.Path(value)\n if not path.exists():\n raise click.BadParameter(f\"File {value} does not exist\")\n path = path.resolve()\n attachment = Attachment(mimetype, str(path), None, None)\n collected.append(attachment)\n return collected"} | |
{"id": "llm/cli.py:127", "code": "def json_validator(object_name):\n def validator(ctx, param, value):\n if value is None:\n return value\n try:\n obj = json.loads(value)\n if not isinstance(obj, dict):\n raise click.BadParameter(f\"{object_name} must be a JSON object\")\n return obj\n except json.JSONDecodeError:\n raise click.BadParameter(f\"{object_name} must be valid JSON\")\n\n return validator"} | |
{"id": "llm/cli.py:142", "code": "def schema_option(fn):\n click.option(\n \"schema_input\",\n \"--schema\",\n help=\"JSON schema, filepath or ID\",\n )(fn)\n return fn"} | |
{"id": "llm/cli.py:151", "code": "@click.group(\n cls=DefaultGroup,\n default=\"prompt\",\n default_if_no_args=True,\n)\n@click.version_option()\ndef cli():\n \"\"\"\n Access Large Language Models from the command-line\n\n Documentation: https://llm.datasette.io/\n\n LLM can run models from many different providers. Consult the\n plugin directory for a list of available models:\n\n https://llm.datasette.io/en/stable/plugins/directory.html\n\n To get started with OpenAI, obtain an API key from them and:\n\n \\b\n $ llm keys set openai\n Enter key: ...\n\n Then execute a prompt like this:\n\n llm 'Five outrageous names for a pet pelican'\n \"\"\""} | |
{"id": "llm/cli.py:180", "code": "@cli.command(name=\"prompt\")\n@click.argument(\"prompt\", required=False)\n@click.option(\"-s\", \"--system\", help=\"System prompt to use\")\n@click.option(\"model_id\", \"-m\", \"--model\", help=\"Model to use\")\n@click.option(\n \"queries\",\n \"-q\",\n \"--query\",\n multiple=True,\n help=\"Use first model matching these strings\",\n)\n@click.option(\n \"attachments\",\n \"-a\",\n \"--attachment\",\n type=AttachmentType(),\n multiple=True,\n help=\"Attachment path or URL or -\",\n)\n@click.option(\n \"attachment_types\",\n \"--at\",\n \"--attachment-type\",\n type=(str, str),\n multiple=True,\n callback=attachment_types_callback,\n help=\"Attachment with explicit mimetype\",\n)\n@click.option(\n \"options\",\n \"-o\",\n \"--option\",\n type=(str, str),\n multiple=True,\n help=\"key/value options for the model\",\n)\n@schema_option\n@click.option(\n \"--schema-multi\",\n help=\"JSON schema to use for multiple results\",\n)\n@click.option(\"-t\", \"--template\", help=\"Template to use\")\n@click.option(\n \"-p\",\n \"--param\",\n multiple=True,\n type=(str, str),\n help=\"Parameters for template\",\n)\n@click.option(\"--no-stream\", is_flag=True, help=\"Do not stream output\")\n@click.option(\"-n\", \"--no-log\", is_flag=True, help=\"Don't log to database\")\n@click.option(\"--log\", is_flag=True, help=\"Log prompt and response to the database\")\n@click.option(\n \"_continue\",\n \"-c\",\n \"--continue\",\n is_flag=True,\n flag_value=-1,\n help=\"Continue the most recent conversation.\",\n)\n@click.option(\n \"conversation_id\",\n \"--cid\",\n \"--conversation\",\n help=\"Continue the conversation with the given ID.\",\n)\n@click.option(\"--key\", help=\"API key to use\")\n@click.option(\"--save\", help=\"Save prompt with this template name\")\n@click.option(\"async_\", \"--async\", is_flag=True, help=\"Run prompt asynchronously\")\n@click.option(\"-u\", \"--usage\", is_flag=True, help=\"Show token usage\")\n@click.option(\"-x\", \"--extract\", is_flag=True, help=\"Extract first fenced code block\")\n@click.option(\n \"extract_last\",\n \"--xl\",\n \"--extract-last\",\n is_flag=True,\n help=\"Extract last fenced code block\",\n)\ndef prompt(\n prompt,\n system,\n model_id,\n queries,\n attachments,\n attachment_types,\n options,\n schema_input,\n schema_multi,\n template,\n param,\n no_stream,\n no_log,\n log,\n _continue,\n conversation_id,\n key,\n save,\n async_,\n usage,\n extract,\n extract_last,\n):\n \"\"\"\n Execute a prompt\n\n Documentation: https://llm.datasette.io/en/stable/usage.html\n\n Examples:\n\n \\b\n llm 'Capital of France?'\n llm 'Capital of France?' -m gpt-4o\n llm 'Capital of France?' -s 'answer in Spanish'\n\n Multi-modal models can be called with attachments like this:\n\n \\b\n llm 'Extract text from this image' -a image.jpg\n llm 'Describe' -a https://static.simonwillison.net/static/2024/pelicans.jpg\n cat image | llm 'describe image' -a -\n # With an explicit mimetype:\n cat image | llm 'describe image' --at - image/jpeg\n\n The -x/--extract option returns just the content of the first ``` fenced code\n block, if one is present. If none are present it returns the full response.\n\n \\b\n llm 'JavaScript function for reversing a string' -x\n \"\"\"\n if log and no_log:\n raise click.ClickException(\"--log and --no-log are mutually exclusive\")\n\n log_path = logs_db_path()\n (log_path.parent).mkdir(parents=True, exist_ok=True)\n db = sqlite_utils.Database(log_path)\n migrate(db)\n\n if queries and not model_id:\n # Use -q options to find model with shortest model_id\n matches = []\n for model_with_aliases in get_models_with_aliases():\n if all(model_with_aliases.matches(q) for q in queries):\n matches.append(model_with_aliases.model.model_id)\n if not matches:\n raise click.ClickException(\n \"No model found matching queries {}\".format(\", \".join(queries))\n )\n model_id = min(matches, key=len)\n\n if schema_multi:\n schema_input = schema_multi\n\n schema = resolve_schema_input(db, schema_input, load_template)\n\n if schema_multi:\n # Convert that schema into multiple \"items\" of the same schema\n schema = multi_schema(schema)\n\n model_aliases = get_model_aliases()\n\n def read_prompt():\n nonlocal prompt, schema\n\n # Is there extra prompt available on stdin?\n stdin_prompt = None\n if not sys.stdin.isatty():\n stdin_prompt = sys.stdin.read()\n\n if stdin_prompt:\n bits = [stdin_prompt]\n if prompt:\n bits.append(prompt)\n prompt = \" \".join(bits)\n\n if (\n prompt is None\n and not save\n and sys.stdin.isatty()\n and not attachments\n and not attachment_types\n and not schema\n ):\n # Hang waiting for input to stdin (unless --save)\n prompt = sys.stdin.read()\n return prompt\n\n if save:\n # We are saving their prompt/system/etc to a new template\n # Fields to save: prompt, system, model - and more in the future\n disallowed_options = []\n for option, var in (\n (\"--template\", template),\n (\"--continue\", _continue),\n (\"--cid\", conversation_id),\n ):\n if var:\n disallowed_options.append(option)\n if disallowed_options:\n raise click.ClickException(\n \"--save cannot be used with {}\".format(\", \".join(disallowed_options))\n )\n path = template_dir() / f\"{save}.yaml\"\n to_save = {}\n if model_id:\n try:\n to_save[\"model\"] = model_aliases[model_id].model_id\n except KeyError:\n raise click.ClickException(\"'{}' is not a known model\".format(model_id))\n prompt = read_prompt()\n if prompt:\n to_save[\"prompt\"] = prompt\n if system:\n to_save[\"system\"] = system\n if param:\n to_save[\"defaults\"] = dict(param)\n if extract:\n to_save[\"extract\"] = True\n if extract_last:\n to_save[\"extract_last\"] = True\n if schema:\n to_save[\"schema_object\"] = schema\n if options:\n # Need to validate and convert their types first\n model = get_model(model_id or get_default_model())\n try:\n to_save[\"options\"] = dict(\n (key, value)\n for key, value in model.Options(**dict(options))\n if value is not None\n )\n except pydantic.ValidationError as ex:\n raise click.ClickException(render_errors(ex.errors()))\n path.write_text(\n yaml.dump(\n to_save,\n indent=4,\n default_flow_style=False,\n sort_keys=False,\n ),\n \"utf-8\",\n )\n return\n\n if template:\n params = dict(param)\n # Cannot be used with system\n if system:\n raise click.ClickException(\"Cannot use -t/--template and --system together\")\n template_obj = load_template(template)\n extract = template_obj.extract\n extract_last = template_obj.extract_last\n if template_obj.schema_object:\n schema = template_obj.schema_object\n input_ = \"\"\n if template_obj.options:\n # Make options mutable (they start as a tuple)\n options = list(options)\n # Load any options, provided they were not set using -o already\n specified_options = dict(options)\n for option_name, option_value in template_obj.options.items():\n if option_name not in specified_options:\n options.append((option_name, option_value))\n if \"input\" in template_obj.vars():\n input_ = read_prompt()\n try:\n template_prompt, system = template_obj.evaluate(input_, params)\n if template_prompt:\n # Over-ride user prompt only if the template provided one\n prompt = template_prompt\n except Template.MissingVariables as ex:\n raise click.ClickException(str(ex))\n if model_id is None and template_obj.model:\n model_id = template_obj.model\n\n if extract or extract_last:\n no_stream = True\n\n conversation = None\n if conversation_id or _continue:\n # Load the conversation - loads most recent if no ID provided\n try:\n conversation = load_conversation(conversation_id, async_=async_)\n except UnknownModelError as ex:\n raise click.ClickException(str(ex))\n\n # Figure out which model we are using\n if model_id is None:\n if conversation:\n model_id = conversation.model.model_id\n else:\n model_id = get_default_model()\n\n # Now resolve the model\n try:\n if async_:\n model = get_async_model(model_id)\n else:\n model = get_model(model_id)\n except UnknownModelError as ex:\n raise click.ClickException(ex)\n\n if conversation:\n # To ensure it can see the key\n conversation.model = model\n\n # Validate options\n validated_options = {}\n if options:\n # Validate with pydantic\n try:\n validated_options = dict(\n (key, value)\n for key, value in model.Options(**dict(options))\n if value is not None\n )\n except pydantic.ValidationError as ex:\n raise click.ClickException(render_errors(ex.errors()))\n\n # Add on any default model options\n default_options = get_model_options(model_id)\n for key_, value in default_options.items():\n if key_ not in validated_options:\n validated_options[key_] = value\n\n kwargs = {**validated_options}\n\n resolved_attachments = [*attachments, *attachment_types]\n\n should_stream = model.can_stream and not no_stream\n if not should_stream:\n kwargs[\"stream\"] = False\n\n if isinstance(model, (KeyModel, AsyncKeyModel)):\n kwargs[\"key\"] = key\n\n prompt = read_prompt()\n response = None\n\n prompt_method = model.prompt\n if conversation:\n prompt_method = conversation.prompt\n\n try:\n if async_:\n\n async def inner():\n if should_stream:\n response = prompt_method(\n prompt,\n attachments=resolved_attachments,\n system=system,\n schema=schema,\n **kwargs,\n )\n async for chunk in response:\n print(chunk, end=\"\")\n sys.stdout.flush()\n print(\"\")\n else:\n response = prompt_method(\n prompt,\n attachments=resolved_attachments,\n system=system,\n schema=schema,\n **kwargs,\n )\n text = await response.text()\n if extract or extract_last:\n text = (\n extract_fenced_code_block(text, last=extract_last) or text\n )\n print(text)\n return response\n\n response = asyncio.run(inner())\n else:\n response = prompt_method(\n prompt,\n attachments=resolved_attachments,\n system=system,\n schema=schema,\n **kwargs,\n )\n if should_stream:\n for chunk in response:\n if isinstance(chunk, Chunk) and chunk.annotation:\n print(chunk.annotation)\n print(chunk, end=\"\")\n sys.stdout.flush()\n print(\"\")\n else:\n text = response.text()\n if extract or extract_last:\n text = extract_fenced_code_block(text, last=extract_last) or text\n print(text)\n # List of exceptions that should never be raised in pytest:\n except (ValueError, NotImplementedError) as ex:\n raise click.ClickException(str(ex))\n except Exception as ex:\n # All other exceptions should raise in pytest, show to user otherwise\n if getattr(sys, \"_called_from_test\", False) or os.environ.get(\n \"LLM_RAISE_ERRORS\", None\n ):\n raise\n raise click.ClickException(str(ex))\n\n if isinstance(response, AsyncResponse):\n response = asyncio.run(response.to_sync_response())\n\n if usage:\n # Show token usage to stderr in yellow\n click.echo(\n click.style(\n \"Token usage: {}\".format(response.token_usage()), fg=\"yellow\", bold=True\n ),\n err=True,\n )\n\n # Log to the database\n if (logs_on() or log) and not no_log:\n response.log_to_db(db)"} | |
{"id": "llm/cli.py:603", "code": "@cli.command()\n@click.option(\"-s\", \"--system\", help=\"System prompt to use\")\n@click.option(\"model_id\", \"-m\", \"--model\", help=\"Model to use\")\n@click.option(\n \"_continue\",\n \"-c\",\n \"--continue\",\n is_flag=True,\n flag_value=-1,\n help=\"Continue the most recent conversation.\",\n)\n@click.option(\n \"conversation_id\",\n \"--cid\",\n \"--conversation\",\n help=\"Continue the conversation with the given ID.\",\n)\n@click.option(\"-t\", \"--template\", help=\"Template to use\")\n@click.option(\n \"-p\",\n \"--param\",\n multiple=True,\n type=(str, str),\n help=\"Parameters for template\",\n)\n@click.option(\n \"options\",\n \"-o\",\n \"--option\",\n type=(str, str),\n multiple=True,\n help=\"key/value options for the model\",\n)\n@click.option(\"--no-stream\", is_flag=True, help=\"Do not stream output\")\n@click.option(\"--key\", help=\"API key to use\")\ndef chat(\n system,\n model_id,\n _continue,\n conversation_id,\n template,\n param,\n options,\n no_stream,\n key,\n):\n \"\"\"\n Hold an ongoing chat with a model.\n \"\"\"\n # Left and right arrow keys to move cursor:\n if sys.platform != \"win32\":\n readline.parse_and_bind(\"\\\\e[D: backward-char\")\n readline.parse_and_bind(\"\\\\e[C: forward-char\")\n else:\n readline.parse_and_bind(\"bind -x '\\\\e[D: backward-char'\")\n readline.parse_and_bind(\"bind -x '\\\\e[C: forward-char'\")\n log_path = logs_db_path()\n (log_path.parent).mkdir(parents=True, exist_ok=True)\n db = sqlite_utils.Database(log_path)\n migrate(db)\n\n conversation = None\n if conversation_id or _continue:\n # Load the conversation - loads most recent if no ID provided\n try:\n conversation = load_conversation(conversation_id)\n except UnknownModelError as ex:\n raise click.ClickException(str(ex))\n\n template_obj = None\n if template:\n params = dict(param)\n # Cannot be used with system\n if system:\n raise click.ClickException(\"Cannot use -t/--template and --system together\")\n template_obj = load_template(template)\n if model_id is None and template_obj.model:\n model_id = template_obj.model\n\n # Figure out which model we are using\n if model_id is None:\n if conversation:\n model_id = conversation.model.model_id\n else:\n model_id = get_default_model()\n\n # Now resolve the model\n try:\n model = get_model(model_id)\n except KeyError:\n raise click.ClickException(\"'{}' is not a known model\".format(model_id))\n\n if conversation is None:\n # Start a fresh conversation for this chat\n conversation = Conversation(model=model)\n else:\n # Ensure it can see the API key\n conversation.model = model\n\n # Validate options\n validated_options = {}\n if options:\n try:\n validated_options = dict(\n (key, value)\n for key, value in model.Options(**dict(options))\n if value is not None\n )\n except pydantic.ValidationError as ex:\n raise click.ClickException(render_errors(ex.errors()))\n\n kwargs = {}\n kwargs.update(validated_options)\n\n should_stream = model.can_stream and not no_stream\n if not should_stream:\n kwargs[\"stream\"] = False\n\n if key and isinstance(model, KeyModel):\n kwargs[\"key\"] = key\n\n click.echo(\"Chatting with {}\".format(model.model_id))\n click.echo(\"Type 'exit' or 'quit' to exit\")\n click.echo(\"Type '!multi' to enter multiple lines, then '!end' to finish\")\n in_multi = False\n accumulated = []\n end_token = \"!end\"\n while True:\n prompt = click.prompt(\"\", prompt_suffix=\"> \" if not in_multi else \"\")\n if prompt.strip().startswith(\"!multi\"):\n in_multi = True\n bits = prompt.strip().split()\n if len(bits) > 1:\n end_token = \"!end {}\".format(\" \".join(bits[1:]))\n continue\n if in_multi:\n if prompt.strip() == end_token:\n prompt = \"\\n\".join(accumulated)\n in_multi = False\n accumulated = []\n else:\n accumulated.append(prompt)\n continue\n if template_obj:\n try:\n prompt, system = template_obj.evaluate(prompt, params)\n except Template.MissingVariables as ex:\n raise click.ClickException(str(ex))\n if prompt.strip() in (\"exit\", \"quit\"):\n break\n response = conversation.prompt(prompt, system=system, **kwargs)\n # System prompt only sent for the first message:\n system = None\n for chunk in response:\n print(chunk, end=\"\")\n sys.stdout.flush()\n response.log_to_db(db)\n print(\"\")"} | |
{"id": "llm/cli.py:763", "code": "def load_conversation(\n conversation_id: Optional[str], async_=False\n) -> Optional[_BaseConversation]:\n db = sqlite_utils.Database(logs_db_path())\n migrate(db)\n if conversation_id is None:\n # Return the most recent conversation, or None if there are none\n matches = list(db[\"conversations\"].rows_where(order_by=\"id desc\", limit=1))\n if matches:\n conversation_id = matches[0][\"id\"]\n else:\n return None\n try:\n row = cast(sqlite_utils.db.Table, db[\"conversations\"]).get(conversation_id)\n except sqlite_utils.db.NotFoundError:\n raise click.ClickException(\n \"No conversation found with id={}\".format(conversation_id)\n )\n # Inflate that conversation\n conversation_class = AsyncConversation if async_ else Conversation\n response_class = AsyncResponse if async_ else Response\n conversation = conversation_class.from_row(row)\n for response in db[\"responses\"].rows_where(\n \"conversation_id = ?\", [conversation_id]\n ):\n conversation.responses.append(response_class.from_row(db, response))\n return conversation"} | |
{"id": "llm/cli.py:792", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef keys():\n \"Manage stored API keys for different models\""} | |
{"id": "llm/cli.py:801", "code": "@keys.command(name=\"list\")\ndef keys_list():\n \"List names of all stored keys\"\n path = user_dir() / \"keys.json\"\n if not path.exists():\n click.echo(\"No keys found\")\n return\n keys = json.loads(path.read_text())\n for key in sorted(keys.keys()):\n if key != \"// Note\":\n click.echo(key)"} | |
{"id": "llm/cli.py:814", "code": "@keys.command(name=\"path\")\ndef keys_path_command():\n \"Output the path to the keys.json file\"\n click.echo(user_dir() / \"keys.json\")"} | |
{"id": "llm/cli.py:820", "code": "@keys.command(name=\"get\")\n@click.argument(\"name\")\ndef keys_get(name):\n \"\"\"\n Return the value of a stored key\n\n Example usage:\n\n \\b\n export OPENAI_API_KEY=$(llm keys get openai)\n \"\"\"\n path = user_dir() / \"keys.json\"\n if not path.exists():\n raise click.ClickException(\"No keys found\")\n keys = json.loads(path.read_text())\n try:\n click.echo(keys[name])\n except KeyError:\n raise click.ClickException(\"No key found with name '{}'\".format(name))"} | |
{"id": "llm/cli.py:841", "code": "@keys.command(name=\"set\")\n@click.argument(\"name\")\n@click.option(\"--value\", prompt=\"Enter key\", hide_input=True, help=\"Value to set\")\ndef keys_set(name, value):\n \"\"\"\n Save a key in the keys.json file\n\n Example usage:\n\n \\b\n $ llm keys set openai\n Enter key: ...\n \"\"\"\n default = {\"// Note\": \"This file stores secret API credentials. Do not share!\"}\n path = user_dir() / \"keys.json\"\n path.parent.mkdir(parents=True, exist_ok=True)\n if not path.exists():\n path.write_text(json.dumps(default))\n path.chmod(0o600)\n try:\n current = json.loads(path.read_text())\n except json.decoder.JSONDecodeError:\n current = default\n current[name] = value\n path.write_text(json.dumps(current, indent=2) + \"\\n\")"} | |
{"id": "llm/cli.py:868", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef logs():\n \"Tools for exploring logged prompts and responses\""} | |
{"id": "llm/cli.py:877", "code": "@logs.command(name=\"path\")\ndef logs_path():\n \"Output the path to the logs.db file\"\n click.echo(logs_db_path())"} | |
{"id": "llm/cli.py:883", "code": "@logs.command(name=\"status\")\ndef logs_status():\n \"Show current status of database logging\"\n path = logs_db_path()\n if not path.exists():\n click.echo(\"No log database found at {}\".format(path))\n return\n if logs_on():\n click.echo(\"Logging is ON for all prompts\".format())\n else:\n click.echo(\"Logging is OFF\".format())\n db = sqlite_utils.Database(path)\n migrate(db)\n click.echo(\"Found log database at {}\".format(path))\n click.echo(\"Number of conversations logged:\\t{}\".format(db[\"conversations\"].count))\n click.echo(\"Number of responses logged:\\t{}\".format(db[\"responses\"].count))\n click.echo(\n \"Database file size: \\t\\t{}\".format(_human_readable_size(path.stat().st_size))\n )"} | |
{"id": "llm/cli.py:904", "code": "@logs.command(name=\"on\")\ndef logs_turn_on():\n \"Turn on logging for all prompts\"\n path = user_dir() / \"logs-off\"\n if path.exists():\n path.unlink()"} | |
{"id": "llm/cli.py:912", "code": "@logs.command(name=\"off\")\ndef logs_turn_off():\n \"Turn off logging for all prompts\"\n path = user_dir() / \"logs-off\"\n path.touch()"} | |
{"id": "llm/cli.py:974", "code": "@logs.command(name=\"list\")\n@click.option(\n \"-n\",\n \"--count\",\n type=int,\n default=None,\n help=\"Number of entries to show - defaults to 3, use 0 for all\",\n)\n@click.option(\n \"-p\",\n \"--path\",\n type=click.Path(readable=True, exists=True, dir_okay=False),\n help=\"Path to log database\",\n)\n@click.option(\"-m\", \"--model\", help=\"Filter by model or model alias\")\n@click.option(\"-q\", \"--query\", help=\"Search for logs matching this string\")\n@schema_option\n@click.option(\n \"--schema-multi\",\n help=\"JSON schema used for multiple results\",\n)\n@click.option(\n \"--data\", is_flag=True, help=\"Output newline-delimited JSON data for schema\"\n)\n@click.option(\"--data-array\", is_flag=True, help=\"Output JSON array of data for schema\")\n@click.option(\"--data-key\", help=\"Return JSON objects from array in this key\")\n@click.option(\n \"--data-ids\", is_flag=True, help=\"Attach corresponding IDs to JSON objects\"\n)\n@click.option(\"-t\", \"--truncate\", is_flag=True, help=\"Truncate long strings in output\")\n@click.option(\n \"-s\", \"--short\", is_flag=True, help=\"Shorter YAML output with truncated prompts\"\n)\n@click.option(\"-u\", \"--usage\", is_flag=True, help=\"Include token usage\")\n@click.option(\"-r\", \"--response\", is_flag=True, help=\"Just output the last response\")\n@click.option(\"-x\", \"--extract\", is_flag=True, help=\"Extract first fenced code block\")\n@click.option(\n \"extract_last\",\n \"--xl\",\n \"--extract-last\",\n is_flag=True,\n help=\"Extract last fenced code block\",\n)\n@click.option(\n \"current_conversation\",\n \"-c\",\n \"--current\",\n is_flag=True,\n flag_value=-1,\n help=\"Show logs from the current conversation\",\n)\n@click.option(\n \"conversation_id\",\n \"--cid\",\n \"--conversation\",\n help=\"Show logs for this conversation ID\",\n)\n@click.option(\"--id-gt\", help=\"Return responses with ID > this\")\n@click.option(\"--id-gte\", help=\"Return responses with ID >= this\")\n@click.option(\n \"json_output\",\n \"--json\",\n is_flag=True,\n help=\"Output logs as JSON\",\n)\ndef logs_list(\n count,\n path,\n model,\n query,\n schema_input,\n schema_multi,\n data,\n data_array,\n data_key,\n data_ids,\n truncate,\n short,\n usage,\n response,\n extract,\n extract_last,\n current_conversation,\n conversation_id,\n id_gt,\n id_gte,\n json_output,\n):\n \"Show recent logged prompts and their responses\"\n path = pathlib.Path(path or logs_db_path())\n if not path.exists():\n raise click.ClickException(\"No log database found at {}\".format(path))\n db = sqlite_utils.Database(path)\n migrate(db)\n\n if schema_multi:\n schema_input = schema_multi\n schema = resolve_schema_input(db, schema_input, load_template)\n if schema_multi:\n schema = multi_schema(schema)\n\n if short and (json_output or response):\n invalid = \" or \".join(\n [\n flag[0]\n for flag in ((\"--json\", json_output), (\"--response\", response))\n if flag[1]\n ]\n )\n raise click.ClickException(\"Cannot use --short and {} together\".format(invalid))\n\n if response and not current_conversation and not conversation_id:\n current_conversation = True\n\n if current_conversation:\n try:\n conversation_id = next(\n db.query(\n \"select conversation_id from responses order by id desc limit 1\"\n )\n )[\"conversation_id\"]\n except StopIteration:\n # No conversations yet\n raise click.ClickException(\"No conversations found\")\n\n # For --conversation set limit 0, if not explicitly set\n if count is None:\n if conversation_id:\n count = 0\n else:\n count = 3\n\n model_id = None\n if model:\n # Resolve alias, if any\n try:\n model_id = get_model(model).model_id\n except UnknownModelError:\n # Maybe they uninstalled a model, use the -m option as-is\n model_id = model\n\n sql = LOGS_SQL\n if query:\n sql = LOGS_SQL_SEARCH\n\n limit = \"\"\n if count is not None and count > 0:\n limit = \" limit {}\".format(count)\n\n sql_format = {\n \"limit\": limit,\n \"columns\": LOGS_COLUMNS,\n \"extra_where\": \"\",\n }\n where_bits = []\n if model_id:\n where_bits.append(\"responses.model = :model\")\n if conversation_id:\n where_bits.append(\"responses.conversation_id = :conversation_id\")\n if id_gt:\n where_bits.append(\"responses.id > :id_gt\")\n if id_gte:\n where_bits.append(\"responses.id >= :id_gte\")\n schema_id = None\n if schema:\n schema_id = make_schema_id(schema)[0]\n where_bits.append(\"responses.schema_id = :schema_id\")\n\n if where_bits:\n where_ = \" and \" if query else \" where \"\n sql_format[\"extra_where\"] = where_ + \" and \".join(where_bits)\n\n final_sql = sql.format(**sql_format)\n rows = list(\n db.query(\n final_sql,\n {\n \"model\": model_id,\n \"query\": query,\n \"conversation_id\": conversation_id,\n \"schema_id\": schema_id,\n \"id_gt\": id_gt,\n \"id_gte\": id_gte,\n },\n )\n )\n\n # Reverse the order - we do this because we 'order by id desc limit 3' to get the\n # 3 most recent results, but we still want to display them in chronological order\n # ... except for searches where we don't do this\n if not query and not data:\n rows.reverse()\n\n # Fetch any attachments\n ids = [row[\"id\"] for row in rows]\n attachments = list(db.query(ATTACHMENTS_SQL.format(\",\".join(\"?\" * len(ids))), ids))\n attachments_by_id = {}\n for attachment in attachments:\n attachments_by_id.setdefault(attachment[\"response_id\"], []).append(attachment)\n\n if data or data_array or data_key or data_ids:\n # Special case for --data to output valid JSON\n to_output = []\n for row in rows:\n response = row[\"response\"] or \"\"\n try:\n decoded = json.loads(response)\n new_items = []\n if (\n isinstance(decoded, dict)\n and (data_key in decoded)\n and all(isinstance(item, dict) for item in decoded[data_key])\n ):\n for item in decoded[data_key]:\n new_items.append(item)\n else:\n new_items.append(decoded)\n if data_ids:\n for item in new_items:\n item[find_unused_key(item, \"response_id\")] = row[\"id\"]\n item[find_unused_key(item, \"conversation_id\")] = row[\"id\"]\n to_output.extend(new_items)\n except ValueError:\n pass\n click.echo(output_rows_as_json(to_output, not data_array))\n return\n\n for row in rows:\n if truncate:\n row[\"prompt\"] = truncate_string(row[\"prompt\"] or \"\")\n row[\"response\"] = truncate_string(row[\"response\"] or \"\")\n # Either decode or remove all JSON keys\n keys = list(row.keys())\n for key in keys:\n if key.endswith(\"_json\") and row[key] is not None:\n if truncate:\n del row[key]\n else:\n row[key] = json.loads(row[key])\n\n output = None\n if json_output:\n # Output as JSON if requested\n for row in rows:\n row[\"attachments\"] = [\n {k: v for k, v in attachment.items() if k != \"response_id\"}\n for attachment in attachments_by_id.get(row[\"id\"], [])\n ]\n output = json.dumps(list(rows), indent=2)\n elif extract or extract_last:\n # Extract and return first code block\n for row in rows:\n output = extract_fenced_code_block(row[\"response\"], last=extract_last)\n if output is not None:\n break\n elif response:\n # Just output the last response\n if rows:\n output = rows[-1][\"response\"]\n\n if output is not None:\n click.echo(output)\n else:\n # Output neatly formatted human-readable logs\n current_system = None\n should_show_conversation = True\n for row in rows:\n if short:\n system = truncate_string(\n row[\"system\"] or \"\", 120, normalize_whitespace=True\n )\n prompt = truncate_string(\n row[\"prompt\"] or \"\", 120, normalize_whitespace=True, keep_end=True\n )\n cid = row[\"conversation_id\"]\n attachments = attachments_by_id.get(row[\"id\"])\n obj = {\n \"model\": row[\"model\"],\n \"datetime\": row[\"datetime_utc\"].split(\".\")[0],\n \"conversation\": cid,\n }\n if system:\n obj[\"system\"] = system\n if prompt:\n obj[\"prompt\"] = prompt\n if attachments:\n items = []\n for attachment in attachments:\n details = {\"type\": attachment[\"type\"]}\n if attachment.get(\"path\"):\n details[\"path\"] = attachment[\"path\"]\n if attachment.get(\"url\"):\n details[\"url\"] = attachment[\"url\"]\n items.append(details)\n obj[\"attachments\"] = items\n if usage and (row[\"input_tokens\"] or row[\"output_tokens\"]):\n usage_details = {\n \"input\": row[\"input_tokens\"],\n \"output\": row[\"output_tokens\"],\n }\n if row[\"token_details\"]:\n usage_details[\"details\"] = json.loads(row[\"token_details\"])\n obj[\"usage\"] = usage_details\n click.echo(yaml.dump([obj], sort_keys=False).strip())\n continue\n click.echo(\n \"# {}{}\\n{}\".format(\n row[\"datetime_utc\"].split(\".\")[0],\n (\n \" conversation: {} id: {}\".format(\n row[\"conversation_id\"], row[\"id\"]\n )\n if should_show_conversation\n else \"\"\n ),\n (\n \"\\nModel: **{}**\\n\".format(row[\"model\"])\n if should_show_conversation\n else \"\"\n ),\n )\n )\n # In conversation log mode only show it for the first one\n if conversation_id:\n should_show_conversation = False\n click.echo(\"## Prompt\\n\\n{}\".format(row[\"prompt\"] or \"-- none --\"))\n if row[\"system\"] != current_system:\n if row[\"system\"] is not None:\n click.echo(\"\\n## System\\n\\n{}\".format(row[\"system\"]))\n current_system = row[\"system\"]\n if row[\"schema_json\"]:\n click.echo(\n \"\\n## Schema\\n\\n```json\\n{}\\n```\".format(\n json.dumps(row[\"schema_json\"], indent=2)\n )\n )\n attachments = attachments_by_id.get(row[\"id\"])\n if attachments:\n click.echo(\"\\n### Attachments\\n\")\n for i, attachment in enumerate(attachments, 1):\n if attachment[\"path\"]:\n path = attachment[\"path\"]\n click.echo(\n \"{}. **{}**: `{}`\".format(i, attachment[\"type\"], path)\n )\n elif attachment[\"url\"]:\n click.echo(\n \"{}. **{}**: {}\".format(\n i, attachment[\"type\"], attachment[\"url\"]\n )\n )\n elif attachment[\"content_length\"]:\n click.echo(\n \"{}. **{}**: `<{} bytes>`\".format(\n i,\n attachment[\"type\"],\n f\"{attachment['content_length']:,}\",\n )\n )\n\n # If a schema was provided and the row is valid JSON, pretty print and syntax highlight it\n response = row[\"response\"]\n if row[\"schema_json\"]:\n try:\n parsed = json.loads(response)\n response = \"```json\\n{}\\n```\".format(json.dumps(parsed, indent=2))\n except ValueError:\n pass\n click.echo(\"\\n## Response\\n\\n{}\\n\".format(response))\n if usage:\n token_usage = token_usage_string(\n row[\"input_tokens\"],\n row[\"output_tokens\"],\n json.loads(row[\"token_details\"]) if row[\"token_details\"] else None,\n )\n if token_usage:\n click.echo(\"## Token usage:\\n\\n{}\\n\".format(token_usage))"} | |
{"id": "llm/cli.py:1353", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef models():\n \"Manage available models\""} | |
{"id": "llm/cli.py:1370", "code": "@models.command(name=\"list\")\n@click.option(\n \"--options\", is_flag=True, help=\"Show options for each model, if available\"\n)\n@click.option(\"async_\", \"--async\", is_flag=True, help=\"List async models\")\n@click.option(\"--schemas\", is_flag=True, help=\"List models that support schemas\")\n@click.option(\n \"-q\",\n \"--query\",\n multiple=True,\n help=\"Search for models matching these strings\",\n)\n@click.option(\"model_ids\", \"-m\", \"--model\", help=\"Specific model IDs\", multiple=True)\ndef models_list(options, async_, schemas, query, model_ids):\n \"List available models\"\n models_that_have_shown_options = set()\n for model_with_aliases in get_models_with_aliases():\n if async_ and not model_with_aliases.async_model:\n continue\n if query:\n # Only show models where every provided query string matches\n if not all(model_with_aliases.matches(q) for q in query):\n continue\n if model_ids:\n ids_and_aliases = set(\n [model_with_aliases.model.model_id] + model_with_aliases.aliases\n )\n if not ids_and_aliases.intersection(model_ids):\n continue\n if schemas and not model_with_aliases.model.supports_schema:\n continue\n extra = \"\"\n if model_with_aliases.aliases:\n extra = \" (aliases: {})\".format(\", \".join(model_with_aliases.aliases))\n model = (\n model_with_aliases.model if not async_ else model_with_aliases.async_model\n )\n output = str(model) + extra\n if options and model.Options.model_json_schema()[\"properties\"]:\n output += \"\\n Options:\"\n for name, field in model.Options.model_json_schema()[\"properties\"].items():\n any_of = field.get(\"anyOf\")\n if any_of is None:\n any_of = [{\"type\": field.get(\"type\", \"str\")}]\n types = \", \".join(\n [\n _type_lookup.get(item.get(\"type\"), item.get(\"type\", \"str\"))\n for item in any_of\n if item.get(\"type\") != \"null\"\n ]\n )\n bits = [\"\\n \", name, \": \", types]\n description = field.get(\"description\", \"\")\n if description and (\n model.__class__ not in models_that_have_shown_options\n ):\n wrapped = textwrap.wrap(description, 70)\n bits.append(\"\\n \")\n bits.extend(\"\\n \".join(wrapped))\n output += \"\".join(bits)\n models_that_have_shown_options.add(model.__class__)\n if options and model.attachment_types:\n attachment_types = \", \".join(sorted(model.attachment_types))\n wrapper = textwrap.TextWrapper(\n width=min(max(shutil.get_terminal_size().columns, 30), 70),\n initial_indent=\" \",\n subsequent_indent=\" \",\n )\n output += \"\\n Attachment types:\\n{}\".format(wrapper.fill(attachment_types))\n features = (\n []\n + ([\"streaming\"] if model.can_stream else [])\n + ([\"schemas\"] if model.supports_schema else [])\n + ([\"async\"] if model_with_aliases.async_model else [])\n )\n if options and features:\n output += \"\\n Features:\\n{}\".format(\n \"\\n\".join(\" - {}\".format(feature) for feature in features)\n )\n click.echo(output)\n if not query and not options and not schemas and not model_ids:\n click.echo(f\"Default: {get_default_model()}\")"} | |
{"id": "llm/cli.py:1454", "code": "@models.command(name=\"default\")\n@click.argument(\"model\", required=False)\ndef models_default(model):\n \"Show or set the default model\"\n if not model:\n click.echo(get_default_model())\n return\n # Validate it is a known model\n try:\n model = get_model(model)\n set_default_model(model.model_id)\n except KeyError:\n raise click.ClickException(\"Unknown model: {}\".format(model))"} | |
{"id": "llm/cli.py:1469", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef templates():\n \"Manage stored prompt templates\""} | |
{"id": "llm/cli.py:1478", "code": "@templates.command(name=\"list\")\ndef templates_list():\n \"List available prompt templates\"\n path = template_dir()\n pairs = []\n for file in path.glob(\"*.yaml\"):\n name = file.stem\n template = load_template(name)\n text = []\n if template.system:\n text.append(f\"system: {template.system}\")\n if template.prompt:\n text.append(f\" prompt: {template.prompt}\")\n else:\n text = [template.prompt if template.prompt else \"\"]\n pairs.append((name, \"\".join(text).replace(\"\\n\", \" \")))\n try:\n max_name_len = max(len(p[0]) for p in pairs)\n except ValueError:\n return\n else:\n fmt = \"{name:<\" + str(max_name_len) + \"} : {prompt}\"\n for name, prompt in sorted(pairs):\n text = fmt.format(name=name, prompt=prompt)\n click.echo(display_truncated(text))"} | |
{"id": "llm/cli.py:1505", "code": "@templates.command(name=\"show\")\n@click.argument(\"name\")\ndef templates_show(name):\n \"Show the specified prompt template\"\n template = load_template(name)\n click.echo(\n yaml.dump(\n dict((k, v) for k, v in template.model_dump().items() if v is not None),\n indent=4,\n default_flow_style=False,\n )\n )"} | |
{"id": "llm/cli.py:1519", "code": "@templates.command(name=\"edit\")\n@click.argument(\"name\")\ndef templates_edit(name):\n \"Edit the specified prompt template using the default $EDITOR\"\n # First ensure it exists\n path = template_dir() / f\"{name}.yaml\"\n if not path.exists():\n path.write_text(DEFAULT_TEMPLATE, \"utf-8\")\n click.edit(filename=path)\n # Validate that template\n load_template(name)"} | |
{"id": "llm/cli.py:1532", "code": "@templates.command(name=\"path\")\ndef templates_path():\n \"Output the path to the templates directory\"\n click.echo(template_dir())"} | |
{"id": "llm/cli.py:1538", "code": "@templates.command(name=\"loaders\")\ndef templates_loaders():\n \"Show template loaders registered by plugins\"\n found = False\n for prefix, loader in get_template_loaders().items():\n found = True\n docs = \"Undocumented\"\n if loader.__doc__:\n docs = textwrap.dedent(loader.__doc__).strip()\n click.echo(f\"{prefix}:\")\n click.echo(textwrap.indent(docs, \" \"))\n if not found:\n click.echo(\"No template loaders found\")"} | |
{"id": "llm/cli.py:1553", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef schemas():\n \"Manage stored schemas\""} | |
{"id": "llm/cli.py:1562", "code": "@schemas.command(name=\"list\")\n@click.option(\n \"-p\",\n \"--path\",\n type=click.Path(readable=True, exists=True, dir_okay=False),\n help=\"Path to log database\",\n)\n@click.option(\n \"queries\",\n \"-q\",\n \"--query\",\n multiple=True,\n help=\"Search for schemas matching this string\",\n)\n@click.option(\"--full\", is_flag=True, help=\"Output full schema contents\")\ndef schemas_list(path, queries, full):\n \"List stored schemas\"\n path = pathlib.Path(path or logs_db_path())\n if not path.exists():\n raise click.ClickException(\"No log database found at {}\".format(path))\n db = sqlite_utils.Database(path)\n migrate(db)\n\n params = []\n where_sql = \"\"\n if queries:\n where_bits = [\"schemas.content like ?\" for _ in queries]\n where_sql += \" where {}\".format(\" and \".join(where_bits))\n params.extend(\"%{}%\".format(q) for q in queries)\n\n sql = \"\"\"\n select\n schemas.id,\n schemas.content,\n max(responses.datetime_utc) as recently_used,\n count(*) as times_used\n from schemas\n join responses\n on responses.schema_id = schemas.id\n {} group by responses.schema_id\n order by recently_used\n \"\"\".format(\n where_sql\n )\n rows = db.query(sql, params)\n for row in rows:\n click.echo(\"- id: {}\".format(row[\"id\"]))\n if full:\n click.echo(\n \" schema: |\\n{}\".format(\n textwrap.indent(\n json.dumps(json.loads(row[\"content\"]), indent=2), \" \"\n )\n )\n )\n else:\n click.echo(\n \" summary: |\\n {}\".format(\n schema_summary(json.loads(row[\"content\"]))\n )\n )\n click.echo(\n \" usage: |\\n {} time{}, most recently {}\".format(\n row[\"times_used\"],\n \"s\" if row[\"times_used\"] != 1 else \"\",\n row[\"recently_used\"],\n )\n )"} | |
{"id": "llm/cli.py:1632", "code": "@schemas.command(name=\"show\")\n@click.argument(\"schema_id\")\n@click.option(\n \"-p\",\n \"--path\",\n type=click.Path(readable=True, exists=True, dir_okay=False),\n help=\"Path to log database\",\n)\ndef schemas_show(schema_id, path):\n \"Show a stored schema\"\n path = pathlib.Path(path or logs_db_path())\n if not path.exists():\n raise click.ClickException(\"No log database found at {}\".format(path))\n db = sqlite_utils.Database(path)\n migrate(db)\n\n try:\n row = db[\"schemas\"].get(schema_id)\n except sqlite_utils.db.NotFoundError:\n raise click.ClickException(\"Invalid schema ID\")\n click.echo(json.dumps(json.loads(row[\"content\"]), indent=2))"} | |
{"id": "llm/cli.py:1655", "code": "@schemas.command(name=\"dsl\")\n@click.argument(\"input\")\n@click.option(\"--multi\", is_flag=True, help=\"Wrap in an array\")\ndef schemas_dsl_debug(input, multi):\n \"\"\"\n Convert LLM's schema DSL to a JSON schema\n\n \\b\n llm schema dsl 'name, age int, bio: their bio'\n \"\"\"\n schema = schema_dsl(input, multi)\n click.echo(json.dumps(schema, indent=2))"} | |
{"id": "llm/cli.py:1669", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef aliases():\n \"Manage model aliases\""} | |
{"id": "llm/cli.py:1678", "code": "@aliases.command(name=\"list\")\n@click.option(\"json_\", \"--json\", is_flag=True, help=\"Output as JSON\")\ndef aliases_list(json_):\n \"List current aliases\"\n to_output = []\n for alias, model in get_model_aliases().items():\n if alias != model.model_id:\n to_output.append((alias, model.model_id, \"\"))\n for alias, embedding_model in get_embedding_model_aliases().items():\n if alias != embedding_model.model_id:\n to_output.append((alias, embedding_model.model_id, \"embedding\"))\n if json_:\n click.echo(\n json.dumps({key: value for key, value, type_ in to_output}, indent=4)\n )\n return\n max_alias_length = max(len(a) for a, _, _ in to_output)\n fmt = \"{alias:<\" + str(max_alias_length) + \"} : {model_id}{type_}\"\n for alias, model_id, type_ in to_output:\n click.echo(\n fmt.format(\n alias=alias, model_id=model_id, type_=f\" ({type_})\" if type_ else \"\"\n )\n )"} | |
{"id": "llm/cli.py:1704", "code": "@aliases.command(name=\"set\")\n@click.argument(\"alias\")\n@click.argument(\"model_id\", required=False)\n@click.option(\n \"-q\",\n \"--query\",\n multiple=True,\n help=\"Set alias for model matching these strings\",\n)\ndef aliases_set(alias, model_id, query):\n \"\"\"\n Set an alias for a model\n\n Example usage:\n\n \\b\n llm aliases set mini gpt-4o-mini\n\n Alternatively you can omit the model ID and specify one or more -q options.\n The first model matching all of those query strings will be used.\n\n \\b\n llm aliases set mini -q 4o -q mini\n \"\"\"\n if not model_id:\n if not query:\n raise click.ClickException(\n \"You must provide a model_id or at least one -q option\"\n )\n # Search for the first model matching all query strings\n found = None\n for model_with_aliases in get_models_with_aliases():\n if all(model_with_aliases.matches(q) for q in query):\n found = model_with_aliases\n break\n if not found:\n raise click.ClickException(\n \"No model found matching query: \" + \", \".join(query)\n )\n model_id = found.model.model_id\n set_alias(alias, model_id)\n click.echo(\n f\"Alias '{alias}' set to model '{model_id}'\",\n err=True,\n )\n else:\n set_alias(alias, model_id)"} | |
{"id": "llm/cli.py:1753", "code": "@aliases.command(name=\"remove\")\n@click.argument(\"alias\")\ndef aliases_remove(alias):\n \"\"\"\n Remove an alias\n\n Example usage:\n\n \\b\n $ llm aliases remove turbo\n \"\"\"\n try:\n remove_alias(alias)\n except KeyError as ex:\n raise click.ClickException(ex.args[0])"} | |
{"id": "llm/cli.py:1770", "code": "@aliases.command(name=\"path\")\ndef aliases_path():\n \"Output the path to the aliases.json file\"\n click.echo(user_dir() / \"aliases.json\")"} | |
{"id": "llm/cli.py:1776", "code": "@cli.command(name=\"plugins\")\n@click.option(\"--all\", help=\"Include built-in default plugins\", is_flag=True)\ndef plugins_list(all):\n \"List installed plugins\"\n click.echo(json.dumps(get_plugins(all), indent=2))"} | |
{"id": "llm/cli.py:1783", "code": "def display_truncated(text):\n console_width = shutil.get_terminal_size()[0]\n if len(text) > console_width:\n return text[: console_width - 3] + \"...\"\n else:\n return text"} | |
{"id": "llm/cli.py:1791", "code": "@cli.command()\n@click.argument(\"packages\", nargs=-1, required=False)\n@click.option(\n \"-U\", \"--upgrade\", is_flag=True, help=\"Upgrade packages to latest version\"\n)\n@click.option(\n \"-e\",\n \"--editable\",\n help=\"Install a project in editable mode from this path\",\n)\n@click.option(\n \"--force-reinstall\",\n is_flag=True,\n help=\"Reinstall all packages even if they are already up-to-date\",\n)\n@click.option(\n \"--no-cache-dir\",\n is_flag=True,\n help=\"Disable the cache\",\n)\ndef install(packages, upgrade, editable, force_reinstall, no_cache_dir):\n \"\"\"Install packages from PyPI into the same environment as LLM\"\"\"\n args = [\"pip\", \"install\"]\n if upgrade:\n args += [\"--upgrade\"]\n if editable:\n args += [\"--editable\", editable]\n if force_reinstall:\n args += [\"--force-reinstall\"]\n if no_cache_dir:\n args += [\"--no-cache-dir\"]\n args += list(packages)\n sys.argv = args\n run_module(\"pip\", run_name=\"__main__\")"} | |
{"id": "llm/cli.py:1827", "code": "@cli.command()\n@click.argument(\"packages\", nargs=-1, required=True)\n@click.option(\"-y\", \"--yes\", is_flag=True, help=\"Don't ask for confirmation\")\ndef uninstall(packages, yes):\n \"\"\"Uninstall Python packages from the LLM environment\"\"\"\n sys.argv = [\"pip\", \"uninstall\"] + list(packages) + ([\"-y\"] if yes else [])\n run_module(\"pip\", run_name=\"__main__\")"} | |
{"id": "llm/cli.py:1836", "code": "@cli.command()\n@click.argument(\"collection\", required=False)\n@click.argument(\"id\", required=False)\n@click.option(\n \"-i\",\n \"--input\",\n type=click.Path(exists=True, readable=True, allow_dash=True),\n help=\"File to embed\",\n)\n@click.option(\"-m\", \"--model\", help=\"Embedding model to use\")\n@click.option(\"--store\", is_flag=True, help=\"Store the text itself in the database\")\n@click.option(\n \"-d\",\n \"--database\",\n type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),\n envvar=\"LLM_EMBEDDINGS_DB\",\n)\n@click.option(\n \"-c\",\n \"--content\",\n help=\"Content to embed\",\n)\n@click.option(\"--binary\", is_flag=True, help=\"Treat input as binary data\")\n@click.option(\n \"--metadata\",\n help=\"JSON object metadata to store\",\n callback=json_validator(\"metadata\"),\n)\n@click.option(\n \"format_\",\n \"-f\",\n \"--format\",\n type=click.Choice([\"json\", \"blob\", \"base64\", \"hex\"]),\n help=\"Output format\",\n)\ndef embed(\n collection, id, input, model, store, database, content, binary, metadata, format_\n):\n \"\"\"Embed text and store or return the result\"\"\"\n if collection and not id:\n raise click.ClickException(\"Must provide both collection and id\")\n\n if store and not collection:\n raise click.ClickException(\"Must provide collection when using --store\")\n\n # Lazy load this because we do not need it for -c or -i versions\n def get_db():\n if database:\n return sqlite_utils.Database(database)\n else:\n return sqlite_utils.Database(user_dir() / \"embeddings.db\")\n\n collection_obj = None\n model_obj = None\n if collection:\n db = get_db()\n if Collection.exists(db, collection):\n # Load existing collection and use its model\n collection_obj = Collection(collection, db)\n model_obj = collection_obj.model()\n else:\n # We will create a new one, but that means model is required\n if not model:\n model = get_default_embedding_model()\n if model is None:\n raise click.ClickException(\n \"You need to specify an embedding model (no default model is set)\"\n )\n collection_obj = Collection(collection, db=db, model_id=model)\n model_obj = collection_obj.model()\n\n if model_obj is None:\n if model is None:\n model = get_default_embedding_model()\n try:\n model_obj = get_embedding_model(model)\n except UnknownModelError:\n raise click.ClickException(\n \"You need to specify an embedding model (no default model is set)\"\n )\n\n show_output = True\n if collection and (format_ is None):\n show_output = False\n\n # Resolve input text\n if not content:\n if not input or input == \"-\":\n # Read from stdin\n input_source = sys.stdin.buffer if binary else sys.stdin\n content = input_source.read()\n else:\n mode = \"rb\" if binary else \"r\"\n with open(input, mode) as f:\n content = f.read()\n\n if not content:\n raise click.ClickException(\"No content provided\")\n\n if collection_obj:\n embedding = collection_obj.embed(id, content, metadata=metadata, store=store)\n else:\n embedding = model_obj.embed(content)\n\n if show_output:\n if format_ == \"json\" or format_ is None:\n click.echo(json.dumps(embedding))\n elif format_ == \"blob\":\n click.echo(encode(embedding))\n elif format_ == \"base64\":\n click.echo(base64.b64encode(encode(embedding)).decode(\"ascii\"))\n elif format_ == \"hex\":\n click.echo(encode(embedding).hex())"} | |
{"id": "llm/cli.py:1951", "code": "@cli.command()\n@click.argument(\"collection\")\n@click.argument(\n \"input_path\",\n type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True),\n required=False,\n)\n@click.option(\n \"--format\",\n type=click.Choice([\"json\", \"csv\", \"tsv\", \"nl\"]),\n help=\"Format of input file - defaults to auto-detect\",\n)\n@click.option(\n \"--files\",\n type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str),\n multiple=True,\n help=\"Embed files in this directory - specify directory and glob pattern\",\n)\n@click.option(\n \"encodings\",\n \"--encoding\",\n help=\"Encodings to try when reading --files\",\n multiple=True,\n)\n@click.option(\"--binary\", is_flag=True, help=\"Treat --files as binary data\")\n@click.option(\"--sql\", help=\"Read input using this SQL query\")\n@click.option(\n \"--attach\",\n type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),\n multiple=True,\n help=\"Additional databases to attach - specify alias and file path\",\n)\n@click.option(\n \"--batch-size\", type=int, help=\"Batch size to use when running embeddings\"\n)\n@click.option(\"--prefix\", help=\"Prefix to add to the IDs\", default=\"\")\n@click.option(\"-m\", \"--model\", help=\"Embedding model to use\")\n@click.option(\n \"--prepend\",\n help=\"Prepend this string to all content before embedding\",\n)\n@click.option(\"--store\", is_flag=True, help=\"Store the text itself in the database\")\n@click.option(\n \"-d\",\n \"--database\",\n type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),\n envvar=\"LLM_EMBEDDINGS_DB\",\n)\ndef embed_multi(\n collection,\n input_path,\n format,\n files,\n encodings,\n binary,\n sql,\n attach,\n batch_size,\n prefix,\n model,\n prepend,\n store,\n database,\n):\n \"\"\"\n Store embeddings for multiple strings at once in the specified collection.\n\n Input data can come from one of three sources:\n\n \\b\n 1. A CSV, TSV, JSON or JSONL file:\n - CSV/TSV: First column is ID, remaining columns concatenated as content\n - JSON: Array of objects with \"id\" field and content fields\n - JSONL: Newline-delimited JSON objects\n\n \\b\n Examples:\n llm embed-multi docs input.csv\n cat data.json | llm embed-multi docs -\n llm embed-multi docs input.json --format json\n\n \\b\n 2. A SQL query against a SQLite database:\n - First column returned is used as ID\n - Other columns concatenated to form content\n\n \\b\n Examples:\n llm embed-multi docs --sql \"SELECT id, title, body FROM posts\"\n llm embed-multi docs --attach blog blog.db --sql \"SELECT id, content FROM blog.posts\"\n\n \\b\n 3. Files in directories matching glob patterns:\n - Each file becomes one embedding\n - Relative file paths become IDs\n\n \\b\n Examples:\n llm embed-multi docs --files docs '**/*.md'\n llm embed-multi images --files photos '*.jpg' --binary\n llm embed-multi texts --files texts '*.txt' --encoding utf-8 --encoding latin-1\n \"\"\"\n if binary and not files:\n raise click.UsageError(\"--binary must be used with --files\")\n if binary and encodings:\n raise click.UsageError(\"--binary cannot be used with --encoding\")\n if not input_path and not sql and not files:\n raise click.UsageError(\"Either --sql or input path or --files is required\")\n\n if files:\n if input_path or sql or format:\n raise click.UsageError(\n \"Cannot use --files with --sql, input path or --format\"\n )\n\n if database:\n db = sqlite_utils.Database(database)\n else:\n db = sqlite_utils.Database(user_dir() / \"embeddings.db\")\n\n for alias, attach_path in attach:\n db.attach(alias, attach_path)\n\n try:\n collection_obj = Collection(\n collection, db=db, model_id=model or get_default_embedding_model()\n )\n except ValueError:\n raise click.ClickException(\n \"You need to specify an embedding model (no default model is set)\"\n )\n\n expected_length = None\n if files:\n encodings = encodings or (\"utf-8\", \"latin-1\")\n\n def count_files():\n i = 0\n for directory, pattern in files:\n for path in pathlib.Path(directory).glob(pattern):\n i += 1\n return i\n\n def iterate_files():\n for directory, pattern in files:\n p = pathlib.Path(directory)\n if not p.exists() or not p.is_dir():\n # fixes issue/274 - raise error if directory does not exist\n raise click.UsageError(f\"Invalid directory: {directory}\")\n for path in pathlib.Path(directory).glob(pattern):\n if path.is_dir():\n continue # fixed issue/280 - skip directories\n relative = path.relative_to(directory)\n content = None\n if binary:\n content = path.read_bytes()\n else:\n for encoding in encodings:\n try:\n content = path.read_text(encoding=encoding)\n except UnicodeDecodeError:\n continue\n if content is None:\n # Log to stderr\n click.echo(\n \"Could not decode text in file {}\".format(path),\n err=True,\n )\n else:\n yield {\"id\": str(relative), \"content\": content}\n\n expected_length = count_files()\n rows = iterate_files()\n elif sql:\n rows = db.query(sql)\n count_sql = \"select count(*) as c from ({})\".format(sql)\n expected_length = next(db.query(count_sql))[\"c\"]\n else:\n\n def load_rows(fp):\n return rows_from_file(fp, Format[format.upper()] if format else None)[0]\n\n try:\n if input_path != \"-\":\n # Read the file twice - first time is to get a count\n expected_length = 0\n with open(input_path, \"rb\") as fp:\n for _ in load_rows(fp):\n expected_length += 1\n\n rows = load_rows(\n open(input_path, \"rb\")\n if input_path != \"-\"\n else io.BufferedReader(sys.stdin.buffer)\n )\n except json.JSONDecodeError as ex:\n raise click.ClickException(str(ex))\n\n with click.progressbar(\n rows, label=\"Embedding\", show_percent=True, length=expected_length\n ) as rows:\n\n def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:\n for row in rows:\n values = list(row.values())\n id: str = prefix + str(values[0])\n content: Optional[Union[bytes, str]] = None\n if binary:\n content = cast(bytes, values[1])\n else:\n content = \" \".join(v or \"\" for v in values[1:])\n if prepend and isinstance(content, str):\n content = prepend + content\n yield id, content or \"\"\n\n embed_kwargs = {\"store\": store}\n if batch_size:\n embed_kwargs[\"batch_size\"] = batch_size\n collection_obj.embed_multi(tuples(), **embed_kwargs)"} | |
{"id": "llm/cli.py:2172", "code": "@cli.command()\n@click.argument(\"collection\")\n@click.argument(\"id\", required=False)\n@click.option(\n \"-i\",\n \"--input\",\n type=click.Path(exists=True, readable=True, allow_dash=True),\n help=\"File to embed for comparison\",\n)\n@click.option(\"-c\", \"--content\", help=\"Content to embed for comparison\")\n@click.option(\"--binary\", is_flag=True, help=\"Treat input as binary data\")\n@click.option(\n \"-n\", \"--number\", type=int, default=10, help=\"Number of results to return\"\n)\n@click.option(\n \"-d\",\n \"--database\",\n type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),\n envvar=\"LLM_EMBEDDINGS_DB\",\n)\ndef similar(collection, id, input, content, binary, number, database):\n \"\"\"\n Return top N similar IDs from a collection using cosine similarity.\n\n Example usage:\n\n \\b\n llm similar my-collection -c \"I like cats\"\n\n Or to find content similar to a specific stored ID:\n\n \\b\n llm similar my-collection 1234\n \"\"\"\n if not id and not content and not input:\n raise click.ClickException(\"Must provide content or an ID for the comparison\")\n\n if database:\n db = sqlite_utils.Database(database)\n else:\n db = sqlite_utils.Database(user_dir() / \"embeddings.db\")\n\n if not db[\"embeddings\"].exists():\n raise click.ClickException(\"No embeddings table found in database\")\n\n try:\n collection_obj = Collection(collection, db, create=False)\n except Collection.DoesNotExist:\n raise click.ClickException(\"Collection does not exist\")\n\n if id:\n try:\n results = collection_obj.similar_by_id(id, number)\n except Collection.DoesNotExist:\n raise click.ClickException(\"ID not found in collection\")\n else:\n # Resolve input text\n if not content:\n if not input or input == \"-\":\n # Read from stdin\n input_source = sys.stdin.buffer if binary else sys.stdin\n content = input_source.read()\n else:\n mode = \"rb\" if binary else \"r\"\n with open(input, mode) as f:\n content = f.read()\n if not content:\n raise click.ClickException(\"No content provided\")\n results = collection_obj.similar(content, number)\n\n for result in results:\n click.echo(json.dumps(asdict(result)))"} | |
{"id": "llm/cli.py:2246", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef embed_models():\n \"Manage available embedding models\""} | |
{"id": "llm/cli.py:2255", "code": "@embed_models.command(name=\"list\")\n@click.option(\n \"-q\",\n \"--query\",\n multiple=True,\n help=\"Search for embedding models matching these strings\",\n)\ndef embed_models_list(query):\n \"List available embedding models\"\n output = []\n for model_with_aliases in get_embedding_models_with_aliases():\n if query:\n if not all(model_with_aliases.matches(q) for q in query):\n continue\n s = str(model_with_aliases.model)\n if model_with_aliases.aliases:\n s += \" (aliases: {})\".format(\", \".join(model_with_aliases.aliases))\n output.append(s)\n click.echo(\"\\n\".join(output))"} | |
{"id": "llm/cli.py:2276", "code": "@embed_models.command(name=\"default\")\n@click.argument(\"model\", required=False)\n@click.option(\n \"--remove-default\", is_flag=True, help=\"Reset to specifying no default model\"\n)\ndef embed_models_default(model, remove_default):\n \"Show or set the default embedding model\"\n if not model and not remove_default:\n default = get_default_embedding_model()\n if default is None:\n click.echo(\"<No default embedding model set>\", err=True)\n else:\n click.echo(default)\n return\n # Validate it is a known model\n try:\n if remove_default:\n set_default_embedding_model(None)\n else:\n model = get_embedding_model(model)\n set_default_embedding_model(model.model_id)\n except KeyError:\n raise click.ClickException(\"Unknown embedding model: {}\".format(model))"} | |
{"id": "llm/cli.py:2301", "code": "@cli.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef collections():\n \"View and manage collections of embeddings\""} | |
{"id": "llm/cli.py:2310", "code": "@collections.command(name=\"path\")\ndef collections_path():\n \"Output the path to the embeddings database\"\n click.echo(user_dir() / \"embeddings.db\")"} | |
{"id": "llm/cli.py:2316", "code": "@collections.command(name=\"list\")\n@click.option(\n \"-d\",\n \"--database\",\n type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),\n envvar=\"LLM_EMBEDDINGS_DB\",\n help=\"Path to embeddings database\",\n)\n@click.option(\"json_\", \"--json\", is_flag=True, help=\"Output as JSON\")\ndef embed_db_collections(database, json_):\n \"View a list of collections\"\n database = database or (user_dir() / \"embeddings.db\")\n db = sqlite_utils.Database(str(database))\n if not db[\"collections\"].exists():\n raise click.ClickException(\"No collections table found in {}\".format(database))\n rows = db.query(\n \"\"\"\n select\n collections.name,\n collections.model,\n count(embeddings.id) as num_embeddings\n from\n collections left join embeddings\n on collections.id = embeddings.collection_id\n group by\n collections.name, collections.model\n \"\"\"\n )\n if json_:\n click.echo(json.dumps(list(rows), indent=4))\n else:\n for row in rows:\n click.echo(\"{}: {}\".format(row[\"name\"], row[\"model\"]))\n click.echo(\n \" {} embedding{}\".format(\n row[\"num_embeddings\"], \"s\" if row[\"num_embeddings\"] != 1 else \"\"\n )\n )"} | |
{"id": "llm/cli.py:2356", "code": "@collections.command(name=\"delete\")\n@click.argument(\"collection\")\n@click.option(\n \"-d\",\n \"--database\",\n type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),\n envvar=\"LLM_EMBEDDINGS_DB\",\n help=\"Path to embeddings database\",\n)\ndef collections_delete(collection, database):\n \"\"\"\n Delete the specified collection\n\n Example usage:\n\n \\b\n llm collections delete my-collection\n \"\"\"\n database = database or (user_dir() / \"embeddings.db\")\n db = sqlite_utils.Database(str(database))\n try:\n collection_obj = Collection(collection, db, create=False)\n except Collection.DoesNotExist:\n raise click.ClickException(\"Collection does not exist\")\n collection_obj.delete()"} | |
{"id": "llm/cli.py:2383", "code": "@models.group(\n cls=DefaultGroup,\n default=\"list\",\n default_if_no_args=True,\n)\ndef options():\n \"Manage default options for models\""} | |
{"id": "llm/cli.py:2392", "code": "@options.command(name=\"list\")\ndef options_list():\n \"\"\"\n List default options for all models\n\n Example usage:\n\n \\b\n llm models options list\n \"\"\"\n options = get_all_model_options()\n if not options:\n click.echo(\"No default options set for any models.\", err=True)\n return\n\n for model_id, model_options in options.items():\n click.echo(f\"{model_id}:\")\n for key, value in model_options.items():\n click.echo(f\" {key}: {value}\")"} | |
{"id": "llm/cli.py:2413", "code": "@options.command(name=\"show\")\n@click.argument(\"model\")\ndef options_show(model):\n \"\"\"\n List default options set for a specific model\n\n Example usage:\n\n \\b\n llm models options show gpt-4o\n \"\"\"\n import llm\n\n try:\n # Resolve alias to model ID\n model_obj = llm.get_model(model)\n model_id = model_obj.model_id\n except llm.UnknownModelError:\n # Use as-is if not found\n model_id = model\n\n options = get_model_options(model_id)\n if not options:\n click.echo(f\"No default options set for model '{model_id}'.\", err=True)\n return\n\n for key, value in options.items():\n click.echo(f\"{key}: {value}\")"} | |
{"id": "llm/cli.py:2443", "code": "@options.command(name=\"set\")\n@click.argument(\"model\")\n@click.argument(\"key\")\n@click.argument(\"value\")\ndef options_set(model, key, value):\n \"\"\"\n Set a default option for a model\n\n Example usage:\n\n \\b\n llm models options set gpt-4o temperature 0.5\n \"\"\"\n import llm\n\n try:\n # Resolve alias to model ID\n model_obj = llm.get_model(model)\n model_id = model_obj.model_id\n\n # Validate option against model schema\n try:\n # Create a test Options object to validate\n test_options = {key: value}\n model_obj.Options(**test_options)\n except pydantic.ValidationError as ex:\n raise click.ClickException(render_errors(ex.errors()))\n\n except llm.UnknownModelError:\n # Use as-is if not found\n model_id = model\n\n set_model_option(model_id, key, value)\n click.echo(f\"Set default option {key}={value} for model {model_id}\", err=True)"} | |
{"id": "llm/cli.py:2479", "code": "@options.command(name=\"clear\")\n@click.argument(\"model\")\n@click.argument(\"key\", required=False)\ndef options_clear(model, key):\n \"\"\"\n Clear default option(s) for a model\n\n Example usage:\n\n \\b\n llm models options clear gpt-4o\n # Or for a single option\n llm models options clear gpt-4o temperature\n \"\"\"\n import llm\n\n try:\n # Resolve alias to model ID\n model_obj = llm.get_model(model)\n model_id = model_obj.model_id\n except llm.UnknownModelError:\n # Use as-is if not found\n model_id = model\n\n cleared_keys = []\n if not key:\n cleared_keys = list(get_model_options(model_id).keys())\n for key_ in cleared_keys:\n clear_model_option(model_id, key_)\n else:\n cleared_keys.append(key)\n clear_model_option(model_id, key)\n if cleared_keys:\n if len(cleared_keys) == 1:\n click.echo(f\"Cleared option '{cleared_keys[0]}' for model {model_id}\")\n else:\n click.echo(\n f\"Cleared {', '.join(cleared_keys)} options for model {model_id}\"\n )"} | |
{"id": "llm/cli.py:2520", "code": "def template_dir():\n path = user_dir() / \"templates\"\n path.mkdir(parents=True, exist_ok=True)\n return path"} | |
{"id": "llm/cli.py:2526", "code": "def logs_db_path():\n return user_dir() / \"logs.db\""} | |
{"id": "llm/cli.py:2530", "code": "def load_template(name):\n if \":\" in name:\n prefix, rest = name.split(\":\", 1)\n loaders = get_template_loaders()\n if prefix not in loaders:\n raise click.ClickException(\"Unknown template prefix: {}\".format(prefix))\n loader = loaders[prefix]\n try:\n return loader(rest)\n except Exception as ex:\n raise click.ClickException(\n \"Could not load template {}: {}\".format(name, ex)\n )\n\n path = template_dir() / f\"{name}.yaml\"\n if not path.exists():\n raise click.ClickException(f\"Invalid template: {name}\")\n try:\n loaded = yaml.safe_load(path.read_text())\n except yaml.YAMLError as ex:\n raise click.ClickException(\"Invalid YAML: {}\".format(str(ex)))\n if isinstance(loaded, str):\n return Template(name=name, prompt=loaded)\n loaded[\"name\"] = name\n try:\n return Template(**loaded)\n except pydantic.ValidationError as ex:\n msg = \"A validation error occurred:\\n\"\n msg += render_errors(ex.errors())\n raise click.ClickException(msg)"} | |
{"id": "llm/cli.py:2562", "code": "def get_history(chat_id):\n if chat_id is None:\n return None, []\n log_path = logs_db_path()\n db = sqlite_utils.Database(log_path)\n migrate(db)\n if chat_id == -1:\n # Return the most recent chat\n last_row = list(db[\"logs\"].rows_where(order_by=\"-id\", limit=1))\n if last_row:\n chat_id = last_row[0].get(\"chat_id\") or last_row[0].get(\"id\")\n else: # Database is empty\n return None, []\n rows = db[\"logs\"].rows_where(\n \"id = ? or chat_id = ?\", [chat_id, chat_id], order_by=\"id\"\n )\n return chat_id, rows"} | |
{"id": "llm/cli.py:2581", "code": "def render_errors(errors):\n output = []\n for error in errors:\n output.append(\", \".join(error[\"loc\"]))\n output.append(\" \" + error[\"msg\"])\n return \"\\n\".join(output)"} | |
{"id": "llm/cli.py:2594", "code": "def _human_readable_size(size_bytes):\n if size_bytes == 0:\n return \"0B\"\n\n size_name = (\"B\", \"KB\", \"MB\", \"GB\", \"TB\", \"PB\", \"EB\", \"ZB\", \"YB\")\n i = 0\n\n while size_bytes >= 1024 and i < len(size_name) - 1:\n size_bytes /= 1024.0\n i += 1\n\n return \"{:.2f}{}\".format(size_bytes, size_name[i])"} | |
{"id": "llm/cli.py:2608", "code": "def logs_on():\n return not (user_dir() / \"logs-off\").exists()"} | |
{"id": "llm/cli.py:2612", "code": "def get_all_model_options() -> dict:\n \"\"\"\n Get all default options for all models\n \"\"\"\n path = user_dir() / \"model_options.json\"\n if not path.exists():\n return {}\n\n try:\n options = json.loads(path.read_text())\n except json.JSONDecodeError:\n return {}\n\n return options"} | |
{"id": "llm/cli.py:2628", "code": "def get_model_options(model_id: str) -> dict:\n \"\"\"\n Get default options for a specific model\n\n Args:\n model_id: Return options for model with this ID\n\n Returns:\n A dictionary of model options\n \"\"\"\n path = user_dir() / \"model_options.json\"\n if not path.exists():\n return {}\n\n try:\n options = json.loads(path.read_text())\n except json.JSONDecodeError:\n return {}\n\n return options.get(model_id, {})"} | |
{"id": "llm/cli.py:2650", "code": "def set_model_option(model_id: str, key: str, value: Any) -> None:\n \"\"\"\n Set a default option for a model.\n\n Args:\n model_id: The model ID\n key: The option key\n value: The option value\n \"\"\"\n path = user_dir() / \"model_options.json\"\n if path.exists():\n try:\n options = json.loads(path.read_text())\n except json.JSONDecodeError:\n options = {}\n else:\n options = {}\n\n # Ensure the model has an entry\n if model_id not in options:\n options[model_id] = {}\n\n # Set the option\n options[model_id][key] = value\n\n # Save the options\n path.write_text(json.dumps(options, indent=2))"} | |
{"id": "llm/cli.py:2679", "code": "def clear_model_option(model_id: str, key: str) -> None:\n \"\"\"\n Clear a model option\n\n Args:\n model_id: The model ID\n key: Key to clear\n \"\"\"\n path = user_dir() / \"model_options.json\"\n if not path.exists():\n return\n\n try:\n options = json.loads(path.read_text())\n except json.JSONDecodeError:\n return\n\n if model_id not in options:\n return\n\n if key in options[model_id]:\n del options[model_id][key]\n if not options[model_id]:\n del options[model_id]\n\n path.write_text(json.dumps(options, indent=2))"} | |
{"id": "llm/utils.py:17", "code": "def mimetype_from_string(content) -> Optional[str]:\n try:\n type_ = puremagic.from_string(content, mime=True)\n return MIME_TYPE_FIXES.get(type_, type_)\n except puremagic.PureError:\n return None"} | |
{"id": "llm/utils.py:25", "code": "def mimetype_from_path(path) -> Optional[str]:\n try:\n type_ = puremagic.from_file(path, mime=True)\n return MIME_TYPE_FIXES.get(type_, type_)\n except puremagic.PureError:\n return None"} | |
{"id": "llm/utils.py:33", "code": "def dicts_to_table_string(\n headings: List[str], dicts: List[Dict[str, str]]\n) -> List[str]:\n max_lengths = [len(h) for h in headings]\n\n # Compute maximum length for each column\n for d in dicts:\n for i, h in enumerate(headings):\n if h in d and len(str(d[h])) > max_lengths[i]:\n max_lengths[i] = len(str(d[h]))\n\n # Generate formatted table strings\n res = []\n res.append(\" \".join(h.ljust(max_lengths[i]) for i, h in enumerate(headings)))\n\n for d in dicts:\n row = []\n for i, h in enumerate(headings):\n row.append(str(d.get(h, \"\")).ljust(max_lengths[i]))\n res.append(\" \".join(row))\n\n return res"} | |
{"id": "llm/utils.py:57", "code": "def remove_dict_none_values(d):\n \"\"\"\n Recursively remove keys with value of None or value of a dict that is all values of None\n \"\"\"\n if not isinstance(d, dict):\n return d\n new_dict = {}\n for key, value in d.items():\n if value is not None:\n if isinstance(value, dict):\n nested = remove_dict_none_values(value)\n if nested:\n new_dict[key] = nested\n elif isinstance(value, list):\n new_dict[key] = [remove_dict_none_values(v) for v in value]\n else:\n new_dict[key] = value\n return new_dict"} | |
{"id": "llm/utils.py:77", "code": "class _LogResponse(httpx.Response):\n def iter_bytes(self, *args, **kwargs):\n for chunk in super().iter_bytes(*args, **kwargs):\n click.echo(chunk.decode(), err=True)\n yield chunk"} | |
{"id": "llm/utils.py:78", "code": " def iter_bytes(self, *args, **kwargs):\n for chunk in super().iter_bytes(*args, **kwargs):\n click.echo(chunk.decode(), err=True)\n yield chunk"} | |
{"id": "llm/utils.py:84", "code": "class _LogTransport(httpx.BaseTransport):\n def __init__(self, transport: httpx.BaseTransport):\n self.transport = transport\n\n def handle_request(self, request: httpx.Request) -> httpx.Response:\n response = self.transport.handle_request(request)\n return _LogResponse(\n status_code=response.status_code,\n headers=response.headers,\n stream=response.stream,\n extensions=response.extensions,\n )"} | |
{"id": "llm/utils.py:85", "code": " def __init__(self, transport: httpx.BaseTransport):\n self.transport = transport"} | |
{"id": "llm/utils.py:88", "code": " def handle_request(self, request: httpx.Request) -> httpx.Response:\n response = self.transport.handle_request(request)\n return _LogResponse(\n status_code=response.status_code,\n headers=response.headers,\n stream=response.stream,\n extensions=response.extensions,\n )"} | |
{"id": "llm/utils.py:98", "code": "def _no_accept_encoding(request: httpx.Request):\n request.headers.pop(\"accept-encoding\", None)"} | |
{"id": "llm/utils.py:102", "code": "def _log_response(response: httpx.Response):\n request = response.request\n click.echo(f\"Request: {request.method} {request.url}\", err=True)\n click.echo(\" Headers:\", err=True)\n for key, value in request.headers.items():\n if key.lower() == \"authorization\":\n value = \"[...]\"\n if key.lower() == \"cookie\":\n value = value.split(\"=\")[0] + \"=...\"\n click.echo(f\" {key}: {value}\", err=True)\n click.echo(\" Body:\", err=True)\n try:\n request_body = json.loads(request.content)\n click.echo(\n textwrap.indent(json.dumps(request_body, indent=2), \" \"), err=True\n )\n except json.JSONDecodeError:\n click.echo(textwrap.indent(request.content.decode(), \" \"), err=True)\n click.echo(f\"Response: status_code={response.status_code}\", err=True)\n click.echo(\" Headers:\", err=True)\n for key, value in response.headers.items():\n if key.lower() == \"set-cookie\":\n value = value.split(\"=\")[0] + \"=...\"\n click.echo(f\" {key}: {value}\", err=True)\n click.echo(\" Body:\", err=True)"} | |
{"id": "llm/utils.py:129", "code": "def logging_client() -> httpx.Client:\n return httpx.Client(\n transport=_LogTransport(httpx.HTTPTransport()),\n event_hooks={\"request\": [_no_accept_encoding], \"response\": [_log_response]},\n )"} | |
{"id": "llm/utils.py:136", "code": "def simplify_usage_dict(d):\n # Recursively remove keys with value 0 and empty dictionaries\n def remove_empty_and_zero(obj):\n if isinstance(obj, dict):\n cleaned = {\n k: remove_empty_and_zero(v)\n for k, v in obj.items()\n if v != 0 and v != {}\n }\n return {k: v for k, v in cleaned.items() if v is not None and v != {}}\n return obj\n\n return remove_empty_and_zero(d) or {}"} | |
{"id": "llm/utils.py:151", "code": "def token_usage_string(input_tokens, output_tokens, token_details) -> str:\n bits = []\n if input_tokens is not None:\n bits.append(f\"{format(input_tokens, ',')} input\")\n if output_tokens is not None:\n bits.append(f\"{format(output_tokens, ',')} output\")\n if token_details:\n bits.append(json.dumps(token_details))\n return \", \".join(bits)"} | |
{"id": "llm/utils.py:162", "code": "def extract_fenced_code_block(text: str, last: bool = False) -> Optional[str]:\n \"\"\"\n Extracts and returns Markdown fenced code block found in the given text.\n\n The function handles fenced code blocks that:\n - Use at least three backticks (`).\n - May include a language tag immediately after the opening backticks.\n - Use more than three backticks as long as the closing fence has the same number.\n\n If no fenced code block is found, the function returns None.\n\n Args:\n text (str): The input text to search for a fenced code block.\n last (bool): Extract the last code block if True, otherwise the first.\n\n Returns:\n Optional[str]: The content of the fenced code block, or None if not found.\n \"\"\"\n # Regex pattern to match fenced code blocks\n # - ^ or \\n ensures that the fence is at the start of a line\n # - (`{3,}) captures the opening backticks (at least three)\n # - (\\w+)? optionally captures the language tag\n # - \\n matches the newline after the opening fence\n # - (.*?) non-greedy match for the code block content\n # - (?P=fence) ensures that the closing fence has the same number of backticks\n # - [ ]* allows for optional spaces between the closing fence and newline\n # - (?=\\n|$) ensures that the closing fence is followed by a newline or end of string\n pattern = re.compile(\n r\"\"\"(?m)^(?P<fence>`{3,})(?P<lang>\\w+)?\\n(?P<code>.*?)^(?P=fence)[ ]*(?=\\n|$)\"\"\",\n re.DOTALL,\n )\n matches = list(pattern.finditer(text))\n if matches:\n match = matches[-1] if last else matches[0]\n return match.group(\"code\")\n return None"} | |
{"id": "llm/utils.py:200", "code": "def make_schema_id(schema: dict) -> Tuple[str, str]:\n schema_json = json.dumps(schema, separators=(\",\", \":\"))\n schema_id = hashlib.blake2b(schema_json.encode(), digest_size=16).hexdigest()\n return schema_id, schema_json"} | |
{"id": "llm/utils.py:206", "code": "def output_rows_as_json(rows, nl=False):\n \"\"\"\n Output rows as JSON - either newline-delimited or an array\n\n Parameters:\n - rows: List of dictionaries to output\n - nl: Boolean, if True, use newline-delimited JSON\n\n Returns:\n - String with formatted JSON output\n \"\"\"\n if not rows:\n return \"\" if nl else \"[]\"\n\n lines = []\n end_i = len(rows) - 1\n for i, row in enumerate(rows):\n is_first = i == 0\n is_last = i == end_i\n\n line = \"{firstchar}{serialized}{maybecomma}{lastchar}\".format(\n firstchar=(\"[\" if is_first else \" \") if not nl else \"\",\n serialized=json.dumps(row),\n maybecomma=\",\" if (not nl and not is_last) else \"\",\n lastchar=\"]\" if (is_last and not nl) else \"\",\n )\n lines.append(line)\n\n return \"\\n\".join(lines)"} | |
{"id": "llm/utils.py:237", "code": "def resolve_schema_input(db, schema_input, load_template):\n # schema_input might be JSON or a filepath or an ID or t:name\n if not schema_input:\n return\n if schema_input.strip().startswith(\"t:\"):\n name = schema_input.strip()[2:]\n template = load_template(name)\n if not template.schema_object:\n raise click.ClickException(\"Template '{}' has no schema\".format(name))\n return template.schema_object\n if schema_input.strip().startswith(\"{\"):\n try:\n return json.loads(schema_input)\n except ValueError:\n pass\n if \" \" in schema_input.strip() or \",\" in schema_input:\n # Treat it as schema DSL\n return schema_dsl(schema_input)\n # Is it a file on disk?\n path = pathlib.Path(schema_input)\n if path.exists():\n try:\n return json.loads(path.read_text())\n except ValueError:\n raise click.ClickException(\"Schema file contained invalid JSON\")\n # Last attempt: is it an ID in the DB?\n try:\n row = db[\"schemas\"].get(schema_input)\n return json.loads(row[\"content\"])\n except (sqlite_utils.db.NotFoundError, ValueError):\n raise click.BadParameter(\"Invalid schema\")"} | |
{"id": "llm/utils.py:270", "code": "def schema_summary(schema: dict) -> str:\n \"\"\"\n Extract property names from a JSON schema and format them in a\n concise way that highlights the array/object structure.\n\n Args:\n schema (dict): A JSON schema dictionary\n\n Returns:\n str: A human-friendly summary of the schema structure\n \"\"\"\n if not schema or not isinstance(schema, dict):\n return \"\"\n\n schema_type = schema.get(\"type\", \"\")\n\n if schema_type == \"object\":\n props = schema.get(\"properties\", {})\n prop_summaries = []\n\n for name, prop_schema in props.items():\n prop_type = prop_schema.get(\"type\", \"\")\n\n if prop_type == \"array\":\n items = prop_schema.get(\"items\", {})\n items_summary = schema_summary(items)\n prop_summaries.append(f\"{name}: [{items_summary}]\")\n elif prop_type == \"object\":\n nested_summary = schema_summary(prop_schema)\n prop_summaries.append(f\"{name}: {nested_summary}\")\n else:\n prop_summaries.append(name)\n\n return \"{\" + \", \".join(prop_summaries) + \"}\"\n\n elif schema_type == \"array\":\n items = schema.get(\"items\", {})\n return schema_summary(items)\n\n return \"\""} | |
{"id": "llm/utils.py:312", "code": "def schema_dsl(schema_dsl: str, multi: bool = False) -> Dict[str, Any]:\n \"\"\"\n Build a JSON schema from a concise schema string.\n\n Args:\n schema_dsl: A string representing a schema in the concise format.\n Can be comma-separated or newline-separated.\n multi: Boolean, return a schema for an \"items\" array of these\n\n Returns:\n A dictionary representing the JSON schema.\n \"\"\"\n # Type mapping dictionary\n type_mapping = {\n \"int\": \"integer\",\n \"float\": \"number\",\n \"bool\": \"boolean\",\n \"str\": \"string\",\n }\n\n # Initialize the schema dictionary with required elements\n json_schema: Dict[str, Any] = {\"type\": \"object\", \"properties\": {}, \"required\": []}\n\n # Check if the schema is newline-separated or comma-separated\n if \"\\n\" in schema_dsl:\n fields = [field.strip() for field in schema_dsl.split(\"\\n\") if field.strip()]\n else:\n fields = [field.strip() for field in schema_dsl.split(\",\") if field.strip()]\n\n # Process each field\n for field in fields:\n # Extract field name, type, and description\n if \":\" in field:\n field_info, description = field.split(\":\", 1)\n description = description.strip()\n else:\n field_info = field\n description = \"\"\n\n # Process field name and type\n field_parts = field_info.strip().split()\n field_name = field_parts[0].strip()\n\n # Default type is string\n field_type = \"string\"\n\n # If type is specified, use it\n if len(field_parts) > 1:\n type_indicator = field_parts[1].strip()\n if type_indicator in type_mapping:\n field_type = type_mapping[type_indicator]\n\n # Add field to properties\n json_schema[\"properties\"][field_name] = {\"type\": field_type}\n\n # Add description if provided\n if description:\n json_schema[\"properties\"][field_name][\"description\"] = description\n\n # Add field to required list\n json_schema[\"required\"].append(field_name)\n\n if multi:\n return multi_schema(json_schema)\n else:\n return json_schema"} | |
{"id": "llm/utils.py:380", "code": "def multi_schema(schema: dict) -> dict:\n \"Wrap JSON schema in an 'items': [] array\"\n return {\n \"type\": \"object\",\n \"properties\": {\"items\": {\"type\": \"array\", \"items\": schema}},\n \"required\": [\"items\"],\n }"} | |
{"id": "llm/utils.py:389", "code": "def find_unused_key(item: dict, key: str) -> str:\n 'Return unused key, e.g. for {\"id\": \"1\"} and key \"id\" returns \"id_\"'\n while key in item:\n key += \"_\"\n return key"} | |
{"id": "llm/utils.py:396", "code": "def truncate_string(\n text: str,\n max_length: int = 100,\n normalize_whitespace: bool = False,\n keep_end: bool = False,\n) -> str:\n \"\"\"\n Truncate a string to a maximum length, with options to normalize whitespace and keep both start and end.\n\n Args:\n text: The string to truncate\n max_length: Maximum length of the result string\n normalize_whitespace: If True, replace all whitespace with a single space\n keep_end: If True, keep both beginning and end of string\n\n Returns:\n Truncated string\n \"\"\"\n if not text:\n return text\n\n if normalize_whitespace:\n text = re.sub(r\"\\s+\", \" \", text)\n\n if len(text) <= max_length:\n return text\n\n # Minimum sensible length for keep_end is 9 characters: \"a... z\"\n min_keep_end_length = 9\n\n if keep_end and max_length >= min_keep_end_length:\n # Calculate how much text to keep at each end\n # Subtract 5 for the \"... \" separator\n cutoff = (max_length - 5) // 2\n return text[:cutoff] + \"... \" + text[-cutoff:]\n else:\n # Fall back to simple truncation for very small max_length\n return text[: max_length - 3] + \"...\""} | |
{"id": "llm/examples.py:6", "code": "def build_markov_table(text):\n words = text.split()\n transitions = {}\n # Loop through all but the last word\n for i in range(len(words) - 1):\n word = words[i]\n next_word = words[i + 1]\n transitions.setdefault(word, []).append(next_word)\n return transitions"} | |
{"id": "llm/examples.py:17", "code": "def generate(transitions, length, start_word=None):\n all_words = list(transitions.keys())\n next_word = start_word or random.choice(all_words)\n for i in range(length):\n yield next_word\n options = transitions.get(next_word) or all_words\n next_word = random.choice(options)"} | |
{"id": "llm/examples.py:26", "code": "class Markov(llm.Model):\n model_id = \"markov\"\n\n def execute(self, prompt, stream, response, conversation):\n text = prompt.prompt\n transitions = build_markov_table(text)\n for word in generate(transitions, 20):\n yield word + \" \""} | |
{"id": "llm/examples.py:29", "code": " def execute(self, prompt, stream, response, conversation):\n text = prompt.prompt\n transitions = build_markov_table(text)\n for word in generate(transitions, 20):\n yield word + \" \""} | |
{"id": "llm/examples.py:36", "code": "class AnnotationsModel(llm.Model):\n model_id = \"annotations\"\n can_stream = True\n\n def execute(self, prompt, stream, response, conversation):\n yield \"Here is text before the annotation. \"\n yield llm.Chunk(\n text=\"This is the annotated text. \",\n annotation={\"title\": \"Annotation Title\", \"content\": \"Annotation Content\"},\n )\n yield \"Here is text after the annotation.\""} | |
{"id": "llm/examples.py:40", "code": " def execute(self, prompt, stream, response, conversation):\n yield \"Here is text before the annotation. \"\n yield llm.Chunk(\n text=\"This is the annotated text. \",\n annotation={\"title\": \"Annotation Title\", \"content\": \"Annotation Content\"},\n )\n yield \"Here is text after the annotation.\""} | |
{"id": "llm/examples.py:49", "code": "class AnnotationsModelAsync(llm.AsyncModel):\n model_id = \"annotations\"\n can_stream = True\n\n async def execute(\n self, prompt, stream, response, conversation=None\n ) -> AsyncGenerator[Union[llm.Chunk, str], None]:\n yield \"Here is text before the annotation. \"\n yield llm.Chunk(\n text=\"This is the annotated text. \",\n annotation={\"title\": \"Annotation Title\", \"content\": \"Annotation Content\"},\n )\n yield \"Here is text after the annotation.\""} | |
{"id": "llm/examples.py:53", "code": " async def execute(\n self, prompt, stream, response, conversation=None\n ) -> AsyncGenerator[Union[llm.Chunk, str], None]:\n yield \"Here is text before the annotation. \"\n yield llm.Chunk(\n text=\"This is the annotated text. \",\n annotation={\"title\": \"Annotation Title\", \"content\": \"Annotation Content\"},\n )\n yield \"Here is text after the annotation.\""} | |
{"id": "llm/migrations.py:8", "code": "def migrate(db):\n ensure_migrations_table(db)\n already_applied = {r[\"name\"] for r in db[\"_llm_migrations\"].rows}\n for fn in MIGRATIONS:\n name = fn.__name__\n if name not in already_applied:\n fn(db)\n db[\"_llm_migrations\"].insert(\n {\n \"name\": name,\n \"applied_at\": str(datetime.datetime.now(datetime.timezone.utc)),\n }\n )\n already_applied.add(name)"} | |
{"id": "llm/migrations.py:24", "code": "def ensure_migrations_table(db):\n if not db[\"_llm_migrations\"].exists():\n db[\"_llm_migrations\"].create(\n {\n \"name\": str,\n \"applied_at\": str,\n },\n pk=\"name\",\n )"} | |
{"id": "llm/migrations.py:35", "code": "@migration\ndef m001_initial(db):\n # Ensure the original table design exists, so other migrations can run\n if db[\"log\"].exists():\n # It needs to have the chat_id column\n if \"chat_id\" not in db[\"log\"].columns_dict:\n db[\"log\"].add_column(\"chat_id\")\n return\n db[\"log\"].create(\n {\n \"provider\": str,\n \"system\": str,\n \"prompt\": str,\n \"chat_id\": str,\n \"response\": str,\n \"model\": str,\n \"timestamp\": str,\n }\n )"} | |
{"id": "llm/migrations.py:56", "code": "@migration\ndef m002_id_primary_key(db):\n db[\"log\"].transform(pk=\"id\")"} | |
{"id": "llm/migrations.py:61", "code": "@migration\ndef m003_chat_id_foreign_key(db):\n db[\"log\"].transform(types={\"chat_id\": int})\n db[\"log\"].add_foreign_key(\"chat_id\", \"log\", \"id\")"} | |
{"id": "llm/migrations.py:67", "code": "@migration\ndef m004_column_order(db):\n db[\"log\"].transform(\n column_order=(\n \"id\",\n \"model\",\n \"timestamp\",\n \"prompt\",\n \"system\",\n \"response\",\n \"chat_id\",\n )\n )"} | |
{"id": "llm/migrations.py:82", "code": "@migration\ndef m004_drop_provider(db):\n db[\"log\"].transform(drop=(\"provider\",))"} | |
{"id": "llm/migrations.py:87", "code": "@migration\ndef m005_debug(db):\n db[\"log\"].add_column(\"debug\", str)\n db[\"log\"].add_column(\"duration_ms\", int)"} | |
{"id": "llm/migrations.py:93", "code": "@migration\ndef m006_new_logs_table(db):\n columns = db[\"log\"].columns_dict\n for column, type in (\n (\"options_json\", str),\n (\"prompt_json\", str),\n (\"response_json\", str),\n (\"reply_to_id\", int),\n ):\n # It's possible people running development code like myself\n # might have accidentally created these columns already\n if column not in columns:\n db[\"log\"].add_column(column, type)\n\n # Use .transform() to rename options and timestamp_utc, and set new order\n db[\"log\"].transform(\n column_order=(\n \"id\",\n \"model\",\n \"prompt\",\n \"system\",\n \"prompt_json\",\n \"options_json\",\n \"response\",\n \"response_json\",\n \"reply_to_id\",\n \"chat_id\",\n \"duration_ms\",\n \"timestamp_utc\",\n ),\n rename={\n \"timestamp\": \"timestamp_utc\",\n \"options\": \"options_json\",\n },\n )"} | |
{"id": "llm/migrations.py:130", "code": "@migration\ndef m007_finish_logs_table(db):\n db[\"log\"].transform(\n drop={\"debug\"},\n rename={\"timestamp_utc\": \"datetime_utc\"},\n drop_foreign_keys=(\"chat_id\",),\n )\n with db.conn:\n db.execute(\"alter table log rename to logs\")"} | |
{"id": "llm/migrations.py:141", "code": "@migration\ndef m008_reply_to_id_foreign_key(db):\n db[\"logs\"].add_foreign_key(\"reply_to_id\", \"logs\", \"id\")"} | |
{"id": "llm/migrations.py:146", "code": "@migration\ndef m008_fix_column_order_in_logs(db):\n # reply_to_id ended up at the end after foreign key added\n db[\"logs\"].transform(\n column_order=(\n \"id\",\n \"model\",\n \"prompt\",\n \"system\",\n \"prompt_json\",\n \"options_json\",\n \"response\",\n \"response_json\",\n \"reply_to_id\",\n \"chat_id\",\n \"duration_ms\",\n \"timestamp_utc\",\n ),\n )"} | |
{"id": "llm/migrations.py:167", "code": "@migration\ndef m009_delete_logs_table_if_empty(db):\n # We moved to a new table design, but we don't delete the table\n # if someone has put data in it\n if not db[\"logs\"].count:\n db[\"logs\"].drop()"} | |
{"id": "llm/migrations.py:175", "code": "@migration\ndef m010_create_new_log_tables(db):\n db[\"conversations\"].create(\n {\n \"id\": str,\n \"name\": str,\n \"model\": str,\n },\n pk=\"id\",\n )\n db[\"responses\"].create(\n {\n \"id\": str,\n \"model\": str,\n \"prompt\": str,\n \"system\": str,\n \"prompt_json\": str,\n \"options_json\": str,\n \"response\": str,\n \"response_json\": str,\n \"conversation_id\": str,\n \"duration_ms\": int,\n \"datetime_utc\": str,\n },\n pk=\"id\",\n foreign_keys=((\"conversation_id\", \"conversations\", \"id\"),),\n )"} | |
{"id": "llm/migrations.py:204", "code": "@migration\ndef m011_fts_for_responses(db):\n db[\"responses\"].enable_fts([\"prompt\", \"response\"], create_triggers=True)"} | |
{"id": "llm/migrations.py:209", "code": "@migration\ndef m012_attachments_tables(db):\n db[\"attachments\"].create(\n {\n \"id\": str,\n \"type\": str,\n \"path\": str,\n \"url\": str,\n \"content\": bytes,\n },\n pk=\"id\",\n )\n db[\"prompt_attachments\"].create(\n {\n \"response_id\": str,\n \"attachment_id\": str,\n \"order\": int,\n },\n foreign_keys=(\n (\"response_id\", \"responses\", \"id\"),\n (\"attachment_id\", \"attachments\", \"id\"),\n ),\n pk=(\"response_id\", \"attachment_id\"),\n )"} | |
{"id": "llm/migrations.py:235", "code": "@migration\ndef m013_usage(db):\n db[\"responses\"].add_column(\"input_tokens\", int)\n db[\"responses\"].add_column(\"output_tokens\", int)\n db[\"responses\"].add_column(\"token_details\", str)"} | |
{"id": "llm/migrations.py:242", "code": "@migration\ndef m014_schemas(db):\n db[\"schemas\"].create(\n {\n \"id\": str,\n \"content\": str,\n },\n pk=\"id\",\n )\n db[\"responses\"].add_column(\"schema_id\", str, fk=\"schemas\", fk_col=\"id\")\n # Clean up SQL create table indentation\n db[\"responses\"].transform()\n # These changes may have dropped the FTS configuration, fix that\n db[\"responses\"].enable_fts(\n [\"prompt\", \"response\"], create_triggers=True, replace=True\n )"} | |
{"id": "llm/migrations.py:260", "code": "@migration\ndef m015_response_annotations(db):\n db[\"response_annotations\"].create(\n {\n \"id\": int,\n \"response_id\": str,\n \"start_index\": int,\n \"end_index\": int,\n \"data\": str,\n },\n pk=\"id\",\n foreign_keys=((\"response_id\", \"responses\", \"id\"),),\n )"} | |
{"id": "llm/errors.py:1", "code": "class ModelError(Exception):\n \"Models can raise this error, which will be displayed to the user\""} | |
{"id": "llm/errors.py:5", "code": "class NeedsKeyException(ModelError):\n \"Model needs an API key which has not been provided\""} | |
{"id": "llm/default_plugins/chunkers.py:4", "code": "def lines(text):\n \"Chunk text into lines\"\n for line in text.split(\"\\n\"):\n if line.strip():\n yield line"} | |
{"id": "llm/default_plugins/chunkers.py:11", "code": "@hookimpl\ndef register_chunker_functions(register):\n register(lines, name=\"lines\")"} | |
{"id": "llm/default_plugins/openai_models.py:23", "code": "@hookimpl\ndef register_models(register):\n # GPT-4o\n register(\n Chat(\"gpt-4o\", vision=True, supports_schema=True),\n AsyncChat(\"gpt-4o\", vision=True, supports_schema=True),\n aliases=(\"4o\",),\n )\n register(\n Chat(\"chatgpt-4o-latest\", vision=True),\n AsyncChat(\"chatgpt-4o-latest\", vision=True),\n aliases=(\"chatgpt-4o\",),\n )\n register(\n Chat(\"gpt-4o-mini\", vision=True, supports_schema=True),\n AsyncChat(\"gpt-4o-mini\", vision=True, supports_schema=True),\n aliases=(\"4o-mini\",),\n )\n for audio_model_id in (\n \"gpt-4o-audio-preview\",\n \"gpt-4o-audio-preview-2024-12-17\",\n \"gpt-4o-audio-preview-2024-10-01\",\n \"gpt-4o-mini-audio-preview\",\n \"gpt-4o-mini-audio-preview-2024-12-17\",\n ):\n register(\n Chat(audio_model_id, audio=True),\n AsyncChat(audio_model_id, audio=True),\n )\n # 3.5 and 4\n register(\n Chat(\"gpt-3.5-turbo\"), AsyncChat(\"gpt-3.5-turbo\"), aliases=(\"3.5\", \"chatgpt\")\n )\n register(\n Chat(\"gpt-3.5-turbo-16k\"),\n AsyncChat(\"gpt-3.5-turbo-16k\"),\n aliases=(\"chatgpt-16k\", \"3.5-16k\"),\n )\n register(Chat(\"gpt-4\"), AsyncChat(\"gpt-4\"), aliases=(\"4\", \"gpt4\"))\n register(Chat(\"gpt-4-32k\"), AsyncChat(\"gpt-4-32k\"), aliases=(\"4-32k\",))\n # GPT-4 Turbo models\n register(Chat(\"gpt-4-1106-preview\"), AsyncChat(\"gpt-4-1106-preview\"))\n register(Chat(\"gpt-4-0125-preview\"), AsyncChat(\"gpt-4-0125-preview\"))\n register(Chat(\"gpt-4-turbo-2024-04-09\"), AsyncChat(\"gpt-4-turbo-2024-04-09\"))\n register(\n Chat(\"gpt-4-turbo\"),\n AsyncChat(\"gpt-4-turbo\"),\n aliases=(\"gpt-4-turbo-preview\", \"4-turbo\", \"4t\"),\n )\n # GPT-4.5\n register(\n Chat(\"gpt-4.5-preview-2025-02-27\", vision=True, supports_schema=True),\n AsyncChat(\"gpt-4.5-preview-2025-02-27\", vision=True, supports_schema=True),\n )\n register(\n Chat(\"gpt-4.5-preview\", vision=True, supports_schema=True),\n AsyncChat(\"gpt-4.5-preview\", vision=True, supports_schema=True),\n aliases=(\"gpt-4.5\",),\n )\n # o1\n for model_id in (\"o1\", \"o1-2024-12-17\"):\n register(\n Chat(\n model_id,\n vision=True,\n can_stream=False,\n reasoning=True,\n supports_schema=True,\n ),\n AsyncChat(\n model_id,\n vision=True,\n can_stream=False,\n reasoning=True,\n supports_schema=True,\n ),\n )\n\n register(\n Chat(\"o1-preview\", allows_system_prompt=False),\n AsyncChat(\"o1-preview\", allows_system_prompt=False),\n )\n register(\n Chat(\"o1-mini\", allows_system_prompt=False),\n AsyncChat(\"o1-mini\", allows_system_prompt=False),\n )\n register(\n Chat(\"o3-mini\", reasoning=True, supports_schema=True),\n AsyncChat(\"o3-mini\", reasoning=True, supports_schema=True),\n )\n # The -instruct completion model\n register(\n Completion(\"gpt-3.5-turbo-instruct\", default_max_tokens=256),\n aliases=(\"3.5-instruct\", \"chatgpt-instruct\"),\n )\n\n # Search models\n for model_id in (\"gpt-4o-search-preview\", \"gpt-4o-mini-search-preview\"):\n register(\n Chat(model_id, search_preview=True),\n AsyncChat(model_id, search_preview=True),\n )\n\n # Load extra models\n extra_path = llm.user_dir() / \"extra-openai-models.yaml\"\n if not extra_path.exists():\n return\n with open(extra_path) as f:\n extra_models = yaml.safe_load(f)\n for extra_model in extra_models:\n model_id = extra_model[\"model_id\"]\n aliases = extra_model.get(\"aliases\", [])\n model_name = extra_model[\"model_name\"]\n api_base = extra_model.get(\"api_base\")\n api_type = extra_model.get(\"api_type\")\n api_version = extra_model.get(\"api_version\")\n api_engine = extra_model.get(\"api_engine\")\n headers = extra_model.get(\"headers\")\n reasoning = extra_model.get(\"reasoning\")\n kwargs = {}\n if extra_model.get(\"can_stream\") is False:\n kwargs[\"can_stream\"] = False\n if extra_model.get(\"supports_schema\") is True:\n kwargs[\"supports_schema\"] = True\n if extra_model.get(\"vision\") is True:\n kwargs[\"vision\"] = True\n if extra_model.get(\"audio\") is True:\n kwargs[\"audio\"] = True\n if extra_model.get(\"completion\"):\n klass = Completion\n else:\n klass = Chat\n chat_model = klass(\n model_id,\n model_name=model_name,\n api_base=api_base,\n api_type=api_type,\n api_version=api_version,\n api_engine=api_engine,\n headers=headers,\n reasoning=reasoning,\n **kwargs,\n )\n if api_base:\n chat_model.needs_key = None\n if extra_model.get(\"api_key_name\"):\n chat_model.needs_key = extra_model[\"api_key_name\"]\n register(\n chat_model,\n aliases=aliases,\n )"} | |
{"id": "llm/default_plugins/openai_models.py:176", "code": "@hookimpl\ndef register_embedding_models(register):\n register(\n OpenAIEmbeddingModel(\"text-embedding-ada-002\", \"text-embedding-ada-002\"),\n aliases=(\n \"ada\",\n \"ada-002\",\n ),\n )\n register(\n OpenAIEmbeddingModel(\"text-embedding-3-small\", \"text-embedding-3-small\"),\n aliases=(\"3-small\",),\n )\n register(\n OpenAIEmbeddingModel(\"text-embedding-3-large\", \"text-embedding-3-large\"),\n aliases=(\"3-large\",),\n )\n # With varying dimensions\n register(\n OpenAIEmbeddingModel(\n \"text-embedding-3-small-512\", \"text-embedding-3-small\", 512\n ),\n aliases=(\"3-small-512\",),\n )\n register(\n OpenAIEmbeddingModel(\n \"text-embedding-3-large-256\", \"text-embedding-3-large\", 256\n ),\n aliases=(\"3-large-256\",),\n )\n register(\n OpenAIEmbeddingModel(\n \"text-embedding-3-large-1024\", \"text-embedding-3-large\", 1024\n ),\n aliases=(\"3-large-1024\",),\n )"} | |
{"id": "llm/default_plugins/openai_models.py:214", "code": "class OpenAIEmbeddingModel(EmbeddingModel):\n needs_key = \"openai\"\n key_env_var = \"OPENAI_API_KEY\"\n batch_size = 100\n\n def __init__(self, model_id, openai_model_id, dimensions=None):\n self.model_id = model_id\n self.openai_model_id = openai_model_id\n self.dimensions = dimensions\n\n def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:\n kwargs = {\n \"input\": items,\n \"model\": self.openai_model_id,\n }\n if self.dimensions:\n kwargs[\"dimensions\"] = self.dimensions\n client = openai.OpenAI(api_key=self.get_key())\n results = client.embeddings.create(**kwargs).data\n return ([float(r) for r in result.embedding] for result in results)"} | |
{"id": "llm/default_plugins/openai_models.py:219", "code": " def __init__(self, model_id, openai_model_id, dimensions=None):\n self.model_id = model_id\n self.openai_model_id = openai_model_id\n self.dimensions = dimensions"} | |
{"id": "llm/default_plugins/openai_models.py:224", "code": " def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:\n kwargs = {\n \"input\": items,\n \"model\": self.openai_model_id,\n }\n if self.dimensions:\n kwargs[\"dimensions\"] = self.dimensions\n client = openai.OpenAI(api_key=self.get_key())\n results = client.embeddings.create(**kwargs).data\n return ([float(r) for r in result.embedding] for result in results)"} | |
{"id": "llm/default_plugins/openai_models.py:236", "code": "@hookimpl\ndef register_commands(cli):\n @cli.group(name=\"openai\")\n def openai_():\n \"Commands for working directly with the OpenAI API\"\n\n @openai_.command()\n @click.option(\"json_\", \"--json\", is_flag=True, help=\"Output as JSON\")\n @click.option(\"--key\", help=\"OpenAI API key\")\n def models(json_, key):\n \"List models available to you from the OpenAI API\"\n from llm import get_key\n\n api_key = get_key(key, \"openai\", \"OPENAI_API_KEY\")\n response = httpx.get(\n \"https://api.openai.com/v1/models\",\n headers={\"Authorization\": f\"Bearer {api_key}\"},\n )\n if response.status_code != 200:\n raise click.ClickException(\n f\"Error {response.status_code} from OpenAI API: {response.text}\"\n )\n models = response.json()[\"data\"]\n if json_:\n click.echo(json.dumps(models, indent=4))\n else:\n to_print = []\n for model in models:\n # Print id, owned_by, root, created as ISO 8601\n created_str = datetime.datetime.fromtimestamp(\n model[\"created\"], datetime.timezone.utc\n ).isoformat()\n to_print.append(\n {\n \"id\": model[\"id\"],\n \"owned_by\": model[\"owned_by\"],\n \"created\": created_str,\n }\n )\n done = dicts_to_table_string(\"id owned_by created\".split(), to_print)\n print(\"\\n\".join(done))"} | |
{"id": "llm/default_plugins/openai_models.py:279", "code": "class SharedOptions(llm.Options):\n temperature: Optional[float] = Field(\n description=(\n \"What sampling temperature to use, between 0 and 2. Higher values like \"\n \"0.8 will make the output more random, while lower values like 0.2 will \"\n \"make it more focused and deterministic.\"\n ),\n ge=0,\n le=2,\n default=None,\n )\n max_tokens: Optional[int] = Field(\n description=\"Maximum number of tokens to generate.\", default=None\n )\n top_p: Optional[float] = Field(\n description=(\n \"An alternative to sampling with temperature, called nucleus sampling, \"\n \"where the model considers the results of the tokens with top_p \"\n \"probability mass. So 0.1 means only the tokens comprising the top \"\n \"10% probability mass are considered. Recommended to use top_p or \"\n \"temperature but not both.\"\n ),\n ge=0,\n le=1,\n default=None,\n )\n frequency_penalty: Optional[float] = Field(\n description=(\n \"Number between -2.0 and 2.0. Positive values penalize new tokens based \"\n \"on their existing frequency in the text so far, decreasing the model's \"\n \"likelihood to repeat the same line verbatim.\"\n ),\n ge=-2,\n le=2,\n default=None,\n )\n presence_penalty: Optional[float] = Field(\n description=(\n \"Number between -2.0 and 2.0. Positive values penalize new tokens based \"\n \"on whether they appear in the text so far, increasing the model's \"\n \"likelihood to talk about new topics.\"\n ),\n ge=-2,\n le=2,\n default=None,\n )\n stop: Optional[str] = Field(\n description=(\"A string where the API will stop generating further tokens.\"),\n default=None,\n )\n logit_bias: Optional[Union[dict, str]] = Field(\n description=(\n \"Modify the likelihood of specified tokens appearing in the completion. \"\n 'Pass a JSON string like \\'{\"1712\":-100, \"892\":-100, \"1489\":-100}\\''\n ),\n default=None,\n )\n seed: Optional[int] = Field(\n description=\"Integer seed to attempt to sample deterministically\",\n default=None,\n )\n\n @field_validator(\"logit_bias\")\n def validate_logit_bias(cls, logit_bias):\n if logit_bias is None:\n return None\n\n if isinstance(logit_bias, str):\n try:\n logit_bias = json.loads(logit_bias)\n except json.JSONDecodeError:\n raise ValueError(\"Invalid JSON in logit_bias string\")\n\n validated_logit_bias = {}\n for key, value in logit_bias.items():\n try:\n int_key = int(key)\n int_value = int(value)\n if -100 <= int_value <= 100:\n validated_logit_bias[int_key] = int_value\n else:\n raise ValueError(\"Value must be between -100 and 100\")\n except ValueError:\n raise ValueError(\"Invalid key-value pair in logit_bias dictionary\")\n\n return validated_logit_bias"} | |
{"id": "llm/default_plugins/openai_models.py:341", "code": " @field_validator(\"logit_bias\")\n def validate_logit_bias(cls, logit_bias):\n if logit_bias is None:\n return None\n\n if isinstance(logit_bias, str):\n try:\n logit_bias = json.loads(logit_bias)\n except json.JSONDecodeError:\n raise ValueError(\"Invalid JSON in logit_bias string\")\n\n validated_logit_bias = {}\n for key, value in logit_bias.items():\n try:\n int_key = int(key)\n int_value = int(value)\n if -100 <= int_value <= 100:\n validated_logit_bias[int_key] = int_value\n else:\n raise ValueError(\"Value must be between -100 and 100\")\n except ValueError:\n raise ValueError(\"Invalid key-value pair in logit_bias dictionary\")\n\n return validated_logit_bias"} | |
{"id": "llm/default_plugins/openai_models.py:367", "code": "class LowMediumHighEnum(str, Enum):\n low = \"low\"\n medium = \"medium\"\n high = \"high\""} | |
{"id": "llm/default_plugins/openai_models.py:373", "code": "class OptionsForReasoning(SharedOptions):\n json_object: Optional[bool] = Field(\n description=\"Output a valid JSON object {...}. Prompt must mention JSON.\",\n default=None,\n )\n reasoning_effort: Optional[LowMediumHighEnum] = Field(\n description=(\n \"Constraints effort on reasoning for reasoning models. Currently supported \"\n \"values are low, medium, and high. Reducing reasoning effort can result in \"\n \"faster responses and fewer tokens used on reasoning in a response.\"\n ),\n default=None,\n )"} | |
{"id": "llm/default_plugins/openai_models.py:388", "code": "class OptionsForSearchPreview(SharedOptions):\n search_context_size: Optional[LowMediumHighEnum] = Field(\n description=(\n \"How much context is retrieved from the web to help the tool formulate a response\"\n ),\n default=None,\n )"} | |
{"id": "llm/default_plugins/openai_models.py:397", "code": "def _attachment(attachment):\n url = attachment.url\n base64_content = \"\"\n if not url or attachment.resolve_type().startswith(\"audio/\"):\n base64_content = attachment.base64_content()\n url = f\"data:{attachment.resolve_type()};base64,{base64_content}\"\n if attachment.resolve_type() == \"application/pdf\":\n if not base64_content:\n base64_content = attachment.base64_content()\n return {\n \"type\": \"file\",\n \"file\": {\n \"filename\": f\"{attachment.id()}.pdf\",\n \"file_data\": f\"data:application/pdf;base64,{base64_content}\",\n },\n }\n if attachment.resolve_type().startswith(\"image/\"):\n return {\"type\": \"image_url\", \"image_url\": {\"url\": url}}\n else:\n format_ = \"wav\" if attachment.resolve_type() == \"audio/wav\" else \"mp3\"\n return {\n \"type\": \"input_audio\",\n \"input_audio\": {\n \"data\": base64_content,\n \"format\": format_,\n },\n }"} | |
{"id": "llm/default_plugins/openai_models.py:426", "code": "class _Shared:\n def __init__(\n self,\n model_id,\n key=None,\n model_name=None,\n api_base=None,\n api_type=None,\n api_version=None,\n api_engine=None,\n headers=None,\n can_stream=True,\n vision=False,\n audio=False,\n reasoning=False,\n supports_schema=False,\n allows_system_prompt=True,\n search_preview=False,\n ):\n self.model_id = model_id\n self.key = key\n self.supports_schema = supports_schema\n self.model_name = model_name\n self.api_base = api_base\n self.api_type = api_type\n self.api_version = api_version\n self.api_engine = api_engine\n self.headers = headers\n self.can_stream = can_stream\n self.vision = vision\n self.allows_system_prompt = allows_system_prompt\n self.search_preview = search_preview\n\n self.attachment_types = set()\n\n if reasoning:\n self.Options = OptionsForReasoning\n\n if search_preview:\n self.Options = OptionsForSearchPreview\n\n if vision:\n self.attachment_types.update(\n {\n \"image/png\",\n \"image/jpeg\",\n \"image/webp\",\n \"image/gif\",\n \"application/pdf\",\n }\n )\n\n if audio:\n self.attachment_types.update(\n {\n \"audio/wav\",\n \"audio/mpeg\",\n }\n )\n\n def __str__(self):\n return \"OpenAI Chat: {}\".format(self.model_id)\n\n def build_messages(self, prompt, conversation):\n messages = []\n current_system = None\n if conversation is not None:\n for prev_response in conversation.responses:\n if (\n prev_response.prompt.system\n and prev_response.prompt.system != current_system\n ):\n messages.append(\n {\"role\": \"system\", \"content\": prev_response.prompt.system}\n )\n current_system = prev_response.prompt.system\n if prev_response.attachments:\n attachment_message = []\n if prev_response.prompt.prompt:\n attachment_message.append(\n {\"type\": \"text\", \"text\": prev_response.prompt.prompt}\n )\n for attachment in prev_response.attachments:\n attachment_message.append(_attachment(attachment))\n messages.append({\"role\": \"user\", \"content\": attachment_message})\n else:\n messages.append(\n {\"role\": \"user\", \"content\": prev_response.prompt.prompt}\n )\n messages.append(\n {\"role\": \"assistant\", \"content\": prev_response.text_or_raise()}\n )\n if prompt.system and prompt.system != current_system:\n messages.append({\"role\": \"system\", \"content\": prompt.system})\n if not prompt.attachments:\n messages.append({\"role\": \"user\", \"content\": prompt.prompt or \"\"})\n else:\n attachment_message = []\n if prompt.prompt:\n attachment_message.append({\"type\": \"text\", \"text\": prompt.prompt})\n for attachment in prompt.attachments:\n attachment_message.append(_attachment(attachment))\n messages.append({\"role\": \"user\", \"content\": attachment_message})\n return messages\n\n def set_usage(self, response, usage):\n if not usage:\n return\n input_tokens = usage.pop(\"prompt_tokens\")\n output_tokens = usage.pop(\"completion_tokens\")\n usage.pop(\"total_tokens\")\n response.set_usage(\n input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)\n )\n\n def set_annotations(self, response, annotations: list):\n # Annotation(type='url_citation', url_citation=AnnotationURLCitation(\n # end_index=358, start_index=284, title='...', url='https://...'))\n to_add = []\n for annotation in annotations:\n if annotation[\"type\"] == \"url_citation\":\n data = annotation[\"url_citation\"]\n start_index = data.pop(\"start_index\")\n end_index = data.pop(\"end_index\")\n to_add.append(\n llm.Annotation(\n start_index=start_index, end_index=end_index, data=data\n )\n )\n response.add_annotations(to_add)\n\n def get_client(self, key, *, async_=False):\n kwargs = {}\n if self.api_base:\n kwargs[\"base_url\"] = self.api_base\n if self.api_type:\n kwargs[\"api_type\"] = self.api_type\n if self.api_version:\n kwargs[\"api_version\"] = self.api_version\n if self.api_engine:\n kwargs[\"engine\"] = self.api_engine\n if self.needs_key:\n kwargs[\"api_key\"] = self.get_key(key)\n else:\n # OpenAI-compatible models don't need a key, but the\n # openai client library requires one\n kwargs[\"api_key\"] = \"DUMMY_KEY\"\n if self.headers:\n kwargs[\"default_headers\"] = self.headers\n if os.environ.get(\"LLM_OPENAI_SHOW_RESPONSES\"):\n kwargs[\"http_client\"] = logging_client()\n if async_:\n return openai.AsyncOpenAI(**kwargs)\n else:\n return openai.OpenAI(**kwargs)\n\n def build_kwargs(self, prompt, stream):\n kwargs = dict(not_nulls(prompt.options))\n json_object = kwargs.pop(\"json_object\", None)\n if \"max_tokens\" not in kwargs and self.default_max_tokens is not None:\n kwargs[\"max_tokens\"] = self.default_max_tokens\n if json_object:\n kwargs[\"response_format\"] = {\"type\": \"json_object\"}\n if prompt.schema:\n kwargs[\"response_format\"] = {\n \"type\": \"json_schema\",\n \"json_schema\": {\"name\": \"output\", \"schema\": prompt.schema},\n }\n if stream:\n kwargs[\"stream_options\"] = {\"include_usage\": True}\n if self.search_preview:\n kwargs[\"web_search_options\"] = {}\n if prompt.options.search_context_size:\n kwargs.pop(\"search_context_size\", None)\n kwargs[\"web_search_options\"][\n \"search_context_size\"\n ] = prompt.options.search_context_size\n return kwargs"} | |
{"id": "llm/default_plugins/openai_models.py:427", "code": " def __init__(\n self,\n model_id,\n key=None,\n model_name=None,\n api_base=None,\n api_type=None,\n api_version=None,\n api_engine=None,\n headers=None,\n can_stream=True,\n vision=False,\n audio=False,\n reasoning=False,\n supports_schema=False,\n allows_system_prompt=True,\n search_preview=False,\n ):\n self.model_id = model_id\n self.key = key\n self.supports_schema = supports_schema\n self.model_name = model_name\n self.api_base = api_base\n self.api_type = api_type\n self.api_version = api_version\n self.api_engine = api_engine\n self.headers = headers\n self.can_stream = can_stream\n self.vision = vision\n self.allows_system_prompt = allows_system_prompt\n self.search_preview = search_preview\n\n self.attachment_types = set()\n\n if reasoning:\n self.Options = OptionsForReasoning\n\n if search_preview:\n self.Options = OptionsForSearchPreview\n\n if vision:\n self.attachment_types.update(\n {\n \"image/png\",\n \"image/jpeg\",\n \"image/webp\",\n \"image/gif\",\n \"application/pdf\",\n }\n )\n\n if audio:\n self.attachment_types.update(\n {\n \"audio/wav\",\n \"audio/mpeg\",\n }\n )"} | |
{"id": "llm/default_plugins/openai_models.py:486", "code": " def __str__(self):\n return \"OpenAI Chat: {}\".format(self.model_id)"} | |
{"id": "llm/default_plugins/openai_models.py:489", "code": " def build_messages(self, prompt, conversation):\n messages = []\n current_system = None\n if conversation is not None:\n for prev_response in conversation.responses:\n if (\n prev_response.prompt.system\n and prev_response.prompt.system != current_system\n ):\n messages.append(\n {\"role\": \"system\", \"content\": prev_response.prompt.system}\n )\n current_system = prev_response.prompt.system\n if prev_response.attachments:\n attachment_message = []\n if prev_response.prompt.prompt:\n attachment_message.append(\n {\"type\": \"text\", \"text\": prev_response.prompt.prompt}\n )\n for attachment in prev_response.attachments:\n attachment_message.append(_attachment(attachment))\n messages.append({\"role\": \"user\", \"content\": attachment_message})\n else:\n messages.append(\n {\"role\": \"user\", \"content\": prev_response.prompt.prompt}\n )\n messages.append(\n {\"role\": \"assistant\", \"content\": prev_response.text_or_raise()}\n )\n if prompt.system and prompt.system != current_system:\n messages.append({\"role\": \"system\", \"content\": prompt.system})\n if not prompt.attachments:\n messages.append({\"role\": \"user\", \"content\": prompt.prompt or \"\"})\n else:\n attachment_message = []\n if prompt.prompt:\n attachment_message.append({\"type\": \"text\", \"text\": prompt.prompt})\n for attachment in prompt.attachments:\n attachment_message.append(_attachment(attachment))\n messages.append({\"role\": \"user\", \"content\": attachment_message})\n return messages"} | |
{"id": "llm/default_plugins/openai_models.py:531", "code": " def set_usage(self, response, usage):\n if not usage:\n return\n input_tokens = usage.pop(\"prompt_tokens\")\n output_tokens = usage.pop(\"completion_tokens\")\n usage.pop(\"total_tokens\")\n response.set_usage(\n input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)\n )"} | |
{"id": "llm/default_plugins/openai_models.py:541", "code": " def set_annotations(self, response, annotations: list):\n # Annotation(type='url_citation', url_citation=AnnotationURLCitation(\n # end_index=358, start_index=284, title='...', url='https://...'))\n to_add = []\n for annotation in annotations:\n if annotation[\"type\"] == \"url_citation\":\n data = annotation[\"url_citation\"]\n start_index = data.pop(\"start_index\")\n end_index = data.pop(\"end_index\")\n to_add.append(\n llm.Annotation(\n start_index=start_index, end_index=end_index, data=data\n )\n )\n response.add_annotations(to_add)"} | |
{"id": "llm/default_plugins/openai_models.py:557", "code": " def get_client(self, key, *, async_=False):\n kwargs = {}\n if self.api_base:\n kwargs[\"base_url\"] = self.api_base\n if self.api_type:\n kwargs[\"api_type\"] = self.api_type\n if self.api_version:\n kwargs[\"api_version\"] = self.api_version\n if self.api_engine:\n kwargs[\"engine\"] = self.api_engine\n if self.needs_key:\n kwargs[\"api_key\"] = self.get_key(key)\n else:\n # OpenAI-compatible models don't need a key, but the\n # openai client library requires one\n kwargs[\"api_key\"] = \"DUMMY_KEY\"\n if self.headers:\n kwargs[\"default_headers\"] = self.headers\n if os.environ.get(\"LLM_OPENAI_SHOW_RESPONSES\"):\n kwargs[\"http_client\"] = logging_client()\n if async_:\n return openai.AsyncOpenAI(**kwargs)\n else:\n return openai.OpenAI(**kwargs)"} | |
{"id": "llm/default_plugins/openai_models.py:582", "code": " def build_kwargs(self, prompt, stream):\n kwargs = dict(not_nulls(prompt.options))\n json_object = kwargs.pop(\"json_object\", None)\n if \"max_tokens\" not in kwargs and self.default_max_tokens is not None:\n kwargs[\"max_tokens\"] = self.default_max_tokens\n if json_object:\n kwargs[\"response_format\"] = {\"type\": \"json_object\"}\n if prompt.schema:\n kwargs[\"response_format\"] = {\n \"type\": \"json_schema\",\n \"json_schema\": {\"name\": \"output\", \"schema\": prompt.schema},\n }\n if stream:\n kwargs[\"stream_options\"] = {\"include_usage\": True}\n if self.search_preview:\n kwargs[\"web_search_options\"] = {}\n if prompt.options.search_context_size:\n kwargs.pop(\"search_context_size\", None)\n kwargs[\"web_search_options\"][\n \"search_context_size\"\n ] = prompt.options.search_context_size\n return kwargs"} | |
{"id": "llm/default_plugins/openai_models.py:606", "code": "class Chat(_Shared, KeyModel):\n needs_key = \"openai\"\n key_env_var = \"OPENAI_API_KEY\"\n default_max_tokens = None\n\n class Options(SharedOptions):\n json_object: Optional[bool] = Field(\n description=\"Output a valid JSON object {...}. Prompt must mention JSON.\",\n default=None,\n )\n\n def execute(self, prompt, stream, response, conversation=None, key=None):\n if prompt.system and not self.allows_system_prompt:\n raise NotImplementedError(\"Model does not support system prompts\")\n messages = self.build_messages(prompt, conversation)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key)\n usage = None\n annotations = []\n if stream:\n completion = client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=True,\n **kwargs,\n )\n chunks = []\n for chunk in completion:\n chunks.append(chunk)\n try:\n annotations.extend(chunk.choices[0].delta.annotations)\n except (AttributeError, IndexError):\n pass\n if chunk.usage:\n usage = chunk.usage.model_dump()\n try:\n content = chunk.choices[0].delta.content\n except IndexError:\n content = None\n if content is not None:\n yield content\n final_json = remove_dict_none_values(combine_chunks(chunks))\n if annotations:\n final_json[\"annotations\"] = annotations\n self.set_annotations(response, annotations)\n response.response_json = final_json\n else:\n completion = client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=False,\n **kwargs,\n )\n usage = completion.usage.model_dump()\n response.response_json = remove_dict_none_values(completion.model_dump())\n yield completion.choices[0].message.content\n try:\n if completion.choices[0].message.annotations:\n self.set_annotations(\n response, completion.choices[0].message.annotations\n )\n except AttributeError:\n pass\n self.set_usage(response, usage)\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:617", "code": " def execute(self, prompt, stream, response, conversation=None, key=None):\n if prompt.system and not self.allows_system_prompt:\n raise NotImplementedError(\"Model does not support system prompts\")\n messages = self.build_messages(prompt, conversation)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key)\n usage = None\n annotations = []\n if stream:\n completion = client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=True,\n **kwargs,\n )\n chunks = []\n for chunk in completion:\n chunks.append(chunk)\n try:\n annotations.extend(chunk.choices[0].delta.annotations)\n except (AttributeError, IndexError):\n pass\n if chunk.usage:\n usage = chunk.usage.model_dump()\n try:\n content = chunk.choices[0].delta.content\n except IndexError:\n content = None\n if content is not None:\n yield content\n final_json = remove_dict_none_values(combine_chunks(chunks))\n if annotations:\n final_json[\"annotations\"] = annotations\n self.set_annotations(response, annotations)\n response.response_json = final_json\n else:\n completion = client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=False,\n **kwargs,\n )\n usage = completion.usage.model_dump()\n response.response_json = remove_dict_none_values(completion.model_dump())\n yield completion.choices[0].message.content\n try:\n if completion.choices[0].message.annotations:\n self.set_annotations(\n response, completion.choices[0].message.annotations\n )\n except AttributeError:\n pass\n self.set_usage(response, usage)\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:673", "code": "class AsyncChat(_Shared, AsyncKeyModel):\n needs_key = \"openai\"\n key_env_var = \"OPENAI_API_KEY\"\n default_max_tokens = None\n\n class Options(SharedOptions):\n json_object: Optional[bool] = Field(\n description=\"Output a valid JSON object {...}. Prompt must mention JSON.\",\n default=None,\n )\n\n async def execute(\n self, prompt, stream, response, conversation=None, key=None\n ) -> AsyncGenerator[str, None]:\n if prompt.system and not self.allows_system_prompt:\n raise NotImplementedError(\"Model does not support system prompts\")\n messages = self.build_messages(prompt, conversation)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key, async_=True)\n usage = None\n if stream:\n completion = await client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=True,\n **kwargs,\n )\n chunks = []\n async for chunk in completion:\n if chunk.usage:\n usage = chunk.usage.model_dump()\n chunks.append(chunk)\n try:\n content = chunk.choices[0].delta.content\n except IndexError:\n content = None\n if content is not None:\n yield content\n response.response_json = remove_dict_none_values(combine_chunks(chunks))\n else:\n completion = await client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=False,\n **kwargs,\n )\n response.response_json = remove_dict_none_values(completion.model_dump())\n usage = completion.usage.model_dump()\n yield completion.choices[0].message.content\n self.set_usage(response, usage)\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:684", "code": " async def execute(\n self, prompt, stream, response, conversation=None, key=None\n ) -> AsyncGenerator[str, None]:\n if prompt.system and not self.allows_system_prompt:\n raise NotImplementedError(\"Model does not support system prompts\")\n messages = self.build_messages(prompt, conversation)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key, async_=True)\n usage = None\n if stream:\n completion = await client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=True,\n **kwargs,\n )\n chunks = []\n async for chunk in completion:\n if chunk.usage:\n usage = chunk.usage.model_dump()\n chunks.append(chunk)\n try:\n content = chunk.choices[0].delta.content\n except IndexError:\n content = None\n if content is not None:\n yield content\n response.response_json = remove_dict_none_values(combine_chunks(chunks))\n else:\n completion = await client.chat.completions.create(\n model=self.model_name or self.model_id,\n messages=messages,\n stream=False,\n **kwargs,\n )\n response.response_json = remove_dict_none_values(completion.model_dump())\n usage = completion.usage.model_dump()\n yield completion.choices[0].message.content\n self.set_usage(response, usage)\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:726", "code": "class Completion(Chat):\n class Options(SharedOptions):\n logprobs: Optional[int] = Field(\n description=\"Include the log probabilities of most likely N per token\",\n default=None,\n le=5,\n )\n\n def __init__(self, *args, default_max_tokens=None, **kwargs):\n super().__init__(*args, **kwargs)\n self.default_max_tokens = default_max_tokens\n\n def __str__(self):\n return \"OpenAI Completion: {}\".format(self.model_id)\n\n def execute(self, prompt, stream, response, conversation=None, key=None):\n if prompt.system:\n raise NotImplementedError(\n \"System prompts are not supported for OpenAI completion models\"\n )\n messages = []\n if conversation is not None:\n for prev_response in conversation.responses:\n messages.append(prev_response.prompt.prompt)\n messages.append(prev_response.text())\n messages.append(prompt.prompt)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key)\n if stream:\n completion = client.completions.create(\n model=self.model_name or self.model_id,\n prompt=\"\\n\".join(messages),\n stream=True,\n **kwargs,\n )\n chunks = []\n for chunk in completion:\n chunks.append(chunk)\n try:\n content = chunk.choices[0].text\n except IndexError:\n content = None\n if content is not None:\n yield content\n combined = combine_chunks(chunks)\n cleaned = remove_dict_none_values(combined)\n response.response_json = cleaned\n else:\n completion = client.completions.create(\n model=self.model_name or self.model_id,\n prompt=\"\\n\".join(messages),\n stream=False,\n **kwargs,\n )\n response.response_json = remove_dict_none_values(completion.model_dump())\n yield completion.choices[0].text\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:734", "code": " def __init__(self, *args, default_max_tokens=None, **kwargs):\n super().__init__(*args, **kwargs)\n self.default_max_tokens = default_max_tokens"} | |
{"id": "llm/default_plugins/openai_models.py:738", "code": " def __str__(self):\n return \"OpenAI Completion: {}\".format(self.model_id)"} | |
{"id": "llm/default_plugins/openai_models.py:741", "code": " def execute(self, prompt, stream, response, conversation=None, key=None):\n if prompt.system:\n raise NotImplementedError(\n \"System prompts are not supported for OpenAI completion models\"\n )\n messages = []\n if conversation is not None:\n for prev_response in conversation.responses:\n messages.append(prev_response.prompt.prompt)\n messages.append(prev_response.text())\n messages.append(prompt.prompt)\n kwargs = self.build_kwargs(prompt, stream)\n client = self.get_client(key)\n if stream:\n completion = client.completions.create(\n model=self.model_name or self.model_id,\n prompt=\"\\n\".join(messages),\n stream=True,\n **kwargs,\n )\n chunks = []\n for chunk in completion:\n chunks.append(chunk)\n try:\n content = chunk.choices[0].text\n except IndexError:\n content = None\n if content is not None:\n yield content\n combined = combine_chunks(chunks)\n cleaned = remove_dict_none_values(combined)\n response.response_json = cleaned\n else:\n completion = client.completions.create(\n model=self.model_name or self.model_id,\n prompt=\"\\n\".join(messages),\n stream=False,\n **kwargs,\n )\n response.response_json = remove_dict_none_values(completion.model_dump())\n yield completion.choices[0].text\n response._prompt_json = redact_data({\"messages\": messages})"} | |
{"id": "llm/default_plugins/openai_models.py:785", "code": "def not_nulls(data) -> dict:\n return {key: value for key, value in data if value is not None}"} | |
{"id": "llm/default_plugins/openai_models.py:789", "code": "def combine_chunks(chunks: List) -> dict:\n content = \"\"\n role = None\n finish_reason = None\n # If any of them have log probability, we're going to persist\n # those later on\n logprobs = []\n usage = {}\n\n for item in chunks:\n if item.usage:\n usage = item.usage.dict()\n for choice in item.choices:\n if choice.logprobs and hasattr(choice.logprobs, \"top_logprobs\"):\n logprobs.append(\n {\n \"text\": choice.text if hasattr(choice, \"text\") else None,\n \"top_logprobs\": choice.logprobs.top_logprobs,\n }\n )\n\n if not hasattr(choice, \"delta\"):\n content += choice.text\n continue\n role = choice.delta.role\n if choice.delta.content is not None:\n content += choice.delta.content\n if choice.finish_reason is not None:\n finish_reason = choice.finish_reason\n\n # Imitations of the OpenAI API may be missing some of these fields\n combined = {\n \"content\": content,\n \"role\": role,\n \"finish_reason\": finish_reason,\n \"usage\": usage,\n }\n if logprobs:\n combined[\"logprobs\"] = logprobs\n if chunks:\n for key in (\"id\", \"object\", \"model\", \"created\", \"index\"):\n value = getattr(chunks[0], key, None)\n if value is not None:\n combined[key] = value\n\n return combined"} | |
{"id": "llm/default_plugins/openai_models.py:837", "code": "def redact_data(input_dict):\n \"\"\"\n Recursively search through the input dictionary for any 'image_url' keys\n and modify the 'url' value to be just 'data:...'.\n\n Also redact input_audio.data keys\n \"\"\"\n if isinstance(input_dict, dict):\n for key, value in input_dict.items():\n if (\n key == \"image_url\"\n and isinstance(value, dict)\n and \"url\" in value\n and value[\"url\"].startswith(\"data:\")\n ):\n value[\"url\"] = \"data:...\"\n elif key == \"input_audio\" and isinstance(value, dict) and \"data\" in value:\n value[\"data\"] = \"...\"\n else:\n redact_data(value)\n elif isinstance(input_dict, list):\n for item in input_dict:\n redact_data(item)\n return input_dict"} | |
{"id": "tests/test_cli_openai_models.py:7", "code": "@pytest.fixture\ndef mocked_models(httpx_mock):\n httpx_mock.add_response(\n method=\"GET\",\n url=\"https://api.openai.com/v1/models\",\n json={\n \"data\": [\n {\n \"id\": \"ada:2020-05-03\",\n \"object\": \"model\",\n \"created\": 1588537600,\n \"owned_by\": \"openai\",\n },\n {\n \"id\": \"babbage:2020-05-03\",\n \"object\": \"model\",\n \"created\": 1588537600,\n \"owned_by\": \"openai\",\n },\n ]\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/test_cli_openai_models.py:33", "code": "def test_openai_models(mocked_models):\n runner = CliRunner()\n result = runner.invoke(cli, [\"openai\", \"models\", \"--key\", \"x\"])\n assert result.exit_code == 0\n assert result.output == (\n \"id owned_by created \\n\"\n \"ada:2020-05-03 openai 2020-05-03T20:26:40+00:00\\n\"\n \"babbage:2020-05-03 openai 2020-05-03T20:26:40+00:00\\n\"\n )"} | |
{"id": "tests/test_cli_openai_models.py:44", "code": "def test_openai_options_min_max():\n options = {\n \"temperature\": [0, 2],\n \"top_p\": [0, 1],\n \"frequency_penalty\": [-2, 2],\n \"presence_penalty\": [-2, 2],\n }\n runner = CliRunner()\n\n for option, [min_val, max_val] in options.items():\n result = runner.invoke(cli, [\"-m\", \"chatgpt\", \"-o\", option, \"-10\"])\n assert result.exit_code == 1\n assert f\"greater than or equal to {min_val}\" in result.output\n result2 = runner.invoke(cli, [\"-m\", \"chatgpt\", \"-o\", option, \"10\"])\n assert result2.exit_code == 1\n assert f\"less than or equal to {max_val}\" in result2.output"} | |
{"id": "tests/test_cli_openai_models.py:62", "code": "@pytest.mark.parametrize(\"model\", (\"gpt-4o-mini\", \"gpt-4o-audio-preview\"))\n@pytest.mark.parametrize(\"filetype\", (\"mp3\", \"wav\"))\ndef test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype):\n httpx_mock.add_response(\n method=\"HEAD\",\n url=f\"https://www.example.com/example.{filetype}\",\n content=b\"binary-data\",\n headers={\"Content-Type\": \"audio/mpeg\" if filetype == \"mp3\" else \"audio/wav\"},\n )\n if model == \"gpt-4o-audio-preview\":\n httpx_mock.add_response(\n method=\"POST\",\n # chat completion request\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"id\": \"chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh\",\n \"object\": \"chat.completion\",\n \"created\": 1730871958,\n \"model\": \"gpt-4o-audio-preview-2024-10-01\",\n \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": \"assistant\",\n \"content\": \"Why did the pelican get kicked out of the restaurant?\\n\\nBecause he had a big bill and no way to pay it!\",\n \"refusal\": None,\n },\n \"finish_reason\": \"stop\",\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 55,\n \"completion_tokens\": 25,\n \"total_tokens\": 80,\n \"prompt_tokens_details\": {\n \"cached_tokens\": 0,\n \"audio_tokens\": 44,\n \"text_tokens\": 11,\n \"image_tokens\": 0,\n },\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"text_tokens\": 25,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0,\n },\n },\n \"system_fingerprint\": \"fp_49254d0e9b\",\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n httpx_mock.add_response(\n method=\"GET\",\n url=f\"https://www.example.com/example.{filetype}\",\n content=b\"binary-data\",\n headers={\n \"Content-Type\": \"audio/mpeg\" if filetype == \"mp3\" else \"audio/wav\"\n },\n )\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"-m\",\n model,\n \"-a\",\n f\"https://www.example.com/example.{filetype}\",\n \"--no-stream\",\n \"--key\",\n \"x\",\n ],\n )\n if model == \"gpt-4o-audio-preview\":\n assert result.exit_code == 0\n assert result.output == (\n \"Why did the pelican get kicked out of the restaurant?\\n\\n\"\n \"Because he had a big bill and no way to pay it!\\n\"\n )\n else:\n assert result.exit_code == 1\n long = \"audio/mpeg\" if filetype == \"mp3\" else \"audio/wav\"\n assert (\n f\"This model does not support attachments of type '{long}'\" in result.output\n )"} | |
{"id": "tests/test_cli_openai_models.py:149", "code": "@pytest.mark.parametrize(\"async_\", (False, True))\n@pytest.mark.parametrize(\"usage\", (None, \"-u\", \"--usage\"))\ndef test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage):\n user_path = tmpdir / \"user_dir\"\n log_db = user_path / \"logs.db\"\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n assert not log_db.exists()\n httpx_mock.add_response(\n method=\"POST\",\n # chat completion request\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"id\": \"chatcmpl-AQT9a30kxEaM1bqxRPepQsPlCyGJh\",\n \"object\": \"chat.completion\",\n \"created\": 1730871958,\n \"model\": \"gpt-4o-mini\",\n \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": \"assistant\",\n \"content\": \"Ho ho ho\",\n \"refusal\": None,\n },\n \"finish_reason\": \"stop\",\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 1000,\n \"completion_tokens\": 2000,\n \"total_tokens\": 12,\n },\n \"system_fingerprint\": \"fp_49254d0e9b\",\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n runner = CliRunner(mix_stderr=False)\n args = [\"-m\", \"gpt-4o-mini\", \"--key\", \"x\", \"--no-stream\"]\n if usage:\n args.append(usage)\n if async_:\n args.append(\"--async\")\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == \"Ho ho ho\\n\"\n if usage:\n assert result.stderr == \"Token usage: 1,000 input, 2,000 output\\n\"\n # Confirm it was correctly logged\n assert log_db.exists()\n db = sqlite_utils.Database(str(log_db))\n assert db[\"responses\"].count == 1\n row = next(db[\"responses\"].rows)\n assert row[\"response\"] == \"Ho ho ho\""} | |
{"id": "tests/test_utils.py:10", "code": "@pytest.mark.parametrize(\n \"input_data,expected_output\",\n [\n (\n {\n \"prompt_tokens_details\": {\"cached_tokens\": 0, \"audio_tokens\": 0},\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 1,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0,\n },\n },\n {\"completion_tokens_details\": {\"audio_tokens\": 1}},\n ),\n (\n {\n \"details\": {\"tokens\": 5, \"audio_tokens\": 2},\n \"more_details\": {\"accepted_tokens\": 3},\n },\n {\n \"details\": {\"tokens\": 5, \"audio_tokens\": 2},\n \"more_details\": {\"accepted_tokens\": 3},\n },\n ),\n ({\"details\": {\"tokens\": 0, \"audio_tokens\": 0}, \"more_details\": {}}, {}),\n ({\"level1\": {\"level2\": {\"value\": 0, \"another_value\": {}}}}, {}),\n (\n {\n \"level1\": {\"level2\": {\"value\": 0, \"another_value\": 1}},\n \"level3\": {\"empty_dict\": {}, \"valid_token\": 10},\n },\n {\"level1\": {\"level2\": {\"another_value\": 1}}, \"level3\": {\"valid_token\": 10}},\n ),\n ],\n)\ndef test_simplify_usage_dict(input_data, expected_output):\n # This utility function is used by at least one plugin - llm-openai-plugin\n assert simplify_usage_dict(input_data) == expected_output"} | |
{"id": "tests/test_utils.py:51", "code": "@pytest.mark.parametrize(\n \"input,last,expected\",\n [\n [\"This is a sample text without any code blocks.\", False, None],\n [\n \"Here is some text.\\n\\n```\\ndef foo():\\n return 'bar'\\n```\\n\\nMore text.\",\n False,\n \"def foo():\\n return 'bar'\\n\",\n ],\n [\n \"Here is some text.\\n\\n```python\\ndef foo():\\n return 'bar'\\n```\\n\\nMore text.\",\n False,\n \"def foo():\\n return 'bar'\\n\",\n ],\n [\n \"Here is some text.\\n\\n````\\ndef foo():\\n return 'bar'\\n````\\n\\nMore text.\",\n False,\n \"def foo():\\n return 'bar'\\n\",\n ],\n [\n \"Here is some text.\\n\\n````javascript\\nfunction foo() {\\n return 'bar';\\n}\\n````\\n\\nMore text.\",\n False,\n \"function foo() {\\n return 'bar';\\n}\\n\",\n ],\n [\n \"Here is some text.\\n\\n```python\\ndef foo():\\n return 'bar'\\n````\\n\\nMore text.\",\n False,\n None,\n ],\n [\n \"First code block:\\n\\n```python\\ndef foo():\\n return 'bar'\\n```\\n\\n\"\n \"Second code block:\\n\\n```javascript\\nfunction foo() {\\n return 'bar';\\n}\\n```\",\n False,\n \"def foo():\\n return 'bar'\\n\",\n ],\n [\n \"First code block:\\n\\n```python\\ndef foo():\\n return 'bar'\\n```\\n\\n\"\n \"Second code block:\\n\\n```javascript\\nfunction foo() {\\n return 'bar';\\n}\\n```\",\n True,\n \"function foo() {\\n return 'bar';\\n}\\n\",\n ],\n [\n \"First code block:\\n\\n```python\\ndef foo():\\n return 'bar'\\n```\\n\\n\"\n # This one has trailing whitespace after the second code block:\n # https://github.com/simonw/llm/pull/718#issuecomment-2613177036\n \"Second code block:\\n\\n```javascript\\nfunction foo() {\\n return 'bar';\\n}\\n``` \",\n True,\n \"function foo() {\\n return 'bar';\\n}\\n\",\n ],\n [\n \"Here is some text.\\n\\n```python\\ndef foo():\\n return `bar`\\n```\\n\\nMore text.\",\n False,\n \"def foo():\\n return `bar`\\n\",\n ],\n ],\n)\ndef test_extract_fenced_code_block(input, last, expected):\n actual = extract_fenced_code_block(input, last=last)\n assert actual == expected"} | |
{"id": "tests/test_utils.py:112", "code": "@pytest.mark.parametrize(\n \"schema, expected\",\n [\n # Test case 1: Basic comma-separated fields, default string type\n (\n \"name, bio\",\n {\n \"type\": \"object\",\n \"properties\": {\"name\": {\"type\": \"string\"}, \"bio\": {\"type\": \"string\"}},\n \"required\": [\"name\", \"bio\"],\n },\n ),\n # Test case 2: Comma-separated fields with types\n (\n \"name, age int, balance float, active bool\",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\"},\n \"balance\": {\"type\": \"number\"},\n \"active\": {\"type\": \"boolean\"},\n },\n \"required\": [\"name\", \"age\", \"balance\", \"active\"],\n },\n ),\n # Test case 3: Comma-separated fields with descriptions\n (\n \"name: full name, age int: years old\",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\", \"description\": \"full name\"},\n \"age\": {\"type\": \"integer\", \"description\": \"years old\"},\n },\n \"required\": [\"name\", \"age\"],\n },\n ),\n # Test case 4: Newline-separated fields\n (\n \"\"\"\n name\n bio\n age int\n \"\"\",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"bio\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\"},\n },\n \"required\": [\"name\", \"bio\", \"age\"],\n },\n ),\n # Test case 5: Newline-separated with descriptions containing commas\n (\n \"\"\"\n name: the person's name\n age int: their age in years, must be positive\n bio: a short bio, no more than three sentences\n \"\"\",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\", \"description\": \"the person's name\"},\n \"age\": {\n \"type\": \"integer\",\n \"description\": \"their age in years, must be positive\",\n },\n \"bio\": {\n \"type\": \"string\",\n \"description\": \"a short bio, no more than three sentences\",\n },\n },\n \"required\": [\"name\", \"age\", \"bio\"],\n },\n ),\n # Test case 6: Empty schema\n (\"\", {\"type\": \"object\", \"properties\": {}, \"required\": []}),\n # Test case 7: Explicit string type\n (\n \"name str, description str\",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"description\": {\"type\": \"string\"},\n },\n \"required\": [\"name\", \"description\"],\n },\n ),\n # Test case 8: Extra whitespace\n (\n \" name , age int : person's age \",\n {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\", \"description\": \"person's age\"},\n },\n \"required\": [\"name\", \"age\"],\n },\n ),\n ],\n)\ndef test_schema_dsl(schema, expected):\n result = schema_dsl(schema)\n assert result == expected"} | |
{"id": "tests/test_utils.py:223", "code": "def test_schema_dsl_multi():\n result = schema_dsl(\"name, age int: The age\", multi=True)\n assert result == {\n \"type\": \"object\",\n \"properties\": {\n \"items\": {\n \"type\": \"array\",\n \"items\": {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\", \"description\": \"The age\"},\n },\n \"required\": [\"name\", \"age\"],\n },\n }\n },\n \"required\": [\"items\"],\n }"} | |
{"id": "tests/test_utils.py:244", "code": "@pytest.mark.parametrize(\n \"text, max_length, normalize_whitespace, keep_end, expected\",\n [\n # Basic truncation tests\n (\"Hello, world!\", 100, False, False, \"Hello, world!\"),\n (\"Hello, world!\", 5, False, False, \"He...\"),\n (\"\", 10, False, False, \"\"),\n (None, 10, False, False, None),\n # Normalize whitespace tests\n (\"Hello world!\", 100, True, False, \"Hello world!\"),\n (\"Hello \\n\\t world!\", 100, True, False, \"Hello world!\"),\n (\"Hello world!\", 5, True, False, \"He...\"),\n # Keep end tests\n (\"Hello, world!\", 10, False, True, \"He... d!\"),\n (\"Hello, world!\", 7, False, False, \"Hell...\"), # Now using regular truncation\n (\"1234567890\", 7, False, False, \"1234...\"), # Now using regular truncation\n # Combinations of parameters\n (\"Hello world!\", 10, True, True, \"He... d!\"),\n # Note: After normalization, \"Hello world!\" is exactly 12 chars, so no truncation\n (\"Hello \\n\\t world!\", 12, True, True, \"Hello world!\"),\n # Edge cases\n (\"12345\", 5, False, False, \"12345\"),\n (\"123456\", 5, False, False, \"12...\"),\n (\"12345\", 5, False, True, \"12345\"), # Unchanged for exact fit\n (\"123456\", 5, False, False, \"12...\"), # Regular truncation for small max_length\n # Very long string\n (\"A\" * 200, 10, False, False, \"AAAAAAA...\"),\n (\"A\" * 200, 10, False, True, \"AA... AA\"), # keep_end with adequate length\n # Exact boundary cases\n (\"123456789\", 9, False, False, \"123456789\"), # Exact fit\n (\"1234567890\", 9, False, False, \"123456...\"), # Simple truncation\n (\"123456789\", 9, False, True, \"123456789\"), # Exact fit with keep_end\n (\"1234567890\", 9, False, True, \"12... 90\"), # keep_end truncation\n # Minimum sensible length tests for keep_end\n (\n \"1234567890\",\n 8,\n False,\n True,\n \"12345...\",\n ), # Too small for keep_end, use regular\n (\"1234567890\", 9, False, True, \"12... 90\"), # Just enough for keep_end\n ],\n)\ndef test_truncate_string(text, max_length, normalize_whitespace, keep_end, expected):\n \"\"\"Test the truncate_string function with various inputs and parameters.\"\"\"\n result = truncate_string(\n text=text,\n max_length=max_length,\n normalize_whitespace=normalize_whitespace,\n keep_end=keep_end,\n )\n assert result == expected"} | |
{"id": "tests/test_utils.py:299", "code": "@pytest.mark.parametrize(\n \"text, max_length, keep_end, prefix_len, expected_full\",\n [\n # Test cases when the length is just right (string fits)\n (\"0123456789\", 10, True, None, \"0123456789\"),\n # Test cases with enough room for the ellipsis\n (\"012345678901234\", 14, True, 4, \"0123... 1234\"),\n # Test cases with different cutoffs\n (\"abcdefghijklmnopqrstuvwxyz\", 10, True, 2, \"ab... yz\"),\n (\"abcdefghijklmnopqrstuvwxyz\", 12, True, 3, \"abc... xyz\"),\n # Test cases below minimum threshold\n (\"abcdefghijklmnopqrstuvwxyz\", 8, True, None, \"abcde...\"),\n ],\n)\ndef test_test_truncate_string_keep_end(\n text, max_length, keep_end, prefix_len, expected_full\n):\n \"\"\"Test the specific behavior of the keep_end parameter.\"\"\"\n result = truncate_string(\n text=text,\n max_length=max_length,\n keep_end=keep_end,\n )\n\n assert result == expected_full\n\n # Only check prefix/suffix when we expect truncation with keep_end\n if prefix_len is not None and len(text) > max_length and max_length >= 9:\n assert result[:prefix_len] == text[:prefix_len]\n assert result[-prefix_len:] == text[-prefix_len:]\n assert \"... \" in result"} | |
{"id": "tests/conftest.py:11", "code": "def pytest_configure(config):\n import sys\n\n sys._called_from_test = True"} | |
{"id": "tests/conftest.py:17", "code": "@pytest.fixture\ndef user_path(tmpdir):\n dir = tmpdir / \"llm.datasette.io\"\n dir.mkdir()\n return dir"} | |
{"id": "tests/conftest.py:24", "code": "@pytest.fixture\ndef logs_db(user_path):\n return sqlite_utils.Database(str(user_path / \"logs.db\"))"} | |
{"id": "tests/conftest.py:29", "code": "@pytest.fixture\ndef user_path_with_embeddings(user_path):\n path = str(user_path / \"embeddings.db\")\n db = sqlite_utils.Database(path)\n collection = llm.Collection(\"demo\", db, model_id=\"embed-demo\")\n collection.embed(\"1\", \"hello world\")\n collection.embed(\"2\", \"goodbye world\")"} | |
{"id": "tests/conftest.py:38", "code": "@pytest.fixture\ndef templates_path(user_path):\n dir = user_path / \"templates\"\n dir.mkdir()\n return dir"} | |
{"id": "tests/conftest.py:45", "code": "@pytest.fixture(autouse=True)\ndef env_setup(monkeypatch, user_path):\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))"} | |
{"id": "tests/conftest.py:50", "code": "class MockModel(llm.Model):\n model_id = \"mock\"\n attachment_types = {\"image/png\", \"audio/wav\"}\n supports_schema = True\n\n class Options(llm.Options):\n max_tokens: Optional[int] = Field(\n description=\"Maximum number of tokens to generate.\", default=None\n )\n\n def __init__(self):\n self.history = []\n self._queue = []\n\n def enqueue(self, messages):\n assert isinstance(messages, list)\n self._queue.append(messages)\n\n def execute(self, prompt, stream, response, conversation):\n self.history.append((prompt, stream, response, conversation))\n gathered = []\n while True:\n try:\n messages = self._queue.pop(0)\n for message in messages:\n gathered.append(message)\n yield message\n break\n except IndexError:\n break\n response.set_usage(\n input=len((prompt.prompt or \"\").split()), output=len(gathered)\n )"} | |
{"id": "tests/conftest.py:60", "code": " def __init__(self):\n self.history = []\n self._queue = []"} | |
{"id": "tests/conftest.py:64", "code": " def enqueue(self, messages):\n assert isinstance(messages, list)\n self._queue.append(messages)"} | |
{"id": "tests/conftest.py:68", "code": " def execute(self, prompt, stream, response, conversation):\n self.history.append((prompt, stream, response, conversation))\n gathered = []\n while True:\n try:\n messages = self._queue.pop(0)\n for message in messages:\n gathered.append(message)\n yield message\n break\n except IndexError:\n break\n response.set_usage(\n input=len((prompt.prompt or \"\").split()), output=len(gathered)\n )"} | |
{"id": "tests/conftest.py:85", "code": "class MockKeyModel(llm.KeyModel):\n model_id = \"mock_key\"\n needs_key = \"mock\"\n\n def execute(self, prompt, stream, response, conversation, key):\n return [f\"key: {key}\"]"} | |
{"id": "tests/conftest.py:89", "code": " def execute(self, prompt, stream, response, conversation, key):\n return [f\"key: {key}\"]"} | |
{"id": "tests/conftest.py:93", "code": "class MockAsyncKeyModel(llm.AsyncKeyModel):\n model_id = \"mock_key\"\n needs_key = \"mock\"\n\n async def execute(self, prompt, stream, response, conversation, key):\n yield f\"async, key: {key}\""} | |
{"id": "tests/conftest.py:97", "code": " async def execute(self, prompt, stream, response, conversation, key):\n yield f\"async, key: {key}\""} | |
{"id": "tests/conftest.py:101", "code": "class AsyncMockModel(llm.AsyncModel):\n model_id = \"mock\"\n supports_schema = True\n\n def __init__(self):\n self.history = []\n self._queue = []\n\n def enqueue(self, messages):\n assert isinstance(messages, list)\n self._queue.append(messages)\n\n async def execute(self, prompt, stream, response, conversation):\n self.history.append((prompt, stream, response, conversation))\n gathered = []\n while True:\n try:\n messages = self._queue.pop(0)\n for message in messages:\n gathered.append(message)\n yield message\n break\n except IndexError:\n break\n response.set_usage(\n input=len((prompt.prompt or \"\").split()), output=len(gathered)\n )"} | |
{"id": "tests/conftest.py:105", "code": " def __init__(self):\n self.history = []\n self._queue = []"} | |
{"id": "tests/conftest.py:109", "code": " def enqueue(self, messages):\n assert isinstance(messages, list)\n self._queue.append(messages)"} | |
{"id": "tests/conftest.py:113", "code": " async def execute(self, prompt, stream, response, conversation):\n self.history.append((prompt, stream, response, conversation))\n gathered = []\n while True:\n try:\n messages = self._queue.pop(0)\n for message in messages:\n gathered.append(message)\n yield message\n break\n except IndexError:\n break\n response.set_usage(\n input=len((prompt.prompt or \"\").split()), output=len(gathered)\n )"} | |
{"id": "tests/conftest.py:130", "code": "class EmbedDemo(llm.EmbeddingModel):\n model_id = \"embed-demo\"\n batch_size = 10\n supports_binary = True\n\n def __init__(self):\n self.embedded_content = []\n\n def embed_batch(self, texts):\n if not hasattr(self, \"batch_count\"):\n self.batch_count = 0\n self.batch_count += 1\n for text in texts:\n self.embedded_content.append(text)\n words = text.split()[:16]\n embedding = [len(word) for word in words]\n # Pad with 0 up to 16 words\n embedding += [0] * (16 - len(embedding))\n yield embedding"} | |
{"id": "tests/conftest.py:135", "code": " def __init__(self):\n self.embedded_content = []"} | |
{"id": "tests/conftest.py:138", "code": " def embed_batch(self, texts):\n if not hasattr(self, \"batch_count\"):\n self.batch_count = 0\n self.batch_count += 1\n for text in texts:\n self.embedded_content.append(text)\n words = text.split()[:16]\n embedding = [len(word) for word in words]\n # Pad with 0 up to 16 words\n embedding += [0] * (16 - len(embedding))\n yield embedding"} | |
{"id": "tests/conftest.py:151", "code": "class EmbedBinaryOnly(EmbedDemo):\n model_id = \"embed-binary-only\"\n supports_text = False\n supports_binary = True"} | |
{"id": "tests/conftest.py:157", "code": "class EmbedTextOnly(EmbedDemo):\n model_id = \"embed-text-only\"\n supports_text = True\n supports_binary = False"} | |
{"id": "tests/conftest.py:163", "code": "@pytest.fixture\ndef embed_demo():\n return EmbedDemo()"} | |
{"id": "tests/conftest.py:168", "code": "@pytest.fixture\ndef mock_model():\n return MockModel()"} | |
{"id": "tests/conftest.py:173", "code": "@pytest.fixture\ndef async_mock_model():\n return AsyncMockModel()"} | |
{"id": "tests/conftest.py:178", "code": "@pytest.fixture\ndef mock_key_model():\n return MockKeyModel()"} | |
{"id": "tests/conftest.py:183", "code": "@pytest.fixture\ndef mock_async_key_model():\n return MockAsyncKeyModel()"} | |
{"id": "tests/conftest.py:188", "code": "@pytest.fixture(autouse=True)\ndef register_embed_demo_model(embed_demo, mock_model, async_mock_model):\n class MockModelsPlugin:\n __name__ = \"MockModelsPlugin\"\n\n @llm.hookimpl\n def register_embedding_models(self, register):\n register(embed_demo)\n register(EmbedBinaryOnly())\n register(EmbedTextOnly())\n\n @llm.hookimpl\n def register_models(self, register):\n register(mock_model, async_model=async_mock_model)\n\n pm.register(MockModelsPlugin(), name=\"undo-mock-models-plugin\")\n try:\n yield\n finally:\n pm.unregister(name=\"undo-mock-models-plugin\")"} | |
{"id": "tests/conftest.py:210", "code": "@pytest.fixture\ndef mocked_openai_chat(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"model\": \"gpt-4o-mini\",\n \"usage\": {},\n \"choices\": [{\"message\": {\"content\": \"Bob, Alice, Eve\"}}],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:225", "code": "@pytest.fixture\ndef mocked_openai_chat_returning_fenced_code(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"model\": \"gpt-4o-mini\",\n \"usage\": {},\n \"choices\": [\n {\n \"message\": {\n \"content\": \"Code:\\n\\n````javascript\\nfunction foo() {\\n return 'bar';\\n}\\n````\\nDone.\",\n }\n }\n ],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:246", "code": "def stream_events():\n for delta, finish_reason in (\n ({\"role\": \"assistant\", \"content\": \"\"}, None),\n ({\"content\": \"Hi\"}, None),\n ({\"content\": \".\"}, None),\n ({}, \"stop\"),\n ):\n yield \"data: {}\\n\\n\".format(\n json.dumps(\n {\n \"id\": \"chat-1\",\n \"object\": \"chat.completion.chunk\",\n \"created\": 1695096940,\n \"model\": \"gpt-3.5-turbo-0613\",\n \"choices\": [\n {\"index\": 0, \"delta\": delta, \"finish_reason\": finish_reason}\n ],\n }\n )\n ).encode(\"utf-8\")\n yield \"data: [DONE]\\n\\n\".encode(\"utf-8\")"} | |
{"id": "tests/conftest.py:269", "code": "@pytest.fixture\ndef mocked_openai_chat_stream(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/chat/completions\",\n stream=IteratorStream(stream_events()),\n headers={\"Content-Type\": \"text/event-stream\"},\n )"} | |
{"id": "tests/conftest.py:279", "code": "@pytest.fixture\ndef mocked_openai_completion(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/completions\",\n json={\n \"id\": \"cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7\",\n \"object\": \"text_completion\",\n \"created\": 1589478378,\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"choices\": [\n {\n \"text\": \"\\n\\nThis is indeed a test\",\n \"index\": 0,\n \"logprobs\": None,\n \"finish_reason\": \"length\",\n }\n ],\n \"usage\": {\"prompt_tokens\": 5, \"completion_tokens\": 7, \"total_tokens\": 12},\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:304", "code": "def stream_completion_events():\n choices_chunks = [\n [\n {\n \"text\": \"\\n\\n\",\n \"index\": 0,\n \"logprobs\": {\n \"tokens\": [\"\\n\\n\"],\n \"token_logprobs\": [-0.6],\n \"top_logprobs\": [{\"\\n\\n\": -0.6, \"\\n\": -1.9}],\n \"text_offset\": [16],\n },\n \"finish_reason\": None,\n }\n ],\n [\n {\n \"text\": \"Hi\",\n \"index\": 0,\n \"logprobs\": {\n \"tokens\": [\"Hi\"],\n \"token_logprobs\": [-1.1],\n \"top_logprobs\": [{\"Hi\": -1.1, \"Hello\": -0.7}],\n \"text_offset\": [18],\n },\n \"finish_reason\": None,\n }\n ],\n [\n {\n \"text\": \".\",\n \"index\": 0,\n \"logprobs\": {\n \"tokens\": [\".\"],\n \"token_logprobs\": [-1.1],\n \"top_logprobs\": [{\".\": -1.1, \"!\": -0.9}],\n \"text_offset\": [20],\n },\n \"finish_reason\": None,\n }\n ],\n [\n {\n \"text\": \"\",\n \"index\": 0,\n \"logprobs\": {\n \"tokens\": [],\n \"token_logprobs\": [],\n \"top_logprobs\": [],\n \"text_offset\": [],\n },\n \"finish_reason\": \"stop\",\n }\n ],\n ]\n\n for choices in choices_chunks:\n yield \"data: {}\\n\\n\".format(\n json.dumps(\n {\n \"id\": \"cmpl-80MdSaou7NnPuff5ZyRMysWBmgSPS\",\n \"object\": \"text_completion\",\n \"created\": 1695097702,\n \"choices\": choices,\n \"model\": \"gpt-3.5-turbo-instruct\",\n }\n )\n ).encode(\"utf-8\")\n yield \"data: [DONE]\\n\\n\".encode(\"utf-8\")"} | |
{"id": "tests/conftest.py:375", "code": "@pytest.fixture\ndef mocked_openai_completion_logprobs_stream(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/completions\",\n stream=IteratorStream(stream_completion_events()),\n headers={\"Content-Type\": \"text/event-stream\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:386", "code": "@pytest.fixture\ndef mocked_openai_completion_logprobs(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/completions\",\n json={\n \"id\": \"cmpl-80MeBfKJutM0uMNJkRrebJLeP3bxL\",\n \"object\": \"text_completion\",\n \"created\": 1695097747,\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"choices\": [\n {\n \"text\": \"\\n\\nHi.\",\n \"index\": 0,\n \"logprobs\": {\n \"tokens\": [\"\\n\\n\", \"Hi\", \"1\"],\n \"token_logprobs\": [-0.6, -1.1, -0.9],\n \"top_logprobs\": [\n {\"\\n\\n\": -0.6, \"\\n\": -1.9},\n {\"Hi\": -1.1, \"Hello\": -0.7},\n {\".\": -0.9, \"!\": -1.1},\n ],\n \"text_offset\": [16, 18, 20],\n },\n \"finish_reason\": \"stop\",\n }\n ],\n \"usage\": {\"prompt_tokens\": 5, \"completion_tokens\": 3, \"total_tokens\": 8},\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:420", "code": "@pytest.fixture\ndef mocked_localai(httpx_mock):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"http://localai.localhost/chat/completions\",\n json={\n \"model\": \"orca\",\n \"usage\": {},\n \"choices\": [{\"message\": {\"content\": \"Bob, Alice, Eve\"}}],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n httpx_mock.add_response(\n method=\"POST\",\n url=\"http://localai.localhost/completions\",\n json={\n \"model\": \"completion-babbage\",\n \"usage\": {},\n \"choices\": [{\"text\": \"Hello\"}],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n return httpx_mock"} | |
{"id": "tests/conftest.py:445", "code": "@pytest.fixture\ndef collection():\n collection = llm.Collection(\"test\", model_id=\"embed-demo\")\n collection.embed(1, \"hello world\")\n collection.embed(2, \"goodbye world\")\n return collection"} | |
{"id": "tests/test_llm_logs.py:18", "code": "@pytest.fixture\ndef log_path(user_path):\n log_path = str(user_path / \"logs.db\")\n db = sqlite_utils.Database(log_path)\n migrate(db)\n start = datetime.datetime.now(datetime.timezone.utc)\n db[\"responses\"].insert_all(\n {\n \"id\": str(ULID()).lower(),\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"response\": 'response\\n```python\\nprint(\"hello word\")\\n```',\n \"model\": \"davinci\",\n \"datetime_utc\": (start + datetime.timedelta(seconds=i)).isoformat(),\n \"conversation_id\": \"abc123\",\n \"input_tokens\": 2,\n \"output_tokens\": 5,\n }\n for i in range(100)\n )\n return log_path"} | |
{"id": "tests/test_llm_logs.py:41", "code": "@pytest.fixture\ndef schema_log_path(user_path):\n log_path = str(user_path / \"logs_schema.db\")\n db = sqlite_utils.Database(log_path)\n migrate(db)\n start = datetime.datetime.now(datetime.timezone.utc)\n db[\"schemas\"].insert({\"id\": SINGLE_ID, \"content\": '{\"name\": \"string\"}'})\n db[\"schemas\"].insert({\"id\": MULTI_ID, \"content\": '{\"name\": \"array\"}'})\n for i in range(2):\n db[\"responses\"].insert(\n {\n \"id\": str(ULID.from_timestamp(time.time() + i)).lower(),\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"response\": '{\"name\": \"' + str(i) + '\"}',\n \"model\": \"davinci\",\n \"datetime_utc\": (start + datetime.timedelta(seconds=i)).isoformat(),\n \"conversation_id\": \"abc123\",\n \"input_tokens\": 2,\n \"output_tokens\": 5,\n \"schema_id\": SINGLE_ID,\n }\n )\n for j in range(4):\n db[\"responses\"].insert(\n {\n \"id\": str(ULID.from_timestamp(time.time() + j)).lower(),\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"response\": '{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}',\n \"model\": \"davinci\",\n \"datetime_utc\": (start + datetime.timedelta(seconds=i)).isoformat(),\n \"conversation_id\": \"abc456\",\n \"input_tokens\": 2,\n \"output_tokens\": 5,\n \"schema_id\": MULTI_ID,\n }\n )\n\n return log_path"} | |
{"id": "tests/test_llm_logs.py:87", "code": "@pytest.mark.parametrize(\"usage\", (False, True))\ndef test_logs_text(log_path, usage):\n runner = CliRunner()\n args = [\"logs\", \"-p\", str(log_path)]\n if usage:\n args.append(\"-u\")\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n output = result.output\n # Replace 2023-08-17T20:53:58 with YYYY-MM-DDTHH:MM:SS\n output = datetime_re.sub(\"YYYY-MM-DDTHH:MM:SS\", output)\n # Replace id: whatever with id: xxx\n output = id_re.sub(\"id: xxx\", output)\n expected = (\n (\n \"# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx\\n\\n\"\n \"Model: **davinci**\\n\\n\"\n \"## Prompt\\n\\n\"\n \"prompt\\n\\n\"\n \"## System\\n\\n\"\n \"system\\n\\n\"\n \"## Response\\n\\n\"\n 'response\\n```python\\nprint(\"hello word\")\\n```\\n\\n'\n )\n + (\"## Token usage:\\n\\n2 input, 5 output\\n\\n\" if usage else \"\")\n + (\n \"# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx\\n\\n\"\n \"Model: **davinci**\\n\\n\"\n \"## Prompt\\n\\n\"\n \"prompt\\n\\n\"\n \"## Response\\n\\n\"\n 'response\\n```python\\nprint(\"hello word\")\\n```\\n\\n'\n )\n + (\"## Token usage:\\n\\n2 input, 5 output\\n\\n\" if usage else \"\")\n + (\n \"# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx\\n\\n\"\n \"Model: **davinci**\\n\\n\"\n \"## Prompt\\n\\n\"\n \"prompt\\n\\n\"\n \"## Response\\n\\n\"\n 'response\\n```python\\nprint(\"hello word\")\\n```\\n\\n'\n )\n + (\"## Token usage:\\n\\n2 input, 5 output\\n\\n\" if usage else \"\")\n )\n assert output == expected"} | |
{"id": "tests/test_llm_logs.py:134", "code": "@pytest.mark.parametrize(\"n\", (None, 0, 2))\ndef test_logs_json(n, log_path):\n \"Test that logs command correctly returns requested -n records\"\n runner = CliRunner()\n args = [\"logs\", \"-p\", str(log_path), \"--json\"]\n if n is not None:\n args.extend([\"-n\", str(n)])\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n logs = json.loads(result.output)\n expected_length = 3\n if n is not None:\n if n == 0:\n expected_length = 100\n else:\n expected_length = n\n assert len(logs) == expected_length"} | |
{"id": "tests/test_llm_logs.py:153", "code": "@pytest.mark.parametrize(\n \"args\", ([\"-r\"], [\"--response\"], [\"list\", \"-r\"], [\"list\", \"--response\"])\n)\ndef test_logs_response_only(args, log_path):\n \"Test that logs -r/--response returns just the last response\"\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\"] + args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == 'response\\n```python\\nprint(\"hello word\")\\n```\\n'"} | |
{"id": "tests/test_llm_logs.py:164", "code": "@pytest.mark.parametrize(\n \"args\",\n (\n [\"-x\"],\n [\"--extract\"],\n [\"list\", \"-x\"],\n [\"list\", \"--extract\"],\n # Using -xr together should have same effect as just -x\n [\"-xr\"],\n [\"-x\", \"-r\"],\n [\"--extract\", \"--response\"],\n ),\n)\ndef test_logs_extract_first_code(args, log_path):\n \"Test that logs -x/--extract returns the first code block\"\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\"] + args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == 'print(\"hello word\")\\n\\n'"} | |
{"id": "tests/test_llm_logs.py:185", "code": "@pytest.mark.parametrize(\n \"args\",\n (\n [\"--xl\"],\n [\"--extract-last\"],\n [\"list\", \"--xl\"],\n [\"list\", \"--extract-last\"],\n [\"--xl\", \"-r\"],\n [\"-x\", \"--xl\"],\n ),\n)\ndef test_logs_extract_last_code(args, log_path):\n \"Test that logs --xl/--extract-last returns the last code block\"\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\"] + args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == 'print(\"hello word\")\\n\\n'"} | |
{"id": "tests/test_llm_logs.py:204", "code": "@pytest.mark.parametrize(\"arg\", (\"-s\", \"--short\"))\n@pytest.mark.parametrize(\"usage\", (None, \"-u\", \"--usage\"))\ndef test_logs_short(log_path, arg, usage):\n runner = CliRunner()\n args = [\"logs\", arg, \"-p\", str(log_path)]\n if usage:\n args.append(usage)\n result = runner.invoke(cli, args)\n assert result.exit_code == 0\n output = datetime_re.sub(\"YYYY-MM-DDTHH:MM:SS\", result.output)\n expected_usage = \"\"\n if usage:\n expected_usage = \" usage:\\n input: 2\\n output: 5\\n\"\n expected = (\n \"- model: davinci\\n\"\n \" datetime: 'YYYY-MM-DDTHH:MM:SS'\\n\"\n \" conversation: abc123\\n\"\n \" system: system\\n\"\n f\" prompt: prompt\\n{expected_usage}\"\n \"- model: davinci\\n\"\n \" datetime: 'YYYY-MM-DDTHH:MM:SS'\\n\"\n \" conversation: abc123\\n\"\n \" system: system\\n\"\n f\" prompt: prompt\\n{expected_usage}\"\n \"- model: davinci\\n\"\n \" datetime: 'YYYY-MM-DDTHH:MM:SS'\\n\"\n \" conversation: abc123\\n\"\n \" system: system\\n\"\n f\" prompt: prompt\\n{expected_usage}\"\n )\n assert output == expected"} | |
{"id": "tests/test_llm_logs.py:237", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\n@pytest.mark.parametrize(\"env\", ({}, {\"LLM_USER_PATH\": \"/tmp/llm-user-path\"}))\ndef test_logs_path(monkeypatch, env, user_path):\n for key, value in env.items():\n monkeypatch.setenv(key, value)\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\", \"path\"])\n assert result.exit_code == 0\n if env:\n expected = env[\"LLM_USER_PATH\"] + \"/logs.db\"\n else:\n expected = str(user_path) + \"/logs.db\"\n assert result.output.strip() == expected"} | |
{"id": "tests/test_llm_logs.py:252", "code": "@pytest.mark.parametrize(\"model\", (\"davinci\", \"curie\"))\ndef test_logs_filtered(user_path, model):\n log_path = str(user_path / \"logs.db\")\n db = sqlite_utils.Database(log_path)\n migrate(db)\n db[\"responses\"].insert_all(\n {\n \"id\": str(ULID()).lower(),\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"response\": \"response\",\n \"model\": \"davinci\" if i % 2 == 0 else \"curie\",\n }\n for i in range(100)\n )\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\", \"list\", \"-m\", model, \"--json\"])\n assert result.exit_code == 0\n records = json.loads(result.output.strip())\n assert all(record[\"model\"] == model for record in records)"} | |
{"id": "tests/test_llm_logs.py:274", "code": "@pytest.mark.parametrize(\n \"query,extra_args,expected\",\n (\n # With no search term order should be by datetime\n (\"\", [], [\"doc1\", \"doc2\", \"doc3\"]),\n # With a search it's order by rank instead\n (\"llama\", [], [\"doc1\", \"doc3\"]),\n (\"alpaca\", [], [\"doc2\"]),\n # Model filter should work too\n (\"llama\", [\"-m\", \"davinci\"], [\"doc1\", \"doc3\"]),\n (\"llama\", [\"-m\", \"davinci2\"], []),\n ),\n)\ndef test_logs_search(user_path, query, extra_args, expected):\n log_path = str(user_path / \"logs.db\")\n db = sqlite_utils.Database(log_path)\n migrate(db)\n\n def _insert(id, text):\n db[\"responses\"].insert(\n {\n \"id\": id,\n \"system\": \"system\",\n \"prompt\": text,\n \"response\": \"response\",\n \"model\": \"davinci\",\n }\n )\n\n _insert(\"doc1\", \"llama\")\n _insert(\"doc2\", \"alpaca\")\n _insert(\"doc3\", \"llama llama\")\n runner = CliRunner()\n result = runner.invoke(cli, [\"logs\", \"list\", \"-q\", query, \"--json\"] + extra_args)\n assert result.exit_code == 0\n records = json.loads(result.output.strip())\n assert [record[\"id\"] for record in records] == expected"} | |
{"id": "tests/test_llm_logs.py:313", "code": "@pytest.mark.parametrize(\n \"args,expected\",\n (\n ([\"--data\", \"--schema\", SINGLE_ID], '{\"name\": \"1\"}\\n{\"name\": \"0\"}\\n'),\n (\n [\"--data\", \"--schema\", MULTI_ID],\n (\n '{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}\\n'\n '{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}\\n'\n '{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}\\n'\n '{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}\\n'\n ),\n ),\n (\n [\"--data-array\", \"--schema\", MULTI_ID],\n (\n '[{\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]},\\n'\n ' {\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]},\\n'\n ' {\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]},\\n'\n ' {\"items\": [{\"name\": \"one\"}, {\"name\": \"two\"}]}]\\n'\n ),\n ),\n (\n [\"--schema\", MULTI_ID, \"--data-key\", \"items\"],\n (\n '{\"name\": \"one\"}\\n'\n '{\"name\": \"two\"}\\n'\n '{\"name\": \"one\"}\\n'\n '{\"name\": \"two\"}\\n'\n '{\"name\": \"one\"}\\n'\n '{\"name\": \"two\"}\\n'\n '{\"name\": \"one\"}\\n'\n '{\"name\": \"two\"}\\n'\n ),\n ),\n ),\n)\ndef test_logs_schema(schema_log_path, args, expected):\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"logs\", \"-n\", \"0\", \"-p\", str(schema_log_path)] + args,\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == expected"} | |
{"id": "tests/test_llm_logs.py:361", "code": "def test_logs_schema_data_ids(schema_log_path):\n db = sqlite_utils.Database(schema_log_path)\n ulid = ULID.from_timestamp(time.time() + 100)\n db[\"responses\"].insert(\n {\n \"id\": str(ulid).lower(),\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"response\": json.dumps(\n {\n \"name\": \"three\",\n \"response_id\": 1,\n \"conversation_id\": 2,\n \"conversation_id_\": 3,\n }\n ),\n \"model\": \"davinci\",\n \"datetime_utc\": ulid.datetime.isoformat(),\n \"conversation_id\": \"abc123\",\n \"input_tokens\": 2,\n \"output_tokens\": 5,\n \"schema_id\": SINGLE_ID,\n }\n )\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"logs\",\n \"-n\",\n \"0\",\n \"-p\",\n str(schema_log_path),\n \"--data-ids\",\n \"--data-key\",\n \"items\",\n \"--data-array\",\n ],\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n rows = json.loads(result.output)\n last_row = rows.pop(-1)\n assert set(last_row.keys()) == {\n \"conversation_id_\",\n \"conversation_id\",\n \"response_id\",\n \"response_id_\",\n \"name\",\n \"conversation_id__\",\n }\n for row in rows:\n assert set(row.keys()) == {\"conversation_id\", \"response_id\", \"name\"}"} | |
{"id": "tests/test_plugins.py:8", "code": "def test_register_commands():\n importlib.reload(cli)\n\n def plugin_names():\n return [plugin[\"name\"] for plugin in llm.get_plugins()]\n\n assert \"HelloWorldPlugin\" not in plugin_names()\n\n class HelloWorldPlugin:\n __name__ = \"HelloWorldPlugin\"\n\n @hookimpl\n def register_commands(self, cli):\n @cli.command(name=\"hello-world\")\n def hello_world():\n \"Print hello world\"\n click.echo(\"Hello world!\")\n\n try:\n plugins.pm.register(HelloWorldPlugin(), name=\"HelloWorldPlugin\")\n importlib.reload(cli)\n\n assert \"HelloWorldPlugin\" in plugin_names()\n\n runner = CliRunner()\n result = runner.invoke(cli.cli, [\"hello-world\"])\n assert result.exit_code == 0\n assert result.output == \"Hello world!\\n\"\n\n finally:\n plugins.pm.unregister(name=\"HelloWorldPlugin\")\n importlib.reload(cli)\n assert \"HelloWorldPlugin\" not in plugin_names()"} | |
{"id": "tests/test_plugins.py:43", "code": "def test_register_template_loaders():\n assert get_template_loaders() == {}\n\n def one_loader(template_path):\n return llm.Template(name=\"one:\" + template_path, prompt=template_path)\n\n def two_loader(template_path):\n \"Docs for two\"\n return llm.Template(name=\"two:\" + template_path, prompt=template_path)\n\n def dupe_two_loader(template_path):\n \"Docs for two dupe\"\n return llm.Template(name=\"two:\" + template_path, prompt=template_path)\n\n class TemplateLoadersPlugin:\n __name__ = \"TemplateLoadersPlugin\"\n\n @hookimpl\n def register_template_loaders(self, register):\n register(\"one\", one_loader)\n register(\"two\", two_loader)\n register(\"two\", dupe_two_loader)\n\n try:\n plugins.pm.register(TemplateLoadersPlugin(), name=\"TemplateLoadersPlugin\")\n loaders = get_template_loaders()\n assert loaders == {\n \"one\": one_loader,\n \"two\": two_loader,\n \"two_1\": dupe_two_loader,\n }\n\n # Test the CLI command\n runner = CliRunner()\n result = runner.invoke(cli.cli, [\"templates\", \"loaders\"])\n assert result.exit_code == 0\n assert result.output == (\n \"one:\\n\"\n \" Undocumented\\n\"\n \"two:\\n\"\n \" Docs for two\\n\"\n \"two_1:\\n\"\n \" Docs for two dupe\\n\"\n )\n\n finally:\n plugins.pm.unregister(name=\"TemplateLoadersPlugin\")\n assert get_template_loaders() == {}"} | |
{"id": "tests/test_llm.py:14", "code": "def test_version():\n runner = CliRunner()\n with runner.isolated_filesystem():\n result = runner.invoke(cli, [\"--version\"])\n assert result.exit_code == 0\n assert result.output.startswith(\"cli, version \")"} | |
{"id": "tests/test_llm.py:22", "code": "def test_llm_prompt_creates_log_database(mocked_openai_chat, tmpdir, monkeypatch):\n user_path = tmpdir / \"user\"\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"three names \\nfor a pet pelican\", \"--no-stream\", \"--key\", \"x\"],\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == \"Bob, Alice, Eve\\n\"\n # Should have created user_path and put a logs.db in it\n assert (user_path / \"logs.db\").exists()\n assert sqlite_utils.Database(str(user_path / \"logs.db\"))[\"responses\"].count == 1"} | |
{"id": "tests/test_llm.py:38", "code": "@mock.patch.dict(os.environ, {\"OPENAI_API_KEY\": \"X\"})\n@pytest.mark.parametrize(\"use_stdin\", (True, False, \"split\"))\n@pytest.mark.parametrize(\n \"logs_off,logs_args,should_log\",\n (\n (True, [], False),\n (False, [], True),\n (False, [\"--no-log\"], False),\n (False, [\"--log\"], True),\n (True, [\"-n\"], False), # Short for --no-log\n (True, [\"--log\"], True),\n ),\n)\ndef test_llm_default_prompt(\n mocked_openai_chat, use_stdin, user_path, logs_off, logs_args, should_log\n):\n # Reset the log_path database\n log_path = user_path / \"logs.db\"\n log_db = sqlite_utils.Database(str(log_path))\n log_db[\"responses\"].delete_where()\n\n logs_off_path = user_path / \"logs-off\"\n if logs_off:\n # Turn off logging\n assert not logs_off_path.exists()\n CliRunner().invoke(cli, [\"logs\", \"off\"])\n assert logs_off_path.exists()\n else:\n # Turn on logging\n CliRunner().invoke(cli, [\"logs\", \"on\"])\n assert not logs_off_path.exists()\n\n # Run the prompt\n runner = CliRunner()\n prompt = \"three names \\nfor a pet pelican\"\n input = None\n args = [\"--no-stream\"]\n if use_stdin == \"split\":\n input = \"three names\"\n args.append(\"\\nfor a pet pelican\")\n elif use_stdin:\n input = prompt\n else:\n args.append(prompt)\n args += logs_args\n result = runner.invoke(cli, args, input=input, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == \"Bob, Alice, Eve\\n\"\n last_request = mocked_openai_chat.get_requests()[-1]\n assert last_request.headers[\"Authorization\"] == \"Bearer X\"\n\n # Was it logged?\n rows = list(log_db[\"responses\"].rows)\n\n if not should_log:\n assert len(rows) == 0\n return\n\n assert len(rows) == 1\n expected = {\n \"model\": \"gpt-4o-mini\",\n \"prompt\": \"three names \\nfor a pet pelican\",\n \"system\": None,\n \"options_json\": \"{}\",\n \"response\": \"Bob, Alice, Eve\",\n }\n row = rows[0]\n assert expected.items() <= row.items()\n assert isinstance(row[\"duration_ms\"], int)\n assert isinstance(row[\"datetime_utc\"], str)\n assert json.loads(row[\"prompt_json\"]) == {\n \"messages\": [{\"role\": \"user\", \"content\": \"three names \\nfor a pet pelican\"}]\n }\n assert json.loads(row[\"response_json\"]) == {\n \"model\": \"gpt-4o-mini\",\n \"choices\": [{\"message\": {\"content\": \"Bob, Alice, Eve\"}}],\n }\n\n # Test \"llm logs\"\n log_result = runner.invoke(\n cli, [\"logs\", \"-n\", \"1\", \"--json\"], catch_exceptions=False\n )\n log_json = json.loads(log_result.output)\n\n # Should have logged correctly:\n assert (\n log_json[0].items()\n >= {\n \"model\": \"gpt-4o-mini\",\n \"prompt\": \"three names \\nfor a pet pelican\",\n \"system\": None,\n \"prompt_json\": {\n \"messages\": [\n {\"role\": \"user\", \"content\": \"three names \\nfor a pet pelican\"}\n ]\n },\n \"options_json\": {},\n \"response\": \"Bob, Alice, Eve\",\n \"response_json\": {\n \"model\": \"gpt-4o-mini\",\n \"choices\": [{\"message\": {\"content\": \"Bob, Alice, Eve\"}}],\n },\n # This doesn't have the \\n after three names:\n \"conversation_name\": \"three names for a pet pelican\",\n \"conversation_model\": \"gpt-4o-mini\",\n }.items()\n )"} | |
{"id": "tests/test_llm.py:147", "code": "@mock.patch.dict(os.environ, {\"OPENAI_API_KEY\": \"X\"})\n@pytest.mark.parametrize(\"async_\", (False, True))\ndef test_llm_prompt_continue(httpx_mock, user_path, async_):\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"model\": \"gpt-4o-mini\",\n \"usage\": {},\n \"choices\": [{\"message\": {\"content\": \"Bob, Alice, Eve\"}}],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n httpx_mock.add_response(\n method=\"POST\",\n url=\"https://api.openai.com/v1/chat/completions\",\n json={\n \"model\": \"gpt-4o-mini\",\n \"usage\": {},\n \"choices\": [{\"message\": {\"content\": \"Terry\"}}],\n },\n headers={\"Content-Type\": \"application/json\"},\n )\n\n log_path = user_path / \"logs.db\"\n log_db = sqlite_utils.Database(str(log_path))\n log_db[\"responses\"].delete_where()\n\n # First prompt\n runner = CliRunner()\n args = [\"three names \\nfor a pet pelican\", \"--no-stream\"] + (\n [\"--async\"] if async_ else []\n )\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0, result.output\n assert result.output == \"Bob, Alice, Eve\\n\"\n\n # Should be logged\n rows = list(log_db[\"responses\"].rows)\n assert len(rows) == 1\n\n # Now ask a follow-up\n args2 = [\"one more\", \"-c\", \"--no-stream\"] + ([\"--async\"] if async_ else [])\n result2 = runner.invoke(cli, args2, catch_exceptions=False)\n assert result2.exit_code == 0, result2.output\n assert result2.output == \"Terry\\n\"\n\n rows = list(log_db[\"responses\"].rows)\n assert len(rows) == 2"} | |
{"id": "tests/test_llm.py:198", "code": "@pytest.mark.parametrize(\n \"args,expect_just_code\",\n (\n ([\"-x\"], True),\n ([\"--extract\"], True),\n ([\"-x\", \"--async\"], True),\n ([\"--extract\", \"--async\"], True),\n # Use --no-stream here to ensure it passes test same as -x/--extract cases\n ([\"--no-stream\"], False),\n ),\n)\ndef test_extract_fenced_code(\n mocked_openai_chat_returning_fenced_code, args, expect_just_code\n):\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"-m\", \"gpt-4o-mini\", \"--key\", \"x\", \"Write code\"] + args,\n catch_exceptions=False,\n )\n output = result.output\n if expect_just_code:\n assert \"```\" not in output\n else:\n assert \"```\" in output"} | |
{"id": "tests/test_llm.py:225", "code": "def test_openai_chat_stream(mocked_openai_chat_stream, user_path):\n runner = CliRunner()\n result = runner.invoke(cli, [\"-m\", \"gpt-3.5-turbo\", \"--key\", \"x\", \"Say hi\"])\n assert result.exit_code == 0\n assert result.output == \"Hi.\\n\""} | |
{"id": "tests/test_llm.py:232", "code": "def test_openai_completion(mocked_openai_completion, user_path):\n log_path = user_path / \"logs.db\"\n log_db = sqlite_utils.Database(str(log_path))\n log_db[\"responses\"].delete_where()\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"-m\",\n \"gpt-3.5-turbo-instruct\",\n \"Say this is a test\",\n \"--no-stream\",\n \"--key\",\n \"x\",\n ],\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == \"\\n\\nThis is indeed a test\\n\"\n\n # Should have requested 256 tokens\n last_request = mocked_openai_completion.get_requests()[-1]\n assert json.loads(last_request.content) == {\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"prompt\": \"Say this is a test\",\n \"stream\": False,\n \"max_tokens\": 256,\n }\n\n # Check it was logged\n rows = list(log_db[\"responses\"].rows)\n assert len(rows) == 1\n expected = {\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"prompt\": \"Say this is a test\",\n \"system\": None,\n \"prompt_json\": '{\"messages\": [\"Say this is a test\"]}',\n \"options_json\": \"{}\",\n \"response\": \"\\n\\nThis is indeed a test\",\n }\n row = rows[0]\n assert expected.items() <= row.items()"} | |
{"id": "tests/test_llm.py:276", "code": "def test_openai_completion_system_prompt_error():\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"-m\",\n \"gpt-3.5-turbo-instruct\",\n \"Say this is a test\",\n \"--no-stream\",\n \"--key\",\n \"x\",\n \"--system\",\n \"system prompts not allowed\",\n ],\n catch_exceptions=False,\n )\n assert result.exit_code == 1\n assert (\n result.output\n == \"Error: System prompts are not supported for OpenAI completion models\\n\"\n )"} | |
{"id": "tests/test_llm.py:299", "code": "def test_openai_completion_logprobs_stream(\n mocked_openai_completion_logprobs_stream, user_path\n):\n log_path = user_path / \"logs.db\"\n log_db = sqlite_utils.Database(str(log_path))\n log_db[\"responses\"].delete_where()\n runner = CliRunner()\n args = [\n \"-m\",\n \"gpt-3.5-turbo-instruct\",\n \"Say hi\",\n \"-o\",\n \"logprobs\",\n \"2\",\n \"--key\",\n \"x\",\n ]\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == \"\\n\\nHi.\\n\"\n rows = list(log_db[\"responses\"].rows)\n assert len(rows) == 1\n row = rows[0]\n assert json.loads(row[\"response_json\"]) == {\n \"content\": \"\\n\\nHi.\",\n \"logprobs\": [\n {\"text\": \"\\n\\n\", \"top_logprobs\": [{\"\\n\\n\": -0.6, \"\\n\": -1.9}]},\n {\"text\": \"Hi\", \"top_logprobs\": [{\"Hi\": -1.1, \"Hello\": -0.7}]},\n {\"text\": \".\", \"top_logprobs\": [{\".\": -1.1, \"!\": -0.9}]},\n {\"text\": \"\", \"top_logprobs\": []},\n ],\n \"id\": \"cmpl-80MdSaou7NnPuff5ZyRMysWBmgSPS\",\n \"object\": \"text_completion\",\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"created\": 1695097702,\n }"} | |
{"id": "tests/test_llm.py:337", "code": "def test_openai_completion_logprobs_nostream(\n mocked_openai_completion_logprobs, user_path\n):\n log_path = user_path / \"logs.db\"\n log_db = sqlite_utils.Database(str(log_path))\n log_db[\"responses\"].delete_where()\n runner = CliRunner()\n args = [\n \"-m\",\n \"gpt-3.5-turbo-instruct\",\n \"Say hi\",\n \"-o\",\n \"logprobs\",\n \"2\",\n \"--key\",\n \"x\",\n \"--no-stream\",\n ]\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n assert result.output == \"\\n\\nHi.\\n\"\n rows = list(log_db[\"responses\"].rows)\n assert len(rows) == 1\n row = rows[0]\n assert json.loads(row[\"response_json\"]) == {\n \"choices\": [\n {\n \"finish_reason\": \"stop\",\n \"index\": 0,\n \"logprobs\": {\n \"text_offset\": [16, 18, 20],\n \"token_logprobs\": [-0.6, -1.1, -0.9],\n \"tokens\": [\"\\n\\n\", \"Hi\", \"1\"],\n \"top_logprobs\": [\n {\"\\n\": -1.9, \"\\n\\n\": -0.6},\n {\"Hello\": -0.7, \"Hi\": -1.1},\n {\"!\": -1.1, \".\": -0.9},\n ],\n },\n \"text\": \"\\n\\nHi.\",\n }\n ],\n \"created\": 1695097747,\n \"id\": \"cmpl-80MeBfKJutM0uMNJkRrebJLeP3bxL\",\n \"model\": \"gpt-3.5-turbo-instruct\",\n \"object\": \"text_completion\",\n \"usage\": {\"completion_tokens\": 3, \"prompt_tokens\": 5, \"total_tokens\": 8},\n }"} | |
{"id": "tests/test_llm.py:398", "code": "def test_openai_localai_configuration(mocked_localai, user_path):\n log_path = user_path / \"logs.db\"\n sqlite_utils.Database(str(log_path))\n # Write the configuration file\n config_path = user_path / \"extra-openai-models.yaml\"\n config_path.write_text(EXTRA_MODELS_YAML, \"utf-8\")\n # Run the prompt\n runner = CliRunner()\n prompt = \"three names \\nfor a pet pelican\"\n result = runner.invoke(cli, [\"--no-stream\", \"--model\", \"orca\", prompt])\n assert result.exit_code == 0\n assert result.output == \"Bob, Alice, Eve\\n\"\n last_request = mocked_localai.get_requests()[-1]\n assert json.loads(last_request.content) == {\n \"model\": \"orca-mini-3b\",\n \"messages\": [{\"role\": \"user\", \"content\": \"three names \\nfor a pet pelican\"}],\n \"stream\": False,\n }\n # And check the completion model too\n result2 = runner.invoke(cli, [\"--no-stream\", \"--model\", \"completion-babbage\", \"hi\"])\n assert result2.exit_code == 0\n assert result2.output == \"Hello\\n\"\n last_request2 = mocked_localai.get_requests()[-1]\n assert json.loads(last_request2.content) == {\n \"model\": \"babbage\",\n \"prompt\": \"hi\",\n \"stream\": False,\n }"} | |
{"id": "tests/test_llm.py:428", "code": "@pytest.mark.parametrize(\n \"args,exit_code\",\n (\n ([\"-q\", \"mo\", \"-q\", \"ck\"], 0),\n ([\"-q\", \"mock\"], 0),\n ([\"-q\", \"badmodel\"], 1),\n ([\"-q\", \"mock\", \"-q\", \"badmodel\"], 1),\n ),\n)\ndef test_prompt_select_model_with_queries(mock_model, user_path, args, exit_code):\n runner = CliRunner()\n result = runner.invoke(\n cli,\n args + [\"hello\"],\n catch_exceptions=False,\n )\n assert result.exit_code == exit_code"} | |
{"id": "tests/test_llm.py:484", "code": "def test_llm_models_options(user_path):\n runner = CliRunner()\n result = runner.invoke(cli, [\"models\", \"--options\"], catch_exceptions=False)\n assert result.exit_code == 0\n assert EXPECTED_OPTIONS.strip() in result.output\n assert \"AsyncMockModel (async): mock\" not in result.output"} | |
{"id": "tests/test_llm.py:492", "code": "def test_llm_models_async(user_path):\n runner = CliRunner()\n result = runner.invoke(cli, [\"models\", \"--async\"], catch_exceptions=False)\n assert result.exit_code == 0\n assert \"AsyncMockModel (async): mock\" in result.output"} | |
{"id": "tests/test_llm.py:499", "code": "@pytest.mark.parametrize(\n \"args,expected_model_ids,unexpected_model_ids\",\n (\n ([\"-q\", \"gpt-4o\"], [\"OpenAI Chat: gpt-4o\"], None),\n ([\"-q\", \"mock\"], [\"MockModel: mock\"], None),\n ([\"--query\", \"mock\"], [\"MockModel: mock\"], None),\n (\n [\"-q\", \"4o\", \"-q\", \"mini\"],\n [\"OpenAI Chat: gpt-4o-mini\"],\n [\"OpenAI Chat: gpt-4o \"],\n ),\n (\n [\"-m\", \"gpt-4o-mini\", \"-m\", \"gpt-4.5\"],\n [\"OpenAI Chat: gpt-4o-mini\", \"OpenAI Chat: gpt-4.5\"],\n [\"OpenAI Chat: gpt-4o \"],\n ),\n ),\n)\ndef test_llm_models_filter(user_path, args, expected_model_ids, unexpected_model_ids):\n runner = CliRunner()\n result = runner.invoke(cli, [\"models\"] + args, catch_exceptions=False)\n assert result.exit_code == 0\n if expected_model_ids:\n for expected_model_id in expected_model_ids:\n assert expected_model_id in result.output\n if unexpected_model_ids:\n for unexpected_model_id in unexpected_model_ids:\n assert unexpected_model_id not in result.output"} | |
{"id": "tests/test_llm.py:529", "code": "def test_llm_user_dir(tmpdir, monkeypatch):\n user_dir = str(tmpdir / \"u\")\n monkeypatch.setenv(\"LLM_USER_PATH\", user_dir)\n assert not os.path.exists(user_dir)\n user_dir2 = llm.user_dir()\n assert user_dir == str(user_dir2)\n assert os.path.exists(user_dir)"} | |
{"id": "tests/test_llm.py:538", "code": "def test_model_defaults(tmpdir, monkeypatch):\n user_dir = str(tmpdir / \"u\")\n monkeypatch.setenv(\"LLM_USER_PATH\", user_dir)\n config_path = pathlib.Path(user_dir) / \"default_model.txt\"\n assert not config_path.exists()\n assert llm.get_default_model() == \"gpt-4o-mini\"\n assert llm.get_model().model_id == \"gpt-4o-mini\"\n llm.set_default_model(\"gpt-4o\")\n assert config_path.exists()\n assert llm.get_default_model() == \"gpt-4o\"\n assert llm.get_model().model_id == \"gpt-4o\""} | |
{"id": "tests/test_llm.py:551", "code": "def test_get_models():\n models = llm.get_models()\n assert all(isinstance(model, (llm.Model, llm.KeyModel)) for model in models)\n model_ids = [model.model_id for model in models]\n assert \"gpt-4o-mini\" in model_ids\n # Ensure no model_ids are duplicated\n # https://github.com/simonw/llm/issues/667\n assert len(model_ids) == len(set(model_ids))"} | |
{"id": "tests/test_llm.py:561", "code": "def test_get_async_models():\n models = llm.get_async_models()\n assert all(\n isinstance(model, (llm.AsyncModel, llm.AsyncKeyModel)) for model in models\n )\n model_ids = [model.model_id for model in models]\n assert \"gpt-4o-mini\" in model_ids"} | |
{"id": "tests/test_llm.py:570", "code": "def test_mock_model(mock_model):\n mock_model.enqueue([\"hello world\"])\n mock_model.enqueue([\"second\"])\n model = llm.get_model(\"mock\")\n response = model.prompt(prompt=\"hello\")\n assert response.text() == \"hello world\"\n assert str(response) == \"hello world\"\n assert model.history[0][0].prompt == \"hello\"\n assert response.usage() == Usage(input=1, output=1, details=None)\n response2 = model.prompt(prompt=\"hello again\")\n assert response2.text() == \"second\"\n assert response2.usage() == Usage(input=2, output=1, details=None)"} | |
{"id": "tests/test_llm.py:584", "code": "class Dog(BaseModel):\n name: str\n age: int"} | |
{"id": "tests/test_llm.py:601", "code": "@pytest.mark.parametrize(\"use_pydantic\", (False, True))\ndef test_schema(mock_model, use_pydantic):\n assert dog_schema == Dog.model_json_schema()\n mock_model.enqueue([json.dumps(dog)])\n response = mock_model.prompt(\n \"invent a dog\", schema=Dog if use_pydantic else dog_schema\n )\n assert json.loads(response.text()) == dog\n assert response.prompt.schema == dog_schema"} | |
{"id": "tests/test_llm.py:612", "code": "@pytest.mark.parametrize(\"use_filename\", (True, False))\ndef test_schema_via_cli(mock_model, tmpdir, monkeypatch, use_filename):\n user_path = tmpdir / \"user\"\n schema_path = tmpdir / \"schema.json\"\n mock_model.enqueue([json.dumps(dog)])\n schema_value = '{\"schema\": \"one\"}'\n open(schema_path, \"w\").write(schema_value)\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n if use_filename:\n schema_value = str(schema_path)\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"--schema\", schema_value, \"prompt\", \"-m\", \"mock\"],\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == '{\"name\": \"Cleo\", \"age\": 10}\\n'\n # Should have created user_path and put a logs.db in it\n assert (user_path / \"logs.db\").exists()\n rows = list(sqlite_utils.Database(str(user_path / \"logs.db\"))[\"schemas\"].rows)\n assert rows == [\n {\"id\": \"9a8ed2c9b17203f6d8905147234475b5\", \"content\": '{\"schema\":\"one\"}'}\n ]\n if use_filename:\n # Run it again to check that the ID option works now it's in the DB\n result2 = runner.invoke(\n cli,\n [\"--schema\", \"9a8ed2c9b17203f6d8905147234475b5\", \"prompt\", \"-m\", \"mock\"],\n catch_exceptions=False,\n )\n assert result2.exit_code == 0"} | |
{"id": "tests/test_llm.py:646", "code": "@pytest.mark.parametrize(\n \"args,expected\",\n (\n (\n [\"--schema\", \"name, age int\"],\n {\n \"type\": \"object\",\n \"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}},\n \"required\": [\"name\", \"age\"],\n },\n ),\n (\n [\"--schema-multi\", \"name, age int\"],\n {\n \"type\": \"object\",\n \"properties\": {\n \"items\": {\n \"type\": \"array\",\n \"items\": {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\"},\n },\n \"required\": [\"name\", \"age\"],\n },\n }\n },\n \"required\": [\"items\"],\n },\n ),\n ),\n)\ndef test_schema_using_dsl(mock_model, tmpdir, monkeypatch, args, expected):\n user_path = tmpdir / \"user\"\n mock_model.enqueue([json.dumps(dog)])\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"prompt\", \"-m\", \"mock\"] + args,\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == '{\"name\": \"Cleo\", \"age\": 10}\\n'\n rows = list(sqlite_utils.Database(str(user_path / \"logs.db\"))[\"schemas\"].rows)\n assert json.loads(rows[0][\"content\"]) == expected"} | |
{"id": "tests/test_llm.py:695", "code": "@pytest.mark.asyncio\n@pytest.mark.parametrize(\"use_pydantic\", (False, True))\nasync def test_schema_async(async_mock_model, use_pydantic):\n async_mock_model.enqueue([json.dumps(dog)])\n response = async_mock_model.prompt(\n \"invent a dog\", schema=Dog if use_pydantic else dog_schema\n )\n assert json.loads(await response.text()) == dog\n assert response.prompt.schema == dog_schema"} | |
{"id": "tests/test_llm.py:706", "code": "def test_mock_key_model(mock_key_model):\n response = mock_key_model.prompt(prompt=\"hello\", key=\"hi\")\n assert response.text() == \"key: hi\""} | |
{"id": "tests/test_llm.py:711", "code": "@pytest.mark.asyncio\nasync def test_mock_async_key_model(mock_async_key_model):\n response = mock_async_key_model.prompt(prompt=\"hello\", key=\"hi\")\n output = await response.text()\n assert output == \"async, key: hi\""} | |
{"id": "tests/test_llm.py:718", "code": "def test_sync_on_done(mock_model):\n mock_model.enqueue([\"hello world\"])\n model = llm.get_model(\"mock\")\n response = model.prompt(prompt=\"hello\")\n caught = []\n\n def done(response):\n caught.append(response)\n\n response.on_done(done)\n assert len(caught) == 0\n str(response)\n assert len(caught) == 1"} | |
{"id": "tests/test_llm.py:733", "code": "def test_schemas_dsl():\n runner = CliRunner()\n result = runner.invoke(cli, [\"schemas\", \"dsl\", \"name, age int, bio: short bio\"])\n assert result.exit_code == 0\n assert json.loads(result.output) == {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\"},\n \"bio\": {\"type\": \"string\", \"description\": \"short bio\"},\n },\n \"required\": [\"name\", \"age\", \"bio\"],\n }\n result2 = runner.invoke(cli, [\"schemas\", \"dsl\", \"name, age int\", \"--multi\"])\n assert result2.exit_code == 0\n assert json.loads(result2.output) == {\n \"type\": \"object\",\n \"properties\": {\n \"items\": {\n \"type\": \"array\",\n \"items\": {\n \"type\": \"object\",\n \"properties\": {\n \"name\": {\"type\": \"string\"},\n \"age\": {\"type\": \"integer\"},\n },\n \"required\": [\"name\", \"age\"],\n },\n }\n },\n \"required\": [\"items\"],\n }"} | |
{"id": "tests/test_templates.py:11", "code": "@pytest.mark.parametrize(\n \"prompt,system,defaults,params,expected_prompt,expected_system,expected_error\",\n (\n (\"S: $input\", None, None, {}, \"S: input\", None, None),\n (\"S: $input\", \"system\", None, {}, \"S: input\", \"system\", None),\n (\"No vars\", None, None, {}, \"No vars\", None, None),\n (\"$one and $two\", None, None, {}, None, None, \"Missing variables: one, two\"),\n (\"$one and $two\", None, None, {\"one\": 1, \"two\": 2}, \"1 and 2\", None, None),\n (\"$one and $two\", None, {\"one\": 1}, {\"two\": 2}, \"1 and 2\", None, None),\n (\n \"$one and $two\",\n None,\n {\"one\": 99},\n {\"one\": 1, \"two\": 2},\n \"1 and 2\",\n None,\n None,\n ),\n ),\n)\ndef test_template_evaluate(\n prompt, system, defaults, params, expected_prompt, expected_system, expected_error\n):\n t = Template(name=\"t\", prompt=prompt, system=system, defaults=defaults)\n if expected_error:\n with pytest.raises(Template.MissingVariables) as ex:\n prompt, system = t.evaluate(\"input\", params)\n assert ex.value.args[0] == expected_error\n else:\n prompt, system = t.evaluate(\"input\", params)\n assert prompt == expected_prompt\n assert system == expected_system"} | |
{"id": "tests/test_templates.py:45", "code": "def test_templates_list_no_templates_found():\n runner = CliRunner()\n result = runner.invoke(cli, [\"templates\", \"list\"])\n assert result.exit_code == 0\n assert result.output == \"\""} | |
{"id": "tests/test_templates.py:52", "code": "@pytest.mark.parametrize(\"args\", ([\"templates\", \"list\"], [\"templates\"]))\ndef test_templates_list(templates_path, args):\n (templates_path / \"one.yaml\").write_text(\"template one\", \"utf-8\")\n (templates_path / \"two.yaml\").write_text(\"template two\", \"utf-8\")\n (templates_path / \"three.yaml\").write_text(\n \"template three is very long \" * 4, \"utf-8\"\n )\n (templates_path / \"four.yaml\").write_text(\n \"'this one\\n\\nhas newlines in it'\", \"utf-8\"\n )\n (templates_path / \"both.yaml\").write_text(\n \"system: summarize this\\nprompt: $input\", \"utf-8\"\n )\n (templates_path / \"sys.yaml\").write_text(\"system: Summarize this\", \"utf-8\")\n runner = CliRunner()\n result = runner.invoke(cli, args)\n assert result.exit_code == 0\n assert result.output == (\n \"both : system: summarize this prompt: $input\\n\"\n \"four : this one has newlines in it\\n\"\n \"one : template one\\n\"\n \"sys : system: Summarize this\\n\"\n \"three : template three is very long template three is very long template thre...\\n\"\n \"two : template two\\n\"\n )"} | |
{"id": "tests/test_templates.py:79", "code": "@pytest.mark.parametrize(\n \"args,expected_prompt,expected_error\",\n (\n ([\"-m\", \"gpt4\", \"hello\"], {\"model\": \"gpt-4\", \"prompt\": \"hello\"}, None),\n ([\"hello $foo\"], {\"prompt\": \"hello $foo\"}, None),\n ([\"--system\", \"system\"], {\"system\": \"system\"}, None),\n ([\"-t\", \"template\"], None, \"--save cannot be used with --template\"),\n ([\"--continue\"], None, \"--save cannot be used with --continue\"),\n ([\"--cid\", \"123\"], None, \"--save cannot be used with --cid\"),\n ([\"--conversation\", \"123\"], None, \"--save cannot be used with --cid\"),\n (\n [\"Say hello as $name\", \"-p\", \"name\", \"default-name\"],\n {\"prompt\": \"Say hello as $name\", \"defaults\": {\"name\": \"default-name\"}},\n None,\n ),\n # Options\n (\n [\"-o\", \"temperature\", \"0.5\", \"--system\", \"in french\"],\n {\"system\": \"in french\", \"options\": {\"temperature\": 0.5}},\n None,\n ),\n # -x/--extract should be persisted:\n (\n [\"--system\", \"write python\", \"--extract\"],\n {\"system\": \"write python\", \"extract\": True},\n None,\n ),\n # So should schemas (and should not sort properties)\n (\n [\n \"--schema\",\n '{\"properties\": {\"b\": {\"type\": \"string\"}, \"a\": {\"type\": \"string\"}}}',\n ],\n {\n \"schema_object\": {\n \"properties\": {\"b\": {\"type\": \"string\"}, \"a\": {\"type\": \"string\"}}\n }\n },\n None,\n ),\n ),\n)\ndef test_templates_prompt_save(templates_path, args, expected_prompt, expected_error):\n assert not (templates_path / \"saved.yaml\").exists()\n runner = CliRunner()\n result = runner.invoke(cli, args + [\"--save\", \"saved\"], catch_exceptions=False)\n if not expected_error:\n assert result.exit_code == 0\n assert (\n yaml.safe_load((templates_path / \"saved.yaml\").read_text(\"utf-8\"))\n == expected_prompt\n )\n else:\n assert result.exit_code == 1\n assert expected_error in result.output"} | |
{"id": "tests/test_templates.py:136", "code": "def test_templates_error_on_missing_schema(templates_path):\n runner = CliRunner()\n runner.invoke(\n cli, [\"the-prompt\", \"--save\", \"prompt_no_schema\"], catch_exceptions=False\n )\n # This should complain about no schema\n result = runner.invoke(\n cli, [\"hi\", \"--schema\", \"t:prompt_no_schema\"], catch_exceptions=False\n )\n assert result.output == \"Error: Template 'prompt_no_schema' has no schema\\n\"\n # And this is just an invalid template\n result2 = runner.invoke(\n cli, [\"hi\", \"--schema\", \"t:bad_template\"], catch_exceptions=False\n )\n assert result2.output == \"Error: Invalid template: bad_template\\n\""} | |
{"id": "tests/test_templates.py:153", "code": "@mock.patch.dict(os.environ, {\"OPENAI_API_KEY\": \"X\"})\n@pytest.mark.parametrize(\n \"template,input_text,extra_args,expected_model,expected_input,expected_error,expected_options\",\n (\n (\n \"'Summarize this: $input'\",\n \"Input text\",\n [],\n \"gpt-4o-mini\",\n \"Summarize this: Input text\",\n None,\n None,\n ),\n (\n \"prompt: 'Summarize this: $input'\\nmodel: gpt-4\",\n \"Input text\",\n [],\n \"gpt-4\",\n \"Summarize this: Input text\",\n None,\n None,\n ),\n (\n \"prompt: 'Summarize this: $input'\",\n \"Input text\",\n [\"-m\", \"4\"],\n \"gpt-4\",\n \"Summarize this: Input text\",\n None,\n None,\n ),\n pytest.param(\n \"boo\",\n \"Input text\",\n [\"-s\", \"s\"],\n None,\n None,\n \"Error: Cannot use -t/--template and --system together\",\n None,\n marks=pytest.mark.httpx_mock(),\n ),\n pytest.param(\n \"prompt: 'Say $hello'\",\n \"Input text\",\n [],\n None,\n None,\n \"Error: Missing variables: hello\",\n None,\n marks=pytest.mark.httpx_mock(),\n ),\n (\n \"prompt: 'Say $hello'\",\n \"Input text\",\n [\"-p\", \"hello\", \"Blah\"],\n \"gpt-4o-mini\",\n \"Say Blah\",\n None,\n None,\n ),\n (\n \"prompt: 'Say pelican'\",\n \"\",\n [],\n \"gpt-4o-mini\",\n \"Say pelican\",\n None,\n None,\n ),\n # Template with just a system prompt\n (\n \"system: 'Summarize this'\",\n \"Input text\",\n [],\n \"gpt-4o-mini\",\n [\n {\"content\": \"Summarize this\", \"role\": \"system\"},\n {\"content\": \"Input text\", \"role\": \"user\"},\n ],\n None,\n None,\n ),\n # Options\n (\n \"prompt: 'Summarize this: $input'\\noptions:\\n temperature: 0.5\",\n \"Input text\",\n [],\n \"gpt-4o-mini\",\n \"Summarize this: Input text\",\n None,\n {\"temperature\": 0.5},\n ),\n # Should be over-ridden by CLI\n (\n \"prompt: 'Summarize this: $input'\\noptions:\\n temperature: 0.5\",\n \"Input text\",\n [\"-o\", \"temperature\", \"0.7\"],\n \"gpt-4o-mini\",\n \"Summarize this: Input text\",\n None,\n {\"temperature\": 0.7},\n ),\n ),\n)\ndef test_execute_prompt_with_a_template(\n templates_path,\n mocked_openai_chat,\n template,\n input_text,\n extra_args,\n expected_model,\n expected_input,\n expected_error,\n expected_options,\n):\n (templates_path / \"template.yaml\").write_text(template, \"utf-8\")\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"--no-stream\", \"-t\", \"template\"]\n + ([input_text] if input_text else [])\n + extra_args,\n catch_exceptions=False,\n )\n if isinstance(expected_input, str):\n expected_messages = [{\"role\": \"user\", \"content\": expected_input}]\n else:\n expected_messages = expected_input\n\n if expected_error is None:\n assert result.exit_code == 0\n last_request = mocked_openai_chat.get_requests()[-1]\n expected_data = {\n \"model\": expected_model,\n \"messages\": expected_messages,\n \"stream\": False,\n }\n if expected_options:\n expected_data.update(expected_options)\n assert json.loads(last_request.content) == expected_data\n else:\n assert result.exit_code == 1\n assert result.output.strip() == expected_error\n mocked_openai_chat.reset()"} | |
{"id": "tests/test_async.py:5", "code": "@pytest.mark.asyncio\nasync def test_async_model(async_mock_model):\n gathered = []\n async_mock_model.enqueue([\"hello world\"])\n async for chunk in async_mock_model.prompt(\"hello\"):\n gathered.append(chunk)\n assert gathered == [\"hello world\"]\n # Not as an iterator\n async_mock_model.enqueue([\"hello world\"])\n response = await async_mock_model.prompt(\"hello\")\n text = await response.text()\n assert text == \"hello world\"\n assert isinstance(response, llm.AsyncResponse)\n usage = await response.usage()\n assert usage.input == 1\n assert usage.output == 1\n assert usage.details is None"} | |
{"id": "tests/test_async.py:24", "code": "@pytest.mark.asyncio\nasync def test_async_model_conversation(async_mock_model):\n async_mock_model.enqueue([\"joke 1\"])\n conversation = async_mock_model.conversation()\n response = await conversation.prompt(\"joke\")\n text = await response.text()\n assert text == \"joke 1\"\n async_mock_model.enqueue([\"joke 2\"])\n response2 = await conversation.prompt(\"again\")\n text2 = await response2.text()\n assert text2 == \"joke 2\""} | |
{"id": "tests/test_async.py:37", "code": "@pytest.mark.asyncio\nasync def test_async_on_done(async_mock_model):\n async_mock_model.enqueue([\"hello world\"])\n response = await async_mock_model.prompt(prompt=\"hello\")\n caught = []\n\n def done(response):\n caught.append(response)\n\n assert len(caught) == 0\n await response.on_done(done)\n await response.text()\n assert response._done\n assert len(caught) == 1"} | |
{"id": "tests/test_async.py:53", "code": "@pytest.mark.asyncio\nasync def test_async_conversation(async_mock_model):\n async_mock_model.enqueue([\"one\"])\n conversation = async_mock_model.conversation()\n response1 = await conversation.prompt(\"hi\").text()\n async_mock_model.enqueue([\"two\"])\n response2 = await conversation.prompt(\"hi\").text()\n assert response1 == \"one\"\n assert response2 == \"two\""} | |
{"id": "tests/test_keys.py:9", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\n@pytest.mark.parametrize(\"env\", ({}, {\"LLM_USER_PATH\": \"/tmp/llm-keys-test\"}))\ndef test_keys_in_user_path(monkeypatch, env, user_path):\n for key, value in env.items():\n monkeypatch.setenv(key, value)\n runner = CliRunner()\n result = runner.invoke(cli, [\"keys\", \"path\"])\n assert result.exit_code == 0\n if env:\n expected = env[\"LLM_USER_PATH\"] + \"/keys.json\"\n else:\n expected = user_path + \"/keys.json\"\n assert result.output.strip() == expected"} | |
{"id": "tests/test_keys.py:24", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\ndef test_keys_set(monkeypatch, tmpdir):\n user_path = tmpdir / \"user/keys\"\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n keys_path = user_path / \"keys.json\"\n assert not keys_path.exists()\n runner = CliRunner()\n result = runner.invoke(cli, [\"keys\", \"set\", \"openai\"], input=\"foo\")\n assert result.exit_code == 0\n assert keys_path.exists()\n # Should be chmod 600\n assert oct(keys_path.stat().mode)[-3:] == \"600\"\n content = keys_path.read_text(\"utf-8\")\n assert json.loads(content) == {\n \"// Note\": \"This file stores secret API credentials. Do not share!\",\n \"openai\": \"foo\",\n }"} | |
{"id": "tests/test_keys.py:43", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\ndef test_keys_get(monkeypatch, tmpdir):\n user_path = tmpdir / \"user/keys\"\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_path))\n runner = CliRunner()\n result = runner.invoke(cli, [\"keys\", \"set\", \"openai\"], input=\"fx\")\n assert result.exit_code == 0\n result2 = runner.invoke(cli, [\"keys\", \"get\", \"openai\"])\n assert result2.exit_code == 0\n assert result2.output.strip() == \"fx\""} | |
{"id": "tests/test_keys.py:55", "code": "@pytest.mark.parametrize(\"args\", ([\"keys\", \"list\"], [\"keys\"]))\ndef test_keys_list(monkeypatch, tmpdir, args):\n user_path = str(tmpdir / \"user/keys\")\n monkeypatch.setenv(\"LLM_USER_PATH\", user_path)\n runner = CliRunner()\n result = runner.invoke(cli, [\"keys\", \"set\", \"openai\"], input=\"foo\")\n assert result.exit_code == 0\n result2 = runner.invoke(cli, args)\n assert result2.exit_code == 0\n assert result2.output.strip() == \"openai\""} | |
{"id": "tests/test_keys.py:67", "code": "@pytest.mark.httpx_mock(\n assert_all_requests_were_expected=False, can_send_already_matched_responses=True\n)\ndef test_uses_correct_key(mocked_openai_chat, monkeypatch, tmpdir):\n user_dir = tmpdir / \"user-dir\"\n pathlib.Path(user_dir).mkdir()\n keys_path = user_dir / \"keys.json\"\n KEYS = {\n \"openai\": \"from-keys-file\",\n \"other\": \"other-key\",\n }\n keys_path.write_text(json.dumps(KEYS), \"utf-8\")\n monkeypatch.setenv(\"LLM_USER_PATH\", str(user_dir))\n monkeypatch.setenv(\"OPENAI_API_KEY\", \"from-env\")\n\n def assert_key(key):\n request = mocked_openai_chat.get_requests()[-1]\n assert request.headers[\"Authorization\"] == \"Bearer {}\".format(key)\n\n runner = CliRunner()\n\n # Called without --key uses stored key\n result = runner.invoke(cli, [\"hello\", \"--no-stream\"], catch_exceptions=False)\n assert result.exit_code == 0\n assert_key(\"from-keys-file\")\n\n # Called without --key and without keys.json uses environment variable\n keys_path.write_text(\"{}\", \"utf-8\")\n result2 = runner.invoke(cli, [\"hello\", \"--no-stream\"], catch_exceptions=False)\n assert result2.exit_code == 0\n assert_key(\"from-env\")\n keys_path.write_text(json.dumps(KEYS), \"utf-8\")\n\n # Called with --key name-in-keys.json uses that value\n result3 = runner.invoke(\n cli, [\"hello\", \"--key\", \"other\", \"--no-stream\"], catch_exceptions=False\n )\n assert result3.exit_code == 0\n assert_key(\"other-key\")\n\n # Called with --key something-else uses exactly that\n result4 = runner.invoke(\n cli, [\"hello\", \"--key\", \"custom-key\", \"--no-stream\"], catch_exceptions=False\n )\n assert result4.exit_code == 0\n assert_key(\"custom-key\")"} | |
{"id": "tests/test_encode_decode.py:6", "code": "@pytest.mark.parametrize(\n \"array\",\n (\n (0.0, 1.0, 1.5),\n (3423.0, 222.0, -1234.5),\n ),\n)\ndef test_roundtrip(array):\n encoded = llm.encode(array)\n decoded = llm.decode(encoded)\n assert decoded == array\n # Try with numpy as well\n numpy_decoded = np.frombuffer(encoded, \"<f4\")\n assert tuple(numpy_decoded.tolist()) == array"} | |
{"id": "tests/test_embed_cli.py:12", "code": "@pytest.mark.parametrize(\n \"format_,expected\",\n (\n (\"json\", \"[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\\n\"),\n (\n \"base64\",\n (\n \"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\"\n \"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\\n\"\n ),\n ),\n (\n \"hex\",\n (\n \"0000a0400000a04000000000000000000000000000000000000000000\"\n \"000000000000000000000000000000000000000000000000000000000\"\n \"00000000000000\\n\"\n ),\n ),\n (\n \"blob\",\n (\n b\"\\x00\\x00\\xef\\xbf\\xbd@\\x00\\x00\\xef\\xbf\\xbd@\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\n\"\n ).decode(\"utf-8\"),\n ),\n ),\n)\n@pytest.mark.parametrize(\"scenario\", (\"argument\", \"file\", \"stdin\"))\ndef test_embed_output_format(tmpdir, format_, expected, scenario):\n runner = CliRunner()\n args = [\"embed\", \"--format\", format_, \"-m\", \"embed-demo\"]\n input = None\n if scenario == \"argument\":\n args.extend([\"-c\", \"hello world\"])\n elif scenario == \"file\":\n path = tmpdir / \"input.txt\"\n path.write_text(\"hello world\", \"utf-8\")\n args.extend([\"-i\", str(path)])\n elif scenario == \"stdin\":\n input = \"hello world\"\n args.extend([\"-i\", \"-\"])\n result = runner.invoke(cli, args, input=input)\n assert result.exit_code == 0\n assert result.output == expected"} | |
{"id": "tests/test_embed_cli.py:62", "code": "@pytest.mark.parametrize(\n \"args,expected_error\",\n (([\"-c\", \"Content\", \"stories\"], \"Must provide both collection and id\"),),\n)\ndef test_embed_errors(args, expected_error):\n runner = CliRunner()\n result = runner.invoke(cli, [\"embed\"] + args)\n assert result.exit_code == 1\n assert expected_error in result.output"} | |
{"id": "tests/test_embed_cli.py:73", "code": "@pytest.mark.parametrize(\n \"metadata,metadata_error\",\n (\n (None, None),\n ('{\"foo\": \"bar\"}', None),\n ('{\"foo\": [1, 2, 3]}', None),\n (\"[1, 2, 3]\", \"metadata must be a JSON object\"), # Must be a dictionary\n ('{\"foo\": \"incomplete}', \"metadata must be valid JSON\"),\n ),\n)\ndef test_embed_store(user_path, metadata, metadata_error):\n embeddings_db = user_path / \"embeddings.db\"\n assert not embeddings_db.exists()\n runner = CliRunner()\n result = runner.invoke(cli, [\"embed\", \"-c\", \"hello\", \"-m\", \"embed-demo\"])\n assert result.exit_code == 0\n # Should not have created the table\n assert not embeddings_db.exists()\n # Now run it to store\n args = [\"embed\", \"-c\", \"hello\", \"-m\", \"embed-demo\", \"items\", \"1\"]\n if metadata is not None:\n args.extend((\"--metadata\", metadata))\n result = runner.invoke(cli, args)\n if metadata_error:\n # Should have returned an error message about invalid metadata\n assert result.exit_code == 2\n assert metadata_error in result.output\n return\n # No error, should have succeeded and stored the data\n assert result.exit_code == 0\n assert embeddings_db.exists()\n # Check the contents\n db = sqlite_utils.Database(str(embeddings_db))\n rows = list(db[\"collections\"].rows)\n assert rows == [{\"id\": 1, \"name\": \"items\", \"model\": \"embed-demo\"}]\n expected_metadata = None\n if metadata and not metadata_error:\n expected_metadata = metadata\n rows = list(db[\"embeddings\"].rows)\n assert rows == [\n {\n \"collection_id\": 1,\n \"id\": \"1\",\n \"embedding\": (\n b\"\\x00\\x00\\xa0@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n ),\n \"content\": None,\n \"content_blob\": None,\n \"content_hash\": Collection.content_hash(\"hello\"),\n \"metadata\": expected_metadata,\n \"updated\": ANY,\n }\n ]\n # Should show up in 'llm collections list'\n for is_json in (False, True):\n args = [\"collections\"]\n if is_json:\n args.extend([\"list\", \"--json\"])\n result2 = runner.invoke(cli, args)\n assert result2.exit_code == 0\n if is_json:\n assert json.loads(result2.output) == [\n {\"name\": \"items\", \"model\": \"embed-demo\", \"num_embeddings\": 1}\n ]\n else:\n assert result2.output == \"items: embed-demo\\n 1 embedding\\n\"\n\n # And test deleting it too\n result = runner.invoke(cli, [\"collections\", \"delete\", \"items\"])\n assert result.exit_code == 0\n assert db[\"collections\"].count == 0\n assert db[\"embeddings\"].count == 0"} | |
{"id": "tests/test_embed_cli.py:151", "code": "def test_embed_store_binary(user_path):\n runner = CliRunner()\n args = [\"embed\", \"-m\", \"embed-demo\", \"items\", \"2\", \"--binary\", \"--store\"]\n result = runner.invoke(cli, args, input=b\"\\x00\\x01\\x02\")\n assert result.exit_code == 0\n db = sqlite_utils.Database(str(user_path / \"embeddings.db\"))\n rows = list(db[\"embeddings\"].rows)\n assert rows == [\n {\n \"collection_id\": 1,\n \"id\": \"2\",\n \"embedding\": (\n b\"\\x00\\x00@@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n ),\n \"content\": None,\n \"content_blob\": b\"\\x00\\x01\\x02\",\n \"content_hash\": b'\\xb9_g\\xf6\\x1e\\xbb\\x03a\\x96\"\\xd7\\x98\\xf4_\\xc2\\xd3',\n \"metadata\": None,\n \"updated\": ANY,\n }\n ]"} | |
{"id": "tests/test_embed_cli.py:177", "code": "def test_collection_delete_errors(user_path):\n db = sqlite_utils.Database(str(user_path / \"embeddings.db\"))\n collection = Collection(\"items\", db, model_id=\"embed-demo\")\n collection.embed(\"1\", \"hello\")\n assert db[\"collections\"].count == 1\n assert db[\"embeddings\"].count == 1\n runner = CliRunner()\n result = runner.invoke(\n cli, [\"collections\", \"delete\", \"does-not-exist\"], catch_exceptions=False\n )\n assert result.exit_code == 1\n assert \"Collection does not exist\" in result.output\n assert db[\"collections\"].count == 1"} | |
{"id": "tests/test_embed_cli.py:192", "code": "@pytest.mark.parametrize(\n \"args,expected_error\",\n (\n ([], \"Missing argument 'COLLECTION'\"),\n ([\"badcollection\", \"-c\", \"content\"], \"Collection does not exist\"),\n ([\"demo\", \"bad-id\"], \"ID not found in collection\"),\n ),\n)\ndef test_similar_errors(args, expected_error, user_path_with_embeddings):\n runner = CliRunner()\n result = runner.invoke(cli, [\"similar\"] + args, catch_exceptions=False)\n assert result.exit_code != 0\n assert expected_error in result.output"} | |
{"id": "tests/test_embed_cli.py:207", "code": "def test_similar_by_id_cli(user_path_with_embeddings):\n runner = CliRunner()\n result = runner.invoke(cli, [\"similar\", \"demo\", \"1\"], catch_exceptions=False)\n assert result.exit_code == 0\n assert json.loads(result.output) == {\n \"id\": \"2\",\n \"score\": pytest.approx(0.9863939238321437),\n \"content\": None,\n \"metadata\": None,\n }"} | |
{"id": "tests/test_embed_cli.py:219", "code": "@pytest.mark.parametrize(\"scenario\", (\"argument\", \"file\", \"stdin\"))\ndef test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):\n runner = CliRunner()\n args = [\"similar\", \"demo\"]\n input = None\n if scenario == \"argument\":\n args.extend([\"-c\", \"hello world\"])\n elif scenario == \"file\":\n path = tmpdir / \"content.txt\"\n path.write_text(\"hello world\", \"utf-8\")\n args.extend([\"-i\", str(path)])\n elif scenario == \"stdin\":\n input = \"hello world\"\n args.extend([\"-i\", \"-\"])\n result = runner.invoke(cli, args, input=input, catch_exceptions=False)\n assert result.exit_code == 0\n lines = [line for line in result.output.splitlines() if line.strip()]\n assert len(lines) == 2\n assert json.loads(lines[0]) == {\n \"id\": \"1\",\n \"score\": pytest.approx(0.9999999999999999),\n \"content\": None,\n \"metadata\": None,\n }\n assert json.loads(lines[1]) == {\n \"id\": \"2\",\n \"score\": pytest.approx(0.9863939238321437),\n \"content\": None,\n \"metadata\": None,\n }"} | |
{"id": "tests/test_embed_cli.py:251", "code": "@pytest.mark.parametrize(\"use_stdin\", (False, True))\n@pytest.mark.parametrize(\"prefix\", (None, \"prefix\"))\n@pytest.mark.parametrize(\"prepend\", (None, \"search_document: \"))\n@pytest.mark.parametrize(\n \"filename,content\",\n (\n (\"phrases.csv\", \"id,phrase\\n1,hello world\\n2,goodbye world\"),\n (\"phrases.tsv\", \"id\\tphrase\\n1\\thello world\\n2\\tgoodbye world\"),\n (\n \"phrases.jsonl\",\n '{\"id\": 1, \"phrase\": \"hello world\"}\\n{\"id\": 2, \"phrase\": \"goodbye world\"}',\n ),\n (\n \"phrases.json\",\n '[{\"id\": 1, \"phrase\": \"hello world\"}, {\"id\": 2, \"phrase\": \"goodbye world\"}]',\n ),\n ),\n)\ndef test_embed_multi_file_input(tmpdir, use_stdin, prefix, prepend, filename, content):\n db_path = tmpdir / \"embeddings.db\"\n args = [\"embed-multi\", \"phrases\", \"-d\", str(db_path), \"-m\", \"embed-demo\"]\n input = None\n if use_stdin:\n input = content\n args.append(\"-\")\n else:\n path = tmpdir / filename\n path.write_text(content, \"utf-8\")\n args.append(str(path))\n if prefix:\n args.extend((\"--prefix\", prefix))\n if prepend:\n args.extend((\"--prepend\", prepend))\n # Auto-detection can't detect JSON-nl, so make that explicit\n if filename.endswith(\".jsonl\"):\n args.extend((\"--format\", \"nl\"))\n runner = CliRunner()\n result = runner.invoke(cli, args, input=input, catch_exceptions=False)\n assert result.exit_code == 0\n # Check that everything was embedded correctly\n db = sqlite_utils.Database(str(db_path))\n assert db[\"embeddings\"].count == 2\n ids = [row[\"id\"] for row in db[\"embeddings\"].rows]\n expected_ids = [\"1\", \"2\"]\n if prefix:\n expected_ids = [\"prefix1\", \"prefix2\"]\n assert ids == expected_ids"} | |
{"id": "tests/test_embed_cli.py:300", "code": "def test_embed_multi_files_binary_store(tmpdir):\n db_path = tmpdir / \"embeddings.db\"\n args = [\"embed-multi\", \"binfiles\", \"-d\", str(db_path), \"-m\", \"embed-demo\"]\n bin_path = tmpdir / \"file.bin\"\n bin_path.write(b\"\\x00\\x01\\x02\")\n args.extend((\"--files\", str(tmpdir), \"*.bin\", \"--store\", \"--binary\"))\n runner = CliRunner()\n result = runner.invoke(cli, args, catch_exceptions=False)\n assert result.exit_code == 0\n db = sqlite_utils.Database(str(db_path))\n assert db[\"embeddings\"].count == 1\n row = list(db[\"embeddings\"].rows)[0]\n assert row == {\n \"collection_id\": 1,\n \"id\": \"file.bin\",\n \"embedding\": (\n b\"\\x00\\x00@@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n ),\n \"content\": None,\n \"content_blob\": b\"\\x00\\x01\\x02\",\n \"content_hash\": b'\\xb9_g\\xf6\\x1e\\xbb\\x03a\\x96\"\\xd7\\x98\\xf4_\\xc2\\xd3',\n \"metadata\": None,\n \"updated\": ANY,\n }"} | |
{"id": "tests/test_embed_cli.py:329", "code": "@pytest.mark.parametrize(\"use_other_db\", (True, False))\n@pytest.mark.parametrize(\"prefix\", (None, \"prefix\"))\n@pytest.mark.parametrize(\"prepend\", (None, \"search_document: \"))\ndef test_embed_multi_sql(tmpdir, use_other_db, prefix, prepend):\n db_path = str(tmpdir / \"embeddings.db\")\n db = sqlite_utils.Database(db_path)\n extra_args = []\n if use_other_db:\n db_path2 = str(tmpdir / \"other.db\")\n db = sqlite_utils.Database(db_path2)\n extra_args = [\"--attach\", \"other\", db_path2]\n\n if prefix:\n extra_args.extend((\"--prefix\", prefix))\n if prepend:\n extra_args.extend((\"--prepend\", prepend))\n\n db[\"content\"].insert_all(\n [\n {\"id\": 1, \"name\": \"cli\", \"description\": \"Command line interface\"},\n {\"id\": 2, \"name\": \"sql\", \"description\": \"Structured query language\"},\n ],\n pk=\"id\",\n )\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"embed-multi\",\n \"stuff\",\n \"-d\",\n db_path,\n \"--sql\",\n \"select * from content\",\n \"-m\",\n \"embed-demo\",\n \"--store\",\n ]\n + extra_args,\n )\n assert result.exit_code == 0\n embeddings_db = sqlite_utils.Database(db_path)\n assert embeddings_db[\"embeddings\"].count == 2\n rows = list(embeddings_db.query(\"select id, content from embeddings order by id\"))\n assert rows == [\n {\n \"id\": (prefix or \"\") + \"1\",\n \"content\": (prepend or \"\") + \"cli Command line interface\",\n },\n {\n \"id\": (prefix or \"\") + \"2\",\n \"content\": (prepend or \"\") + \"sql Structured query language\",\n },\n ]"} | |
{"id": "tests/test_embed_cli.py:385", "code": "def test_embed_multi_batch_size(embed_demo, tmpdir):\n db_path = str(tmpdir / \"data.db\")\n runner = CliRunner()\n sql = \"\"\"\n with recursive cte (id) as (\n select 1\n union all\n select id+1 from cte where id < 100\n )\n select id, 'Row ' || cast(id as text) as value from cte\n \"\"\"\n assert getattr(embed_demo, \"batch_count\", 0) == 0\n result = runner.invoke(\n cli,\n [\n \"embed-multi\",\n \"rows\",\n \"--sql\",\n sql,\n \"-d\",\n db_path,\n \"-m\",\n \"embed-demo\",\n \"--store\",\n \"--batch-size\",\n \"8\",\n ],\n )\n assert result.exit_code == 0\n db = sqlite_utils.Database(db_path)\n assert db[\"embeddings\"].count == 100\n assert embed_demo.batch_count == 13"} | |
{"id": "tests/test_embed_cli.py:419", "code": "@pytest.fixture\ndef multi_files(tmpdir):\n db_path = str(tmpdir / \"files.db\")\n files = tmpdir / \"files\"\n for filename, content in (\n (\"file1.txt\", b\"hello world\"),\n (\"file2.txt\", b\"goodbye world\"),\n (\"nested/one.txt\", b\"one\"),\n (\"nested/two.txt\", b\"two\"),\n (\"nested/more/three.txt\", b\"three\"),\n # This tests the fallback to latin-1 encoding:\n (\"nested/more/ignored.ini\", b\"Has weird \\x96 character\"),\n ):\n path = pathlib.Path(files / filename)\n path.parent.mkdir(parents=True, exist_ok=True)\n path.write_bytes(content)\n return db_path, tmpdir / \"files\""} | |
{"id": "tests/test_embed_cli.py:438", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\n@pytest.mark.parametrize(\"scenario\", (\"single\", \"multi\"))\n@pytest.mark.parametrize(\"prepend\", (None, \"search_document: \"))\ndef test_embed_multi_files(multi_files, scenario, prepend):\n db_path, files = multi_files\n for filename, content in (\n (\"file1.txt\", b\"hello world\"),\n (\"file2.txt\", b\"goodbye world\"),\n (\"nested/one.txt\", b\"one\"),\n (\"nested/two.txt\", b\"two\"),\n (\"nested/more/three.txt\", b\"three\"),\n # This tests the fallback to latin-1 encoding:\n (\"nested/more.txt/ignored.ini\", b\"Has weird \\x96 character\"),\n ):\n path = pathlib.Path(files / filename)\n path.parent.mkdir(parents=True, exist_ok=True)\n path.write_bytes(content)\n\n extra_args = []\n\n if prepend:\n extra_args.extend((\"--prepend\", prepend))\n if scenario == \"single\":\n extra_args.extend([\"--files\", str(files), \"**/*.txt\"])\n else:\n extra_args.extend(\n [\n \"--files\",\n str(files / \"nested\" / \"more\"),\n \"**/*.ini\",\n \"--files\",\n str(files / \"nested\"),\n \"*.txt\",\n ]\n )\n\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\n \"embed-multi\",\n \"files\",\n \"-d\",\n db_path,\n \"-m\",\n \"embed-demo\",\n \"--store\",\n ]\n + extra_args,\n )\n assert result.exit_code == 0\n embeddings_db = sqlite_utils.Database(db_path)\n rows = list(embeddings_db.query(\"select id, content from embeddings order by id\"))\n if scenario == \"single\":\n assert rows == [\n {\"id\": \"file1.txt\", \"content\": (prepend or \"\") + \"hello world\"},\n {\"id\": \"file2.txt\", \"content\": (prepend or \"\") + \"goodbye world\"},\n {\"id\": \"nested/more/three.txt\", \"content\": (prepend or \"\") + \"three\"},\n {\"id\": \"nested/one.txt\", \"content\": (prepend or \"\") + \"one\"},\n {\"id\": \"nested/two.txt\", \"content\": (prepend or \"\") + \"two\"},\n ]\n else:\n assert rows == [\n {\n \"id\": \"ignored.ini\",\n \"content\": (prepend or \"\") + \"Has weird \\x96 character\",\n },\n {\"id\": \"one.txt\", \"content\": (prepend or \"\") + \"one\"},\n {\"id\": \"two.txt\", \"content\": (prepend or \"\") + \"two\"},\n ]"} | |
{"id": "tests/test_embed_cli.py:510", "code": "@pytest.mark.parametrize(\n \"args,expected_error\",\n (([\"not-a-dir\", \"*.txt\"], \"Invalid directory: not-a-dir\"),),\n)\ndef test_embed_multi_files_errors(multi_files, args, expected_error):\n runner = CliRunner()\n result = runner.invoke(\n cli,\n [\"embed-multi\", \"files\", \"-m\", \"embed-demo\", \"--files\"] + args,\n )\n assert result.exit_code == 2\n assert expected_error in result.output"} | |
{"id": "tests/test_embed_cli.py:524", "code": "@pytest.mark.parametrize(\n \"extra_args,expected_error\",\n (\n # With no args default utf-8 with latin-1 fallback should work\n ([], None),\n ([\"--encoding\", \"utf-8\"], \"Could not decode text in file\"),\n ([\"--encoding\", \"latin-1\"], None),\n ([\"--encoding\", \"latin-1\", \"--encoding\", \"utf-8\"], None),\n ([\"--encoding\", \"utf-8\", \"--encoding\", \"latin-1\"], None),\n ),\n)\ndef test_embed_multi_files_encoding(multi_files, extra_args, expected_error):\n db_path, files = multi_files\n runner = CliRunner(mix_stderr=False)\n result = runner.invoke(\n cli,\n [\n \"embed-multi\",\n \"files\",\n \"-d\",\n db_path,\n \"-m\",\n \"embed-demo\",\n \"--files\",\n str(files / \"nested\" / \"more\"),\n \"*.ini\",\n \"--store\",\n ]\n + extra_args,\n )\n if expected_error:\n # Should still succeed with 0, but show a warning\n assert result.exit_code == 0\n assert expected_error in result.stderr\n else:\n assert result.exit_code == 0\n assert not result.stderr\n embeddings_db = sqlite_utils.Database(db_path)\n rows = list(\n embeddings_db.query(\"select id, content from embeddings order by id\")\n )\n assert rows == [\n {\"id\": \"ignored.ini\", \"content\": \"Has weird \\x96 character\"},\n ]"} | |
{"id": "tests/test_embed_cli.py:570", "code": "def test_default_embedding_model():\n runner = CliRunner()\n result = runner.invoke(cli, [\"embed-models\", \"default\"])\n assert result.exit_code == 0\n assert result.output == \"<No default embedding model set>\\n\"\n result2 = runner.invoke(cli, [\"embed-models\", \"default\", \"ada-002\"])\n assert result2.exit_code == 0\n result3 = runner.invoke(cli, [\"embed-models\", \"default\"])\n assert result3.exit_code == 0\n assert result3.output == \"text-embedding-ada-002\\n\"\n result4 = runner.invoke(cli, [\"embed-models\", \"default\", \"--remove-default\"])\n assert result4.exit_code == 0\n result5 = runner.invoke(cli, [\"embed-models\", \"default\"])\n assert result5.exit_code == 0\n assert result5.output == \"<No default embedding model set>\\n\"\n # Now set the default and actually use it\n result6 = runner.invoke(cli, [\"embed-models\", \"default\", \"embed-demo\"])\n assert result6.exit_code == 0\n result7 = runner.invoke(cli, [\"embed\", \"-c\", \"hello world\"])\n assert result7.exit_code == 0\n assert result7.output == \"[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\\n\""} | |
{"id": "tests/test_embed_cli.py:593", "code": "@pytest.mark.parametrize(\n \"args,expected_model_id\",\n (\n ([\"-q\", \"text-embedding-3-large\"], \"text-embedding-3-large\"),\n ([\"-q\", \"text\", \"-q\", \"3\"], \"text-embedding-3-large\"),\n ),\n)\ndef test_llm_embed_models_query(user_path, args, expected_model_id):\n runner = CliRunner()\n result = runner.invoke(cli, [\"embed-models\"] + args, catch_exceptions=False)\n assert result.exit_code == 0\n assert expected_model_id in result.output"} | |
{"id": "tests/test_embed_cli.py:607", "code": "@pytest.mark.parametrize(\"default_is_set\", (False, True))\n@pytest.mark.parametrize(\"command\", (\"embed\", \"embed-multi\"))\ndef test_default_embed_model_errors(user_path, default_is_set, command):\n runner = CliRunner()\n if default_is_set:\n (user_path / \"default_embedding_model.txt\").write_text(\n \"embed-demo\", encoding=\"utf8\"\n )\n args = []\n input = None\n if command == \"embed-multi\":\n args = [\"embed-multi\", \"example\", \"-\"]\n input = \"id,name\\n1,hello\"\n else:\n args = [\"embed\", \"example\", \"1\", \"-c\", \"hello world\"]\n result = runner.invoke(cli, args, input=input, catch_exceptions=False)\n if default_is_set:\n assert result.exit_code == 0\n else:\n assert result.exit_code == 1\n assert (\n \"You need to specify an embedding model (no default model is set)\"\n in result.output\n )\n # Now set the default model and try again\n result2 = runner.invoke(cli, [\"embed-models\", \"default\", \"embed-demo\"])\n assert result2.exit_code == 0\n result3 = runner.invoke(cli, args, input=input, catch_exceptions=False)\n assert result3.exit_code == 0\n # At the end of this, there should be 2 embeddings\n db = sqlite_utils.Database(str(user_path / \"embeddings.db\"))\n assert db[\"embeddings\"].count == 1"} | |
{"id": "tests/test_embed_cli.py:641", "code": "def test_duplicate_content_embedded_only_once(embed_demo):\n # content_hash should avoid embedding the same content twice\n # per collection\n db = sqlite_utils.Database(memory=True)\n assert len(embed_demo.embedded_content) == 0\n collection = Collection(\"test\", db, model_id=\"embed-demo\")\n collection.embed(\"1\", \"hello world\")\n assert len(embed_demo.embedded_content) == 1\n collection.embed(\"2\", \"goodbye world\")\n assert db[\"embeddings\"].count == 2\n assert len(embed_demo.embedded_content) == 2\n collection.embed(\"1\", \"hello world\")\n assert db[\"embeddings\"].count == 2\n assert len(embed_demo.embedded_content) == 2\n # The same string in another collection should be embedded\n c2 = Collection(\"test2\", db, model_id=\"embed-demo\")\n c2.embed(\"1\", \"hello world\")\n assert db[\"embeddings\"].count == 3\n assert len(embed_demo.embedded_content) == 3\n\n # Same again for embed_multi\n collection.embed_multi(\n ((\"1\", \"hello world\"), (\"2\", \"goodbye world\"), (\"3\", \"this is new\"))\n )\n # Should have only embedded one more thing\n assert db[\"embeddings\"].count == 4\n assert len(embed_demo.embedded_content) == 4"} | |
{"id": "tests/test_chat.py:8", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\ndef test_chat_basic(mock_model, logs_db):\n runner = CliRunner()\n mock_model.enqueue([\"one world\"])\n mock_model.enqueue([\"one again\"])\n result = runner.invoke(\n llm.cli.cli,\n [\"chat\", \"-m\", \"mock\"],\n input=\"Hi\\nHi two\\nquit\\n\",\n catch_exceptions=False,\n )\n assert result.exit_code == 0\n assert result.output == (\n \"Chatting with mock\"\n \"\\nType 'exit' or 'quit' to exit\"\n \"\\nType '!multi' to enter multiple lines, then '!end' to finish\"\n \"\\n> Hi\"\n \"\\none world\"\n \"\\n> Hi two\"\n \"\\none again\"\n \"\\n> quit\"\n \"\\n\"\n )\n # Should have logged\n conversations = list(logs_db[\"conversations\"].rows)\n assert conversations[0] == {\n \"id\": ANY,\n \"name\": \"Hi\",\n \"model\": \"mock\",\n }\n conversation_id = conversations[0][\"id\"]\n responses = list(logs_db[\"responses\"].rows)\n assert responses == [\n {\n \"id\": ANY,\n \"model\": \"mock\",\n \"prompt\": \"Hi\",\n \"system\": None,\n \"prompt_json\": None,\n \"options_json\": \"{}\",\n \"response\": \"one world\",\n \"response_json\": None,\n \"conversation_id\": conversation_id,\n \"duration_ms\": ANY,\n \"datetime_utc\": ANY,\n \"input_tokens\": 1,\n \"output_tokens\": 1,\n \"token_details\": None,\n \"schema_id\": None,\n },\n {\n \"id\": ANY,\n \"model\": \"mock\",\n \"prompt\": \"Hi two\",\n \"system\": None,\n \"prompt_json\": None,\n \"options_json\": \"{}\",\n \"response\": \"one again\",\n \"response_json\": None,\n \"conversation_id\": conversation_id,\n \"duration_ms\": ANY,\n \"datetime_utc\": ANY,\n \"input_tokens\": 2,\n \"output_tokens\": 1,\n \"token_details\": None,\n \"schema_id\": None,\n },\n ]\n # Now continue that conversation\n mock_model.enqueue([\"continued\"])\n result2 = runner.invoke(\n llm.cli.cli,\n [\"chat\", \"-m\", \"mock\", \"-c\"],\n input=\"Continue\\nquit\\n\",\n catch_exceptions=False,\n )\n assert result2.exit_code == 0\n assert result2.output == (\n \"Chatting with mock\"\n \"\\nType 'exit' or 'quit' to exit\"\n \"\\nType '!multi' to enter multiple lines, then '!end' to finish\"\n \"\\n> Continue\"\n \"\\ncontinued\"\n \"\\n> quit\"\n \"\\n\"\n )\n new_responses = list(\n logs_db.query(\n \"select * from responses where id not in ({})\".format(\n \", \".join(\"?\" for _ in responses)\n ),\n [r[\"id\"] for r in responses],\n )\n )\n assert new_responses == [\n {\n \"id\": ANY,\n \"model\": \"mock\",\n \"prompt\": \"Continue\",\n \"system\": None,\n \"prompt_json\": None,\n \"options_json\": \"{}\",\n \"response\": \"continued\",\n \"response_json\": None,\n \"conversation_id\": conversation_id,\n \"duration_ms\": ANY,\n \"datetime_utc\": ANY,\n \"input_tokens\": 1,\n \"output_tokens\": 1,\n \"token_details\": None,\n \"schema_id\": None,\n }\n ]"} | |
{"id": "tests/test_chat.py:123", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\ndef test_chat_system(mock_model, logs_db):\n runner = CliRunner()\n mock_model.enqueue([\"I am mean\"])\n result = runner.invoke(\n llm.cli.cli,\n [\"chat\", \"-m\", \"mock\", \"--system\", \"You are mean\"],\n input=\"Hi\\nquit\\n\",\n )\n assert result.exit_code == 0\n assert result.output == (\n \"Chatting with mock\"\n \"\\nType 'exit' or 'quit' to exit\"\n \"\\nType '!multi' to enter multiple lines, then '!end' to finish\"\n \"\\n> Hi\"\n \"\\nI am mean\"\n \"\\n> quit\"\n \"\\n\"\n )\n responses = list(logs_db[\"responses\"].rows)\n assert responses == [\n {\n \"id\": ANY,\n \"model\": \"mock\",\n \"prompt\": \"Hi\",\n \"system\": \"You are mean\",\n \"prompt_json\": None,\n \"options_json\": \"{}\",\n \"response\": \"I am mean\",\n \"response_json\": None,\n \"conversation_id\": ANY,\n \"duration_ms\": ANY,\n \"datetime_utc\": ANY,\n \"input_tokens\": 1,\n \"output_tokens\": 1,\n \"token_details\": None,\n \"schema_id\": None,\n }\n ]"} | |
{"id": "tests/test_chat.py:164", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\ndef test_chat_options(mock_model, logs_db):\n runner = CliRunner()\n mock_model.enqueue([\"Some text\"])\n result = runner.invoke(\n llm.cli.cli,\n [\"chat\", \"-m\", \"mock\", \"--option\", \"max_tokens\", \"10\"],\n input=\"Hi\\nquit\\n\",\n )\n assert result.exit_code == 0\n responses = list(logs_db[\"responses\"].rows)\n assert responses == [\n {\n \"id\": ANY,\n \"model\": \"mock\",\n \"prompt\": \"Hi\",\n \"system\": None,\n \"prompt_json\": None,\n \"options_json\": '{\"max_tokens\": 10}',\n \"response\": \"Some text\",\n \"response_json\": None,\n \"conversation_id\": ANY,\n \"duration_ms\": ANY,\n \"datetime_utc\": ANY,\n \"input_tokens\": 1,\n \"output_tokens\": 1,\n \"token_details\": None,\n \"schema_id\": None,\n }\n ]"} | |
{"id": "tests/test_chat.py:196", "code": "@pytest.mark.xfail(sys.platform == \"win32\", reason=\"Expected to fail on Windows\")\n@pytest.mark.parametrize(\n \"input,expected\",\n (\n (\n \"Hi\\n!multi\\nthis is multiple lines\\nuntil the !end\\n!end\\nquit\\n\",\n [\n {\"prompt\": \"Hi\", \"response\": \"One\\n\"},\n {\n \"prompt\": \"this is multiple lines\\nuntil the !end\",\n \"response\": \"Two\\n\",\n },\n ],\n ),\n # quit should not work within !multi\n (\n \"!multi\\nthis is multiple lines\\nquit\\nuntil the !end\\n!end\\nquit\\n\",\n [\n {\n \"prompt\": \"this is multiple lines\\nquit\\nuntil the !end\",\n \"response\": \"One\\n\",\n }\n ],\n ),\n # Try custom delimiter\n (\n \"!multi abc\\nCustom delimiter\\n!end\\n!end 123\\n!end abc\\nquit\\n\",\n [{\"prompt\": \"Custom delimiter\\n!end\\n!end 123\", \"response\": \"One\\n\"}],\n ),\n ),\n)\ndef test_chat_multi(mock_model, logs_db, input, expected):\n runner = CliRunner()\n mock_model.enqueue([\"One\\n\"])\n mock_model.enqueue([\"Two\\n\"])\n mock_model.enqueue([\"Three\\n\"])\n result = runner.invoke(\n llm.cli.cli, [\"chat\", \"-m\", \"mock\", \"--option\", \"max_tokens\", \"10\"], input=input\n )\n assert result.exit_code == 0\n rows = list(logs_db[\"responses\"].rows_where(select=\"prompt, response\"))\n assert rows == expected"} | |
{"id": "tests/test_aliases.py:9", "code": "@pytest.mark.parametrize(\"model_id_or_alias\", (\"gpt-3.5-turbo\", \"chatgpt\"))\ndef test_set_alias(model_id_or_alias):\n with pytest.raises(llm.UnknownModelError):\n llm.get_model(\"this-is-a-new-alias\")\n llm.set_alias(\"this-is-a-new-alias\", model_id_or_alias)\n assert llm.get_model(\"this-is-a-new-alias\").model_id == \"gpt-3.5-turbo\""} | |
{"id": "tests/test_aliases.py:17", "code": "def test_remove_alias():\n with pytest.raises(KeyError):\n llm.remove_alias(\"some-other-alias\")\n llm.set_alias(\"some-other-alias\", \"gpt-3.5-turbo\")\n assert llm.get_model(\"some-other-alias\").model_id == \"gpt-3.5-turbo\"\n llm.remove_alias(\"some-other-alias\")\n with pytest.raises(llm.UnknownModelError):\n llm.get_model(\"some-other-alias\")"} | |
{"id": "tests/test_aliases.py:27", "code": "@pytest.mark.parametrize(\"args\", ([\"aliases\", \"list\"], [\"aliases\"]))\ndef test_cli_aliases_list(args):\n llm.set_alias(\"e-demo\", \"embed-demo\")\n runner = CliRunner()\n result = runner.invoke(cli, args)\n assert result.exit_code == 0\n for line in (\n \"3.5 : gpt-3.5-turbo\\n\"\n \"chatgpt : gpt-3.5-turbo\\n\"\n \"chatgpt-16k : gpt-3.5-turbo-16k\\n\"\n \"3.5-16k : gpt-3.5-turbo-16k\\n\"\n \"4 : gpt-4\\n\"\n \"gpt4 : gpt-4\\n\"\n \"4-32k : gpt-4-32k\\n\"\n \"e-demo : embed-demo (embedding)\\n\"\n \"ada : text-embedding-ada-002 (embedding)\\n\"\n ).split(\"\\n\"):\n line = line.strip()\n if not line:\n continue\n # Turn the whitespace into a regex\n regex = r\"\\s+\".join(re.escape(part) for part in line.split())\n assert re.search(regex, result.output)"} | |
{"id": "tests/test_aliases.py:52", "code": "@pytest.mark.parametrize(\"args\", ([\"aliases\", \"list\"], [\"aliases\"]))\ndef test_cli_aliases_list_json(args):\n llm.set_alias(\"e-demo\", \"embed-demo\")\n runner = CliRunner()\n result = runner.invoke(cli, args + [\"--json\"])\n assert result.exit_code == 0\n assert (\n json.loads(result.output).items()\n >= {\n \"3.5\": \"gpt-3.5-turbo\",\n \"chatgpt\": \"gpt-3.5-turbo\",\n \"chatgpt-16k\": \"gpt-3.5-turbo-16k\",\n \"3.5-16k\": \"gpt-3.5-turbo-16k\",\n \"4\": \"gpt-4\",\n \"gpt4\": \"gpt-4\",\n \"4-32k\": \"gpt-4-32k\",\n \"ada\": \"text-embedding-ada-002\",\n \"e-demo\": \"embed-demo\",\n }.items()\n )"} | |
{"id": "tests/test_aliases.py:74", "code": "@pytest.mark.parametrize(\n \"args,expected,expected_error\",\n (\n ([\"foo\", \"bar\"], {\"foo\": \"bar\"}, None),\n ([\"foo\", \"-q\", \"mo\"], {\"foo\": \"mock\"}, None),\n ([\"foo\", \"-q\", \"mog\"], None, \"No model found matching query: mog\"),\n ),\n)\ndef test_cli_aliases_set(user_path, args, expected, expected_error):\n # Should be not aliases.json at start\n assert not (user_path / \"aliases.json\").exists()\n runner = CliRunner()\n result = runner.invoke(cli, [\"aliases\", \"set\"] + args)\n if not expected_error:\n assert result.exit_code == 0\n assert (user_path / \"aliases.json\").exists()\n assert json.loads((user_path / \"aliases.json\").read_text(\"utf-8\")) == expected\n else:\n assert result.exit_code == 1\n assert result.output.strip() == f\"Error: {expected_error}\""} | |
{"id": "tests/test_aliases.py:96", "code": "def test_cli_aliases_path(user_path):\n runner = CliRunner()\n result = runner.invoke(cli, [\"aliases\", \"path\"])\n assert result.exit_code == 0\n assert result.output.strip() == str(user_path / \"aliases.json\")"} | |
{"id": "tests/test_aliases.py:103", "code": "def test_cli_aliases_remove(user_path):\n (user_path / \"aliases.json\").write_text(json.dumps({\"foo\": \"bar\"}), \"utf-8\")\n runner = CliRunner()\n result = runner.invoke(cli, [\"aliases\", \"remove\", \"foo\"])\n assert result.exit_code == 0\n assert json.loads((user_path / \"aliases.json\").read_text(\"utf-8\")) == {}"} | |
{"id": "tests/test_aliases.py:111", "code": "def test_cli_aliases_remove_invalid(user_path):\n (user_path / \"aliases.json\").write_text(json.dumps({\"foo\": \"bar\"}), \"utf-8\")\n runner = CliRunner()\n result = runner.invoke(cli, [\"aliases\", \"remove\", \"invalid\"])\n assert result.exit_code == 1\n assert result.output == \"Error: No such alias: invalid\\n\""} | |
{"id": "tests/test_aliases.py:119", "code": "@pytest.mark.parametrize(\"args\", ([\"models\"], [\"models\", \"list\"]))\ndef test_cli_aliases_are_registered(user_path, args):\n (user_path / \"aliases.json\").write_text(\n json.dumps({\"foo\": \"bar\", \"turbo\": \"gpt-3.5-turbo\"}), \"utf-8\"\n )\n runner = CliRunner()\n result = runner.invoke(cli, args)\n assert result.exit_code == 0\n assert \"gpt-3.5-turbo (aliases: 3.5, chatgpt, turbo)\" in result.output"} | |
{"id": "tests/test_embed.py:9", "code": "def test_demo_plugin():\n model = llm.get_embedding_model(\"embed-demo\")\n assert model.embed(\"hello world\") == [5, 5] + [0] * 14"} | |
{"id": "tests/test_embed.py:14", "code": "@pytest.mark.parametrize(\n \"batch_size,expected_batches\",\n (\n (None, 100),\n (10, 100),\n ),\n)\ndef test_embed_huge_list(batch_size, expected_batches):\n model = llm.get_embedding_model(\"embed-demo\")\n huge_list = (\"hello {}\".format(i) for i in range(1000))\n kwargs = {}\n if batch_size:\n kwargs[\"batch_size\"] = batch_size\n results = model.embed_multi(huge_list, **kwargs)\n assert repr(type(results)) == \"<class 'generator'>\"\n first_twos = {}\n for result in results:\n key = (result[0], result[1])\n first_twos[key] = first_twos.get(key, 0) + 1\n assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}\n assert model.batch_count == expected_batches"} | |
{"id": "tests/test_embed.py:37", "code": "def test_embed_store(collection):\n collection.embed(\"3\", \"hello world again\", store=True)\n assert collection.db[\"embeddings\"].count == 3\n assert (\n next(collection.db[\"embeddings\"].rows_where(\"id = ?\", [\"3\"]))[\"content\"]\n == \"hello world again\"\n )"} | |
{"id": "tests/test_embed.py:46", "code": "def test_embed_metadata(collection):\n collection.embed(\"3\", \"hello yet again\", metadata={\"foo\": \"bar\"}, store=True)\n assert collection.db[\"embeddings\"].count == 3\n assert json.loads(\n next(collection.db[\"embeddings\"].rows_where(\"id = ?\", [\"3\"]))[\"metadata\"]\n ) == {\"foo\": \"bar\"}\n entry = collection.similar(\"hello yet again\")[0]\n assert entry.id == \"3\"\n assert entry.metadata == {\"foo\": \"bar\"}\n assert entry.content == \"hello yet again\""} | |
{"id": "tests/test_embed.py:58", "code": "def test_collection(collection):\n assert collection.id == 1\n assert collection.count() == 2\n # Check that the embeddings are there\n rows = list(collection.db[\"embeddings\"].rows)\n assert rows == [\n {\n \"collection_id\": 1,\n \"id\": \"1\",\n \"embedding\": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n \"content\": None,\n \"content_blob\": None,\n \"content_hash\": collection.content_hash(\"hello world\"),\n \"metadata\": None,\n \"updated\": ANY,\n },\n {\n \"collection_id\": 1,\n \"id\": \"2\",\n \"embedding\": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n \"content\": None,\n \"content_blob\": None,\n \"content_hash\": collection.content_hash(\"goodbye world\"),\n \"metadata\": None,\n \"updated\": ANY,\n },\n ]\n assert isinstance(rows[0][\"updated\"], int) and rows[0][\"updated\"] > 0"} | |
{"id": "tests/test_embed.py:88", "code": "def test_similar(collection):\n results = list(collection.similar(\"hello world\"))\n assert results == [\n Entry(id=\"1\", score=pytest.approx(0.9999999999999999)),\n Entry(id=\"2\", score=pytest.approx(0.9863939238321437)),\n ]"} | |
{"id": "tests/test_embed.py:96", "code": "def test_similar_by_id(collection):\n results = list(collection.similar_by_id(\"1\"))\n assert results == [\n Entry(id=\"2\", score=pytest.approx(0.9863939238321437)),\n ]"} | |
{"id": "tests/test_embed.py:103", "code": "@pytest.mark.parametrize(\n \"batch_size,expected_batches\",\n (\n (None, 100),\n (5, 200),\n ),\n)\n@pytest.mark.parametrize(\"with_metadata\", (False, True))\ndef test_embed_multi(with_metadata, batch_size, expected_batches):\n db = sqlite_utils.Database(memory=True)\n collection = llm.Collection(\"test\", db, model_id=\"embed-demo\")\n model = collection.model()\n assert getattr(model, \"batch_count\", 0) == 0\n ids_and_texts = ((str(i), \"hello {}\".format(i)) for i in range(1000))\n kwargs = {}\n if batch_size is not None:\n kwargs[\"batch_size\"] = batch_size\n if with_metadata:\n ids_and_texts = ((id, text, {\"meta\": id}) for id, text in ids_and_texts)\n collection.embed_multi_with_metadata(ids_and_texts, **kwargs)\n else:\n # Exercise store=True here too\n collection.embed_multi(ids_and_texts, store=True, **kwargs)\n rows = list(db[\"embeddings\"].rows)\n assert len(rows) == 1000\n rows_with_metadata = [row for row in rows if row[\"metadata\"] is not None]\n rows_with_content = [row for row in rows if row[\"content\"] is not None]\n if with_metadata:\n assert len(rows_with_metadata) == 1000\n assert len(rows_with_content) == 0\n else:\n assert len(rows_with_metadata) == 0\n assert len(rows_with_content) == 1000\n # Every row should have content_hash set\n assert all(row[\"content_hash\"] is not None for row in rows)\n # Check batch count\n assert collection.model().batch_count == expected_batches"} | |
{"id": "tests/test_embed.py:142", "code": "def test_collection_delete(collection):\n db = collection.db\n assert db[\"embeddings\"].count == 2\n assert db[\"collections\"].count == 1\n collection.delete()\n assert db[\"embeddings\"].count == 0\n assert db[\"collections\"].count == 0"} | |
{"id": "tests/test_embed.py:151", "code": "def test_binary_only_and_text_only_embedding_models():\n binary_only = llm.get_embedding_model(\"embed-binary-only\")\n text_only = llm.get_embedding_model(\"embed-text-only\")\n\n assert binary_only.supports_binary\n assert not binary_only.supports_text\n assert not text_only.supports_binary\n assert text_only.supports_text\n\n with pytest.raises(ValueError):\n binary_only.embed(\"hello world\")\n\n binary_only.embed(b\"hello world\")\n\n with pytest.raises(ValueError):\n text_only.embed(b\"hello world\")\n\n text_only.embed(\"hello world\")\n\n # Try the multi versions too\n # Have to call list() on this or the generator is not evaluated\n with pytest.raises(ValueError):\n list(binary_only.embed_multi([\"hello world\"]))\n\n list(binary_only.embed_multi([b\"hello world\"]))\n\n with pytest.raises(ValueError):\n list(text_only.embed_multi([b\"hello world\"]))\n\n list(text_only.embed_multi([\"hello world\"]))"} | |
{"id": "tests/test_migrate.py:27", "code": "def test_migrate_blank():\n db = sqlite_utils.Database(memory=True)\n migrate(db)\n assert set(db.table_names()).issuperset(\n {\"_llm_migrations\", \"conversations\", \"responses\", \"responses_fts\"}\n )\n assert db[\"responses\"].columns_dict == EXPECTED\n\n foreign_keys = db[\"responses\"].foreign_keys\n for expected_fk in (\n sqlite_utils.db.ForeignKey(\n table=\"responses\",\n column=\"conversation_id\",\n other_table=\"conversations\",\n other_column=\"id\",\n ),\n ):\n assert expected_fk in foreign_keys"} | |
{"id": "tests/test_migrate.py:47", "code": "@pytest.mark.parametrize(\"has_record\", [True, False])\ndef test_migrate_from_original_schema(has_record):\n db = sqlite_utils.Database(memory=True)\n if has_record:\n db[\"log\"].insert(\n {\n \"provider\": \"provider\",\n \"system\": \"system\",\n \"prompt\": \"prompt\",\n \"chat_id\": None,\n \"response\": \"response\",\n \"model\": \"model\",\n \"timestamp\": \"timestamp\",\n },\n )\n else:\n # Create empty logs table\n db[\"log\"].create(\n {\n \"provider\": str,\n \"system\": str,\n \"prompt\": str,\n \"chat_id\": str,\n \"response\": str,\n \"model\": str,\n \"timestamp\": str,\n }\n )\n migrate(db)\n expected_tables = {\"_llm_migrations\", \"conversations\", \"responses\", \"responses_fts\"}\n if has_record:\n expected_tables.add(\"logs\")\n assert set(db.table_names()).issuperset(expected_tables)"} | |
{"id": "tests/test_migrate.py:82", "code": "def test_migrations_with_legacy_alter_table():\n # https://github.com/simonw/llm/issues/162\n db = sqlite_utils.Database(memory=True)\n db.execute(\"pragma legacy_alter_table=on\")\n migrate(db)"} | |
{"id": "tests/test_migrate.py:89", "code": "def test_migrations_for_embeddings():\n db = sqlite_utils.Database(memory=True)\n embeddings_migrations.apply(db)\n assert db[\"collections\"].columns_dict == {\"id\": int, \"name\": str, \"model\": str}\n assert db[\"embeddings\"].columns_dict == {\n \"collection_id\": int,\n \"id\": str,\n \"embedding\": bytes,\n \"content\": str,\n \"content_blob\": bytes,\n \"content_hash\": bytes,\n \"metadata\": str,\n \"updated\": int,\n }\n assert db[\"embeddings\"].foreign_keys[0].column == \"collection_id\"\n assert db[\"embeddings\"].foreign_keys[0].other_table == \"collections\""} | |
{"id": "tests/test_migrate.py:107", "code": "def test_backfill_content_hash():\n db = sqlite_utils.Database(memory=True)\n # Run migrations up to but not including m004_store_content_hash\n embeddings_migrations.apply(db, stop_before=\"m004_store_content_hash\")\n assert \"content_hash\" not in db[\"embeddings\"].columns_dict\n # Add some some directly directly because llm.Collection would run migrations\n db[\"embeddings\"].insert_all(\n [\n {\n \"collection_id\": 1,\n \"id\": \"1\",\n \"embedding\": (\n b\"\\x00\\x00\\xa0@\\x00\\x00\\xa0@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n ),\n \"content\": None,\n \"metadata\": None,\n \"updated\": 1693763088,\n },\n {\n \"collection_id\": 1,\n \"id\": \"2\",\n \"embedding\": (\n b\"\\x00\\x00\\xe0@\\x00\\x00\\xa0@\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n b\"\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\"\n ),\n \"content\": \"goodbye world\",\n \"metadata\": None,\n \"updated\": 1693763088,\n },\n ]\n )\n # Now finish the migrations\n embeddings_migrations.apply(db)\n row1, row2 = db[\"embeddings\"].rows\n # This one should be random:\n assert row1[\"content_hash\"] is not None\n # This should be a hash of 'goodbye world'\n assert row2[\"content_hash\"] == llm.Collection.content_hash(\"goodbye world\")"} | |
{"id": "tests/test_attachments.py:20", "code": "@pytest.mark.parametrize(\n \"attachment_type,attachment_content\",\n [\n (\"image/png\", TINY_PNG),\n (\"audio/wav\", TINY_WAV),\n ],\n)\ndef test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content):\n runner = CliRunner()\n mock_model.enqueue([\"two boxes\"])\n result = runner.invoke(\n cli.cli,\n [\"prompt\", \"-m\", \"mock\", \"describe file\", \"-a\", \"-\"],\n input=attachment_content,\n catch_exceptions=False,\n )\n assert result.exit_code == 0, result.output\n assert result.output == \"two boxes\\n\"\n assert mock_model.history[0][0].attachments[0] == llm.Attachment(\n type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY\n )\n\n # Check it was logged correctly\n conversations = list(logs_db[\"conversations\"].rows)\n assert len(conversations) == 1\n conversation = conversations[0]\n assert conversation[\"model\"] == \"mock\"\n assert conversation[\"name\"] == \"describe file\"\n response = list(logs_db[\"responses\"].rows)[0]\n attachment = list(logs_db[\"attachments\"].rows)[0]\n assert attachment == {\n \"id\": ANY,\n \"type\": attachment_type,\n \"path\": None,\n \"url\": None,\n \"content\": attachment_content,\n }\n prompt_attachment = list(logs_db[\"prompt_attachments\"].rows)[0]\n assert prompt_attachment[\"attachment_id\"] == attachment[\"id\"]\n assert prompt_attachment[\"response_id\"] == response[\"id\"]"} | |
{"id": "tests/test_cli_options.py:7", "code": "@pytest.mark.parametrize(\n \"args,expected_options,expected_error\",\n (\n (\n [\"gpt-4o-mini\", \"temperature\", \"0.5\"],\n {\"gpt-4o-mini\": {\"temperature\": \"0.5\"}},\n None,\n ),\n (\n [\"gpt-4o-mini\", \"temperature\", \"invalid\"],\n {},\n \"Error: temperature\\n Input should be a valid number\",\n ),\n (\n [\"gpt-4o-mini\", \"not-an-option\", \"invalid\"],\n {},\n \"Extra inputs are not permitted\",\n ),\n ),\n)\ndef test_set_model_default_options(user_path, args, expected_options, expected_error):\n path = user_path / \"model_options.json\"\n assert not path.exists()\n runner = CliRunner()\n result = runner.invoke(cli, [\"models\", \"options\", \"set\"] + args)\n if not expected_error:\n assert result.exit_code == 0\n assert path.exists()\n data = json.loads(path.read_text(\"utf-8\"))\n assert data == expected_options\n else:\n assert result.exit_code == 1\n assert expected_error in result.output"} | |
{"id": "tests/test_cli_options.py:42", "code": "def test_model_options_list_and_show(user_path):\n (user_path / \"model_options.json\").write_text(\n json.dumps(\n {\"gpt-4o-mini\": {\"temperature\": 0.5}, \"gpt-4o\": {\"temperature\": 0.7}}\n ),\n \"utf-8\",\n )\n runner = CliRunner()\n result = runner.invoke(cli, [\"models\", \"options\", \"list\"])\n assert result.exit_code == 0\n assert (\n result.output\n == \"gpt-4o-mini:\\n temperature: 0.5\\ngpt-4o:\\n temperature: 0.7\\n\"\n )\n result = runner.invoke(cli, [\"models\", \"options\", \"show\", \"gpt-4o-mini\"])\n assert result.exit_code == 0\n assert result.output == \"temperature: 0.5\\n\""} | |
{"id": "tests/test_cli_options.py:61", "code": "def test_model_options_clear(user_path):\n path = user_path / \"model_options.json\"\n path.write_text(\n json.dumps(\n {\n \"gpt-4o-mini\": {\"temperature\": 0.5},\n \"gpt-4o\": {\"temperature\": 0.7, \"top_p\": 0.9},\n }\n ),\n \"utf-8\",\n )\n assert path.exists()\n runner = CliRunner()\n # Clear all for gpt-4o-mini\n result = runner.invoke(cli, [\"models\", \"options\", \"clear\", \"gpt-4o-mini\"])\n assert result.exit_code == 0\n # Clear just top_p for gpt-4o\n result2 = runner.invoke(cli, [\"models\", \"options\", \"clear\", \"gpt-4o\", \"top_p\"])\n assert result2.exit_code == 0\n data = json.loads(path.read_text(\"utf-8\"))\n assert data == {\"gpt-4o\": {\"temperature\": 0.7}}"} | |
{"id": "docs/plugins/llm-markov/llm_markov.py:8", "code": "@llm.hookimpl\ndef register_models(register):\n register(Markov())"} | |
{"id": "docs/plugins/llm-markov/llm_markov.py:13", "code": "def build_markov_table(text):\n words = text.split()\n transitions = {}\n # Loop through all but the last word\n for i in range(len(words) - 1):\n word = words[i]\n next_word = words[i + 1]\n transitions.setdefault(word, []).append(next_word)\n return transitions"} | |
{"id": "docs/plugins/llm-markov/llm_markov.py:24", "code": "def generate(transitions, length, start_word=None):\n all_words = list(transitions.keys())\n next_word = start_word or random.choice(all_words)\n for i in range(length):\n yield next_word\n options = transitions.get(next_word) or all_words\n next_word = random.choice(options)"} | |
{"id": "docs/plugins/llm-markov/llm_markov.py:33", "code": "class Markov(llm.Model):\n model_id = \"markov\"\n can_stream = True\n\n class Options(llm.Options):\n length: Optional[int] = Field(\n description=\"Number of words to generate\", default=None\n )\n delay: Optional[float] = Field(\n description=\"Seconds to delay between each token\", default=None\n )\n\n @field_validator(\"length\")\n def validate_length(cls, length):\n if length is None:\n return None\n if length < 2:\n raise ValueError(\"length must be >= 2\")\n return length\n\n @field_validator(\"delay\")\n def validate_delay(cls, delay):\n if delay is None:\n return None\n if not 0 <= delay <= 10:\n raise ValueError(\"delay must be between 0 and 10\")\n return delay\n\n def execute(self, prompt, stream, response, conversation):\n text = prompt.prompt\n transitions = build_markov_table(text)\n length = prompt.options.length or 20\n for word in generate(transitions, length):\n yield word + \" \"\n if prompt.options.delay:\n time.sleep(prompt.options.delay)"} | |
{"id": "docs/plugins/llm-markov/llm_markov.py:61", "code": " def execute(self, prompt, stream, response, conversation):\n text = prompt.prompt\n transitions = build_markov_table(text)\n length = prompt.options.length or 20\n for word in generate(transitions, length):\n yield word + \" \"\n if prompt.options.delay:\n time.sleep(prompt.options.delay)"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment