Created
March 31, 2026 12:02
-
-
Save ElliotRoe/54dbb816cbdc3a5df29b331eb59ea1d6 to your computer and use it in GitHub Desktop.
Lit Lake Core Files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import json | |
| import sqlite3 | |
| import struct | |
| import os | |
| import sys | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any, Sequence | |
| from docling.document_converter import DocumentConverter | |
| from docling_core.transforms.chunker.hybrid_chunker import HybridChunker | |
| @dataclass(frozen=True) | |
| class UpsertReferenceInput: | |
| db_path: str | Path | |
| source_system: str | |
| source_id: str | |
| title: str | |
| authors: str | None = None | |
| year: str | None = None | |
| abstract: str | None = None | |
| pdf_path: str | Path = "" | |
| @dataclass(frozen=True) | |
| class UpsertReferenceResult: | |
| reference_id: int | |
| document_file_id: int | |
| pdf_changed: bool | |
| chunks_written: int | |
| embeddings_written: int | |
| sha256: str | |
| def to_dict(self) -> dict[str, object]: | |
| return asdict(self) | |
| @dataclass(frozen=True) | |
| class ChunkArtifact: | |
| content: str | |
| DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" | |
| DEFAULT_DIM = 384 | |
| CACHE_DIR = Path("~/.cache/docling/models").expanduser().resolve() | |
| def connect_db(path: Path | str, *, read_only: bool = False) -> sqlite3.Connection: | |
| if read_only: | |
| conn = sqlite3.connect(f"file:{Path(path)}?mode=ro", uri=True, check_same_thread=False) | |
| else: | |
| conn = sqlite3.connect(str(path), check_same_thread=False) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA journal_mode=WAL;") | |
| conn.execute("PRAGMA busy_timeout=5000;") | |
| conn.execute("PRAGMA foreign_keys=ON;") | |
| _load_sqlite_vec(conn) | |
| return conn | |
| def _load_sqlite_vec(conn: sqlite3.Connection) -> None: | |
| import sqlite_vec | |
| conn.enable_load_extension(True) | |
| try: | |
| sqlite_vec.load(conn) | |
| finally: | |
| conn.enable_load_extension(False) | |
| def init_db(db_path: str | Path) -> None: | |
| conn = connect_db(Path(db_path).expanduser().resolve()) | |
| try: | |
| _init_schema(conn) | |
| finally: | |
| conn.close() | |
| def _init_schema(conn: sqlite3.Connection) -> None: | |
| conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS reference_items ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| title TEXT, | |
| authors TEXT, | |
| year TEXT, | |
| source_system TEXT NOT NULL, | |
| source_id TEXT NOT NULL, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| UNIQUE(source_system, source_id) | |
| ) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS document_files ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| reference_id INTEGER NOT NULL, | |
| file_path TEXT NOT NULL, | |
| mime_type TEXT NOT NULL DEFAULT 'application/pdf', | |
| file_hash_sha256 TEXT NOT NULL, | |
| extracted_text TEXT, | |
| source_system TEXT NOT NULL, | |
| source_id TEXT NOT NULL, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY(reference_id) REFERENCES reference_items(id) ON DELETE CASCADE, | |
| UNIQUE(source_system, source_id) | |
| ) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| reference_id INTEGER NOT NULL, | |
| document_file_id INTEGER, | |
| kind TEXT NOT NULL, | |
| content TEXT, | |
| chunk_index INTEGER, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY(reference_id) REFERENCES reference_items(id) ON DELETE CASCADE, | |
| FOREIGN KEY(document_file_id) REFERENCES document_files(id) ON DELETE CASCADE | |
| ) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE INDEX IF NOT EXISTS idx_documents_reference_id | |
| ON documents(reference_id) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE INDEX IF NOT EXISTS idx_documents_document_file_id | |
| ON documents(document_file_id) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE INDEX IF NOT EXISTS idx_documents_kind | |
| ON documents(kind) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE VIRTUAL TABLE IF NOT EXISTS documents_fts USING fts5( | |
| content, | |
| kind UNINDEXED, | |
| reference_id UNINDEXED, | |
| tokenize='unicode61' | |
| ) | |
| """ | |
| ) | |
| conn.execute( | |
| "CREATE VIRTUAL TABLE IF NOT EXISTS vec_documents USING vec0(embedding float[384] distance_metric=cosine);" | |
| ) | |
| conn.commit() | |
| def delete_fts_rows(conn: sqlite3.Connection, doc_ids: list[int]) -> None: | |
| for doc_id in doc_ids: | |
| conn.execute("DELETE FROM documents_fts WHERE rowid = ?", (doc_id,)) | |
| def upsert_fts_row( | |
| conn: sqlite3.Connection, | |
| *, | |
| doc_id: int, | |
| content: str, | |
| kind: str, | |
| reference_id: int, | |
| ) -> None: | |
| conn.execute("DELETE FROM documents_fts WHERE rowid = ?", (doc_id,)) | |
| conn.execute( | |
| """ | |
| INSERT INTO documents_fts(rowid, content, kind, reference_id) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (doc_id, content, kind, reference_id), | |
| ) | |
| def search_by_vec( | |
| conn: sqlite3.Connection, | |
| query_vector: list[float], | |
| k: int, | |
| ) -> list[tuple[int, str, str, int, float]]: | |
| """Run KNN over vec_documents and return (doc_id, content, kind, reference_id, distance).""" | |
| if not query_vector or k <= 0: | |
| return [] | |
| vector_blob = sqlite3.Binary(serialize_vector(query_vector)) | |
| rows = conn.execute( | |
| """ | |
| SELECT v.rowid AS doc_id, v.distance, | |
| d.content, d.kind, d.reference_id | |
| FROM vec_documents v | |
| JOIN documents d ON d.id = v.rowid | |
| WHERE v.embedding MATCH ? AND k = ? | |
| """, | |
| (vector_blob, k), | |
| ).fetchall() | |
| return [ | |
| (int(r["doc_id"]), r["content"] or "", r["kind"], int(r["reference_id"]), float(r["distance"])) | |
| for r in rows | |
| ] | |
| def search_by_string( | |
| conn: sqlite3.Connection, | |
| query: str, | |
| k: int, | |
| ) -> list[tuple[int, str, str, int, float]]: | |
| """Embed the query string and run semantic search. Returns (doc_id, content, kind, reference_id, distance).""" | |
| query = (query or "").strip() | |
| if not query: | |
| return [] | |
| vectors = embed_texts([query]) | |
| if not vectors: | |
| return [] | |
| return search_by_vec(conn, vectors[0], k) | |
| def delete_vec_rows(conn: sqlite3.Connection, doc_ids: list[int]) -> None: | |
| for doc_id in doc_ids: | |
| conn.execute("DELETE FROM vec_documents WHERE rowid = ?", (doc_id,)) | |
| def upsert_vec_row(conn: sqlite3.Connection, *, doc_id: int, vector_bytes: bytes) -> None: | |
| conn.execute("DELETE FROM vec_documents WHERE rowid = ?", (doc_id,)) | |
| conn.execute( | |
| "INSERT INTO vec_documents(rowid, embedding) VALUES (?, ?)", | |
| (doc_id, sqlite3.Binary(vector_bytes)), | |
| ) | |
| def delete_documents(conn: sqlite3.Connection, doc_ids: list[int]) -> None: | |
| if not doc_ids: | |
| return | |
| delete_fts_rows(conn, doc_ids) | |
| delete_vec_rows(conn, doc_ids) | |
| placeholders = ",".join("?" for _ in doc_ids) | |
| conn.execute(f"DELETE FROM documents WHERE id IN ({placeholders})", doc_ids) | |
| def ingest_pdf_chunks(pdf_path: Path, converter: DocumentConverter, chunker: HybridChunker) -> list[ChunkArtifact]: | |
| result = converter.convert(str(pdf_path)) | |
| document = getattr(result, "document", result) | |
| raw_chunks = list(chunker.chunk(document)) | |
| artifacts: list[ChunkArtifact] = [] | |
| for chunk in raw_chunks: | |
| content = _chunk_content(chunker, chunk).strip() | |
| if content: | |
| artifacts.append(ChunkArtifact(content=content)) | |
| if not artifacts: | |
| raise ValueError(f"Docling produced no chunks for {pdf_path}") | |
| return artifacts | |
| def _chunk_content(chunker: Any, chunk: Any) -> str: | |
| text = getattr(chunk, "text", None) | |
| if isinstance(text, str): | |
| return text | |
| if callable(text): | |
| out = text() | |
| if isinstance(out, str): | |
| return out | |
| serialize = getattr(chunker, "serialize", None) | |
| if callable(serialize): | |
| out = serialize(chunk) | |
| if isinstance(out, str): | |
| return out | |
| content = getattr(chunk, "content", None) | |
| if isinstance(content, str): | |
| return content | |
| return str(chunk) | |
| def embed_texts( | |
| texts: list[str] | |
| ) -> list[list[float]]: | |
| if not texts: | |
| return [] | |
| from fastembed import TextEmbedding | |
| model_cache = CACHE_DIR / "FastEmbed" | |
| model = TextEmbedding(model_name=DEFAULT_MODEL, cache_dir=str(model_cache)) | |
| vectors = list(model.embed(texts)) | |
| out = [list(vector) for vector in vectors] | |
| return out | |
| def serialize_vector(vector: list[float]) -> bytes: | |
| return struct.pack(f"{len(vector)}f", *vector) | |
| def upsert_reference( | |
| payload: UpsertReferenceInput, | |
| converter, | |
| chunker | |
| ) -> UpsertReferenceResult: | |
| db_path = Path(payload.db_path).expanduser().resolve() | |
| pdf_path = Path(payload.pdf_path).expanduser().resolve() | |
| _validate_input(payload, pdf_path) | |
| sha256 = _sha256_file(pdf_path) | |
| title = _normalize_required(payload.title, field_name="title") | |
| authors = _normalize_optional(payload.authors) | |
| year = _normalize_optional(payload.year) | |
| abstract = _normalize_optional(payload.abstract) | |
| conn = connect_db(db_path) | |
| try: | |
| _init_schema(conn) | |
| existing_ref = conn.execute( | |
| """ | |
| SELECT id, title, authors, year | |
| FROM reference_items | |
| WHERE source_system = ? AND source_id = ? | |
| """, | |
| (payload.source_system, payload.source_id), | |
| ).fetchone() | |
| reference_id = int(existing_ref["id"]) if existing_ref is not None else None | |
| existing_file = conn.execute( | |
| """ | |
| SELECT id, reference_id, file_path, file_hash_sha256, extracted_text | |
| FROM document_files | |
| WHERE source_system = ? AND source_id = ? | |
| """, | |
| (payload.source_system, payload.source_id), | |
| ).fetchone() | |
| metadata_docs = {} | |
| if reference_id is not None: | |
| rows = conn.execute( | |
| """ | |
| SELECT id, kind, content | |
| FROM documents | |
| WHERE reference_id = ? | |
| AND kind IN ('title','abstract') | |
| AND document_file_id IS NULL | |
| """, | |
| (reference_id,), | |
| ).fetchall() | |
| metadata_docs = {str(row["kind"]): row for row in rows} | |
| title_changed = _document_content_changed(metadata_docs.get("title"), title) | |
| abstract_changed = _document_content_changed(metadata_docs.get("abstract"), abstract) | |
| old_hash = str(existing_file["file_hash_sha256"]) if existing_file is not None else None | |
| pdf_changed = old_hash != sha256 | |
| chunk_artifacts = [] | |
| extracted_text = None | |
| if pdf_changed: | |
| chunk_artifacts = ingest_pdf_chunks(pdf_path, converter, chunker) | |
| extracted_text = "\n\n".join(chunk.content for chunk in chunk_artifacts) | |
| embedding_keys: list[tuple[str, int | None]] = [] | |
| embedding_texts: list[str] = [] | |
| if title_changed and title: | |
| embedding_keys.append(("title", None)) | |
| embedding_texts.append(title) | |
| if abstract_changed and abstract: | |
| embedding_keys.append(("abstract", None)) | |
| embedding_texts.append(abstract) | |
| if pdf_changed: | |
| for idx, chunk in enumerate(chunk_artifacts): | |
| embedding_keys.append(("fulltext_chunk", idx)) | |
| embedding_texts.append(chunk.content) | |
| embedded_vectors = embed_texts(embedding_texts) | |
| if len(embedded_vectors) != len(embedding_keys): | |
| raise ValueError("Embedding provider returned unexpected vector count") | |
| vectors_by_key = { | |
| key: embedded_vectors[idx] | |
| for idx, key in enumerate(embedding_keys) | |
| } | |
| conn.execute("BEGIN IMMEDIATE") | |
| try: | |
| reference_id = _upsert_reference_row( | |
| conn, | |
| existing_ref=existing_ref, | |
| source_system=payload.source_system, | |
| source_id=payload.source_id, | |
| title=title, | |
| authors=authors, | |
| year=year, | |
| ) | |
| file_id = _upsert_file_row( | |
| conn, | |
| existing_file=existing_file, | |
| reference_id=reference_id, | |
| source_system=payload.source_system, | |
| source_id=payload.source_id, | |
| pdf_path=pdf_path, | |
| sha256=sha256, | |
| extracted_text=extracted_text, | |
| pdf_changed=pdf_changed, | |
| ) | |
| embeddings_written = 0 | |
| title_doc_id = _apply_metadata_document( | |
| conn, | |
| reference_id=reference_id, | |
| kind="title", | |
| content=title, | |
| existing_row=metadata_docs.get("title"), | |
| ) | |
| if title_changed and title_doc_id is not None: | |
| vector = vectors_by_key[("title", None)] | |
| _index_document( | |
| conn, | |
| doc_id=title_doc_id, | |
| reference_id=reference_id, | |
| kind="title", | |
| content=title, | |
| vector=vector, | |
| ) | |
| embeddings_written += 1 | |
| abstract_doc_id = _apply_metadata_document( | |
| conn, | |
| reference_id=reference_id, | |
| kind="abstract", | |
| content=abstract, | |
| existing_row=metadata_docs.get("abstract"), | |
| ) | |
| if abstract_changed and abstract_doc_id is not None and abstract is not None: | |
| vector = vectors_by_key[("abstract", None)] | |
| _index_document( | |
| conn, | |
| doc_id=abstract_doc_id, | |
| reference_id=reference_id, | |
| kind="abstract", | |
| content=abstract, | |
| vector=vector, | |
| ) | |
| embeddings_written += 1 | |
| chunks_written = 0 | |
| if pdf_changed: | |
| old_chunk_rows = conn.execute( | |
| """ | |
| SELECT id | |
| FROM documents | |
| WHERE document_file_id = ? | |
| AND kind = 'fulltext_chunk' | |
| ORDER BY chunk_index ASC | |
| """, | |
| (file_id,), | |
| ).fetchall() | |
| delete_documents(conn, [int(row["id"]) for row in old_chunk_rows]) | |
| for idx, chunk in enumerate(chunk_artifacts): | |
| doc_id = _insert_document( | |
| conn, | |
| reference_id=reference_id, | |
| document_file_id=file_id, | |
| kind="fulltext_chunk", | |
| content=chunk.content, | |
| chunk_index=idx, | |
| ) | |
| vector = vectors_by_key[("fulltext_chunk", idx)] | |
| _index_document( | |
| conn, | |
| doc_id=doc_id, | |
| reference_id=reference_id, | |
| kind="fulltext_chunk", | |
| content=chunk.content, | |
| vector=vector, | |
| ) | |
| embeddings_written += 1 | |
| chunks_written += 1 | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| return UpsertReferenceResult( | |
| reference_id=reference_id, | |
| document_file_id=file_id, | |
| pdf_changed=pdf_changed, | |
| chunks_written=chunks_written, | |
| embeddings_written=embeddings_written, | |
| sha256=sha256, | |
| ) | |
| finally: | |
| conn.close() | |
| def _validate_input(payload: UpsertReferenceInput, pdf_path: Path) -> None: | |
| if not _normalize_required(payload.source_system, field_name="source_system"): | |
| raise ValueError("source_system is required") | |
| if not _normalize_required(payload.source_id, field_name="source_id"): | |
| raise ValueError("source_id is required") | |
| if not pdf_path.exists() or not pdf_path.is_file(): | |
| raise FileNotFoundError(f"PDF path not found: {pdf_path}") | |
| def _upsert_reference_row( | |
| conn: sqlite3.Connection, | |
| *, | |
| existing_ref: sqlite3.Row | None, | |
| source_system: str, | |
| source_id: str, | |
| title: str, | |
| authors: str | None, | |
| year: str | None, | |
| ) -> int: | |
| if existing_ref is None: | |
| cur = conn.execute( | |
| """ | |
| INSERT INTO reference_items( | |
| source_system, source_id, title, authors, year | |
| ) VALUES (?, ?, ?, ?, ?) | |
| """, | |
| (source_system, source_id, title, authors, year), | |
| ) | |
| return int(cur.lastrowid) | |
| changed = ( | |
| _normalize_optional(existing_ref["title"]) != title | |
| or _normalize_optional(existing_ref["authors"]) != authors | |
| or _normalize_optional(existing_ref["year"]) != year | |
| ) | |
| if changed: | |
| conn.execute( | |
| """ | |
| UPDATE reference_items | |
| SET title = ?, authors = ?, year = ?, updated_at = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| """, | |
| (title, authors, year, int(existing_ref["id"])), | |
| ) | |
| return int(existing_ref["id"]) | |
| def _upsert_file_row( | |
| conn: sqlite3.Connection, | |
| *, | |
| existing_file: sqlite3.Row | None, | |
| reference_id: int, | |
| source_system: str, | |
| source_id: str, | |
| pdf_path: Path, | |
| sha256: str, | |
| extracted_text: str | None, | |
| pdf_changed: bool, | |
| ) -> int: | |
| if existing_file is None: | |
| cur = conn.execute( | |
| """ | |
| INSERT INTO document_files( | |
| reference_id, | |
| file_path, | |
| mime_type, | |
| file_hash_sha256, | |
| extracted_text, | |
| source_system, | |
| source_id | |
| ) VALUES (?, ?, 'application/pdf', ?, ?, ?, ?) | |
| """, | |
| ( | |
| reference_id, | |
| str(pdf_path), | |
| sha256, | |
| extracted_text, | |
| source_system, | |
| source_id, | |
| ), | |
| ) | |
| return int(cur.lastrowid) | |
| next_extracted = extracted_text if pdf_changed else existing_file["extracted_text"] | |
| conn.execute( | |
| """ | |
| UPDATE document_files | |
| SET reference_id = ?, | |
| file_path = ?, | |
| file_hash_sha256 = ?, | |
| extracted_text = ?, | |
| updated_at = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| """, | |
| ( | |
| reference_id, | |
| str(pdf_path), | |
| sha256, | |
| next_extracted, | |
| int(existing_file["id"]), | |
| ), | |
| ) | |
| return int(existing_file["id"]) | |
| def _apply_metadata_document( | |
| conn: sqlite3.Connection, | |
| *, | |
| reference_id: int, | |
| kind: str, | |
| content: str | None, | |
| existing_row: sqlite3.Row | None, | |
| ) -> int | None: | |
| if content is None: | |
| if existing_row is not None: | |
| delete_documents(conn, [int(existing_row["id"])]) | |
| return None | |
| if existing_row is None: | |
| return _insert_document( | |
| conn, | |
| reference_id=reference_id, | |
| document_file_id=None, | |
| kind=kind, | |
| content=content, | |
| chunk_index=None, | |
| ) | |
| old_content = _normalize_optional(existing_row["content"]) | |
| if old_content == content: | |
| return int(existing_row["id"]) | |
| conn.execute( | |
| """ | |
| UPDATE documents | |
| SET content = ?, | |
| updated_at = CURRENT_TIMESTAMP | |
| WHERE id = ? | |
| """, | |
| (content, int(existing_row["id"])), | |
| ) | |
| return int(existing_row["id"]) | |
| def _insert_document( | |
| conn: sqlite3.Connection, | |
| *, | |
| reference_id: int, | |
| document_file_id: int | None, | |
| kind: str, | |
| content: str, | |
| chunk_index: int | None, | |
| ) -> int: | |
| cur = conn.execute( | |
| """ | |
| INSERT INTO documents( | |
| reference_id, | |
| document_file_id, | |
| kind, | |
| content, | |
| chunk_index | |
| ) VALUES (?, ?, ?, ?, ?) | |
| """, | |
| (reference_id, document_file_id, kind, content, chunk_index), | |
| ) | |
| return int(cur.lastrowid) | |
| def _index_document( | |
| conn: sqlite3.Connection, | |
| *, | |
| doc_id: int, | |
| reference_id: int, | |
| kind: str, | |
| content: str, | |
| vector: list[float], | |
| ) -> None: | |
| upsert_fts_row( | |
| conn, | |
| doc_id=doc_id, | |
| content=content, | |
| kind=kind, | |
| reference_id=reference_id, | |
| ) | |
| upsert_vec_row(conn, doc_id=doc_id, vector_bytes=serialize_vector(vector)) | |
| def _normalize_required(value: str | None, *, field_name: str) -> str: | |
| out = _normalize_optional(value) | |
| if out is None: | |
| raise ValueError(f"{field_name} is required") | |
| return out | |
| def _normalize_optional(value: Any) -> str | None: | |
| if value is None: | |
| return None | |
| text = str(value).strip() | |
| return text if text else None | |
| def _sha256_file(path: Path) -> str: | |
| digest = hashlib.sha256() | |
| with path.open("rb") as handle: | |
| while True: | |
| chunk = handle.read(1024 * 1024) | |
| if not chunk: | |
| break | |
| digest.update(chunk) | |
| return digest.hexdigest() | |
| def _document_content_changed(existing_row: sqlite3.Row | None, new_content: str | None) -> bool: | |
| if existing_row is None: | |
| return new_content is not None | |
| old_content = _normalize_optional(existing_row["content"]) | |
| return old_content != new_content | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="litlake-core headless reference manager") | |
| subparsers = parser.add_subparsers(dest="command", required=True) | |
| upsert = subparsers.add_parser("upsert", help="Upsert a single reference + PDF") | |
| upsert.add_argument("--db", required=True, dest="db_path") | |
| upsert.add_argument("--source-system", required=True) | |
| upsert.add_argument("--source-id", required=True) | |
| upsert.add_argument("--title", required=True) | |
| upsert.add_argument("--authors", required=True) | |
| upsert.add_argument("--year", required=True) | |
| upsert.add_argument("--abstract", required=True) | |
| upsert.add_argument("--pdf", required=True, dest="pdf_path") | |
| upsert = subparsers.add_parser("batch_upsert", help="Upsert a batch of references in jsonl format") | |
| upsert.add_argument("--db", required=True, dest="db_path") | |
| upsert.add_argument("--jsonl", required=True, dest="jsonl_path") | |
| search_parser = subparsers.add_parser("search", help="Semantic search by query string") | |
| search_parser.add_argument("--db", required=True, dest="db_path") | |
| search_parser.add_argument("query", help="Search query (embedded and matched against document embeddings)") | |
| search_parser.add_argument("-k", "--top-k", type=int, default=10, dest="top_k", help="Number of results (default: 10)") | |
| return parser | |
| def main(argv: Sequence[str] | None = None) -> int: | |
| from docling.datamodel.base_models import InputFormat | |
| from docling.datamodel.pipeline_options import EasyOcrOptions, PdfPipelineOptions | |
| from docling.document_converter import DocumentConverter, PdfFormatOption | |
| parser = build_parser() | |
| args = parser.parse_args(list(argv) if argv is not None else None) | |
| if args.command == "upsert": | |
| pipeline_options = PdfPipelineOptions(artifacts_path=str(CACHE_DIR)) | |
| converter = DocumentConverter( | |
| format_options={ | |
| InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) | |
| } | |
| ) | |
| chunker = HybridChunker() | |
| payload = UpsertReferenceInput( | |
| db_path=args.db_path, | |
| source_system=args.source_system, | |
| source_id=args.source_id, | |
| title=args.title, | |
| authors=args.authors, | |
| year=args.year, | |
| abstract=args.abstract, | |
| pdf_path=args.pdf_path, | |
| ) | |
| result = upsert_reference(payload, converter, chunker) | |
| json.dump(result.to_dict(), sys.stdout, ensure_ascii=False) | |
| sys.stdout.write("\n") | |
| return 0 | |
| if args.command == "batch_upsert": | |
| pipeline_options = PdfPipelineOptions(artifacts_path=str(CACHE_DIR)) | |
| converter = DocumentConverter( | |
| format_options={ | |
| InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) | |
| } | |
| ) | |
| chunker = HybridChunker() | |
| if not os.path.exists(args.jsonl_path): | |
| parser.error(f"Path must exist: {args.jsonl_path}") | |
| with open(args.jsonl_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| reference_args = json.loads(line) | |
| if not os.path.exists(reference_args['pdf_path']): | |
| continue | |
| payload = UpsertReferenceInput( | |
| db_path=args.db_path, | |
| source_system=reference_args['source_system'], | |
| source_id=reference_args['source_id'], | |
| title=reference_args['title'], | |
| authors=reference_args['authors'], | |
| year=reference_args['year'], | |
| abstract=reference_args['abstract'], | |
| pdf_path=reference_args['pdf_path'], | |
| ) | |
| result = upsert_reference(payload, converter, chunker) | |
| json.dump(result.to_dict(), sys.stdout, ensure_ascii=False) | |
| sys.stdout.write("\n") | |
| return 0 | |
| if args.command == "search": | |
| conn = connect_db(args.db_path, read_only=True) | |
| try: | |
| results = search_by_string(conn, args.query, args.top_k) | |
| for doc_id, content, kind, reference_id, distance in results: | |
| out = { | |
| "doc_id": doc_id, | |
| "reference_id": reference_id, | |
| "kind": kind, | |
| "distance": distance, | |
| "content": content[:500] + ("..." if len(content) > 500 else ""), | |
| } | |
| json.dump(out, sys.stdout, ensure_ascii=False) | |
| sys.stdout.write("\n") | |
| finally: | |
| conn.close() | |
| return 0 | |
| parser.error(f"Unknown command: {args.command}") | |
| return 2 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Generate per-input JSONL files from PubMed XML with optional filters. | |
| Each input XML/XML.GZ becomes one output JSONL in --output-dir. | |
| If that JSONL already exists, the XML file is skipped. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gzip | |
| import json | |
| import re | |
| import sys | |
| import xml.etree.ElementTree as ET | |
| from pathlib import Path | |
| from typing import Iterable, Iterator, Optional | |
| PMCID_TOKEN_RE = re.compile(r"^(?:PMC)?(\d+)$", re.IGNORECASE) | |
| DIGITS_ONLY_RE = re.compile(r"^\d+$") | |
| YEAR_RE = re.compile(r"\b(1[89]\d{2}|20\d{2}|2100)\b") | |
| WRITE_BUFFER_SIZE = 1024 * 1024 | |
| def open_text(path: Path): | |
| if path.suffix.lower() == ".gz": | |
| return gzip.open(path, "rt", encoding="utf-8", errors="replace") | |
| return path.open("rt", encoding="utf-8", errors="replace") | |
| def collapse_ws(text: str) -> str: | |
| return " ".join(text.split()) | |
| def element_text(elem: Optional[ET.Element]) -> str: | |
| if elem is None: | |
| return "" | |
| return collapse_ws("".join(elem.itertext())) | |
| def normalize_pmcid(value: str) -> Optional[str]: | |
| match = PMCID_TOKEN_RE.match(value.strip()) | |
| if not match: | |
| return None | |
| return f"PMC{int(match.group(1))}" | |
| def pmcid_to_int(value: str) -> Optional[int]: | |
| match = PMCID_TOKEN_RE.match(value.strip()) | |
| if not match: | |
| return None | |
| return int(match.group(1)) | |
| def extract_id_token(line: str, id_type: str) -> Optional[int]: | |
| for token in re.split(r"[\s,;|]+", line.strip()): | |
| if not token: | |
| continue | |
| token = token.strip().strip('"\'') | |
| if id_type == "pmcid": | |
| normalized = pmcid_to_int(token) | |
| if normalized is not None: | |
| return normalized | |
| elif id_type == "pmid" and DIGITS_ONLY_RE.match(token): | |
| return int(token) | |
| return None | |
| def load_id_filter(id_file: Path, id_type: str) -> set[int]: | |
| ids: set[int] = set() | |
| with open_text(id_file) as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| token = extract_id_token(line, id_type) | |
| if token is not None: | |
| ids.add(token) | |
| return ids | |
| def iter_input_files_from_dir(input_dir: Path, patterns: list[str]) -> Iterator[Path]: | |
| if input_dir.is_file(): | |
| yield input_dir | |
| return | |
| seen: set[Path] = set() | |
| for pattern in patterns: | |
| for file_path in sorted(input_dir.glob(pattern)): | |
| if not file_path.is_file(): | |
| continue | |
| resolved = file_path.resolve() | |
| if resolved in seen: | |
| continue | |
| seen.add(resolved) | |
| yield resolved | |
| def iter_input_files_from_list(file_list: Path) -> Iterator[Path]: | |
| with open_text(file_list) as handle: | |
| for raw_line in handle: | |
| line = raw_line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| p = Path(line).expanduser() | |
| if not p.is_absolute(): | |
| p = (file_list.parent / p).resolve() | |
| if p.is_file(): | |
| yield p | |
| else: | |
| print(f"WARNING: listed file not found, skipping: {line}", file=sys.stderr) | |
| def extract_abstract(article: ET.Element) -> str: | |
| sections: list[str] = [] | |
| for node in article.findall("./MedlineCitation/Article/Abstract/AbstractText"): | |
| text = element_text(node) | |
| if not text: | |
| continue | |
| label = collapse_ws((node.attrib.get("Label") or "").strip()) | |
| if label and not text.lower().startswith(label.lower()): | |
| text = f"{label}: {text}" | |
| sections.append(text) | |
| if not sections: | |
| for node in article.findall("./MedlineCitation/OtherAbstract/AbstractText"): | |
| text = element_text(node) | |
| if text: | |
| sections.append(text) | |
| return "\n\n".join(sections) | |
| def extract_publication_year(article: ET.Element) -> Optional[int]: | |
| candidate_paths = ( | |
| "./MedlineCitation/Article/Journal/JournalIssue/PubDate/Year", | |
| "./MedlineCitation/Article/ArticleDate/Year", | |
| "./MedlineCitation/DateCompleted/Year", | |
| "./MedlineCitation/DateCreated/Year", | |
| ) | |
| for path in candidate_paths: | |
| value = (article.findtext(path) or "").strip() | |
| if value and DIGITS_ONLY_RE.match(value): | |
| year = int(value) | |
| if 1800 <= year <= 2100: | |
| return year | |
| medline_date = ( | |
| article.findtext("./MedlineCitation/Article/Journal/JournalIssue/PubDate/MedlineDate") | |
| or "" | |
| ).strip() | |
| if medline_date: | |
| match = YEAR_RE.search(medline_date) | |
| if match: | |
| return int(match.group(1)) | |
| return None | |
| def extract_authors(article: ET.Element) -> list[str]: | |
| authors: list[str] = [] | |
| for node in article.findall("./MedlineCitation/Article/AuthorList/Author"): | |
| collective = element_text(node.find("./CollectiveName")) | |
| if collective: | |
| authors.append(collective) | |
| continue | |
| last = element_text(node.find("./LastName")) | |
| fore = element_text(node.find("./ForeName")) | |
| initials = element_text(node.find("./Initials")) | |
| if fore and last: | |
| authors.append(f"{fore} {last}") | |
| elif initials and last: | |
| authors.append(f"{initials} {last}") | |
| elif last: | |
| authors.append(last) | |
| elif fore: | |
| authors.append(fore) | |
| return authors | |
| def iter_pubmed_records(xml_path: Path) -> Iterator[dict[str, object]]: | |
| with open_text(xml_path) as handle: | |
| context = ET.iterparse(handle, events=("start", "end")) | |
| _, root = next(context) | |
| for event, elem in context: | |
| if event != "end" or elem.tag != "PubmedArticle": | |
| continue | |
| pmid = (elem.findtext("./MedlineCitation/PMID") or "").strip() | |
| if not pmid: | |
| elem.clear() | |
| continue | |
| title = element_text(elem.find("./MedlineCitation/Article/ArticleTitle")) | |
| abstract = extract_abstract(elem) | |
| publication_year = extract_publication_year(elem) | |
| authors = extract_authors(elem) | |
| pmcid = "" | |
| for node in elem.findall("./PubmedData/ArticleIdList/ArticleId"): | |
| if (node.attrib.get("IdType") or "").lower() == "pmc": | |
| pmcid = normalize_pmcid(element_text(node)) or "" | |
| if pmcid: | |
| break | |
| yield { | |
| "pmid": str(int(pmid)) if DIGITS_ONLY_RE.match(pmid) else pmid, | |
| "pmcid": pmcid, | |
| "publication_year": publication_year, | |
| "authors": authors, | |
| "title": title, | |
| "abstract": abstract, | |
| } | |
| elem.clear() | |
| root.clear() | |
| def skip_reason( | |
| record: dict[str, object], | |
| require_pmcid: bool, | |
| filter_ids: Optional[set[int]], | |
| id_type: str, | |
| allow_empty_abstract: bool, | |
| ) -> Optional[str]: | |
| pmcid = str(record.get("pmcid") or "") | |
| pmid = str(record.get("pmid") or "") | |
| abstract = str(record.get("abstract") or "") | |
| if require_pmcid and not pmcid: | |
| return "no_pmcid" | |
| if filter_ids is not None: | |
| if id_type == "pmcid": | |
| lookup_value = pmcid_to_int(pmcid) if pmcid else None | |
| else: | |
| lookup_value = int(pmid) if DIGITS_ONLY_RE.match(pmid) else None | |
| if lookup_value is None or lookup_value not in filter_ids: | |
| return "not_in_list" | |
| if not allow_empty_abstract and not abstract: | |
| return "no_abstract" | |
| return None | |
| def output_path_for_input(xml_file: Path, output_dir: Path) -> Path: | |
| name = xml_file.name | |
| if name.endswith(".xml.gz"): | |
| stem = name[: -len(".xml.gz")] | |
| elif name.endswith(".xml"): | |
| stem = name[: -len(".xml")] | |
| else: | |
| stem = xml_file.stem | |
| return output_dir / f"{stem}.jsonl" | |
| def parse_args(argv: Optional[Iterable[str]] = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Extract title/abstract/year/authors from PubMed XML into per-file JSONL outputs." | |
| ) | |
| parser.add_argument( | |
| "--input-dir", | |
| type=Path, | |
| help="Input XML directory or single XML/XML.GZ file", | |
| ) | |
| parser.add_argument( | |
| "--input-file-list", | |
| type=Path, | |
| help="Text file with XML/XML.GZ paths (one per line)", | |
| ) | |
| parser.add_argument( | |
| "--input-glob", | |
| action="append", | |
| dest="input_globs", | |
| help=( | |
| "Glob pattern(s) relative to --input-dir, repeatable. " | |
| "Examples: '*.xml.gz', 'pubmed26n*.xml.gz', '**/*.xml.gz'. " | |
| "Defaults to '*.xml.gz' and '*.xml'" | |
| ), | |
| ) | |
| parser.add_argument("--output-dir", type=Path, required=True, help="Output directory for per-file JSONL") | |
| parser.add_argument( | |
| "--pmcid-list", | |
| type=Path, | |
| help="Optional text file of PMCID/PMID values to keep", | |
| ) | |
| parser.add_argument( | |
| "--require-pmcid", | |
| action="store_true", | |
| help="Keep only records with PMCID", | |
| ) | |
| parser.add_argument( | |
| "--id-type", | |
| choices=("pmcid", "pmid"), | |
| default="pmcid", | |
| help="How to interpret --pmcid-list values (default: pmcid)", | |
| ) | |
| parser.add_argument( | |
| "--allow-empty-abstract", | |
| action="store_true", | |
| help="Keep records even when abstract is missing", | |
| ) | |
| parser.add_argument( | |
| "--progress-every", | |
| type=int, | |
| default=100_000, | |
| help="Print global progress every N scanned records (default: 100000)", | |
| ) | |
| return parser.parse_args(argv) | |
| def main(argv: Optional[Iterable[str]] = None) -> int: | |
| args = parse_args(argv) | |
| if not args.input_dir and not args.input_file_list: | |
| print("ERROR: provide at least one of --input-dir or --input-file-list", file=sys.stderr) | |
| return 1 | |
| if args.input_dir and not args.input_dir.exists(): | |
| print(f"ERROR: input path not found: {args.input_dir}", file=sys.stderr) | |
| return 1 | |
| if args.input_file_list and not args.input_file_list.exists(): | |
| print(f"ERROR: input file list not found: {args.input_file_list}", file=sys.stderr) | |
| return 1 | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| filter_ids: Optional[set[int]] = None | |
| if args.pmcid_list: | |
| if not args.pmcid_list.exists(): | |
| print(f"ERROR: --pmcid-list file not found: {args.pmcid_list}", file=sys.stderr) | |
| return 1 | |
| print(f"Loading ID filter from {args.pmcid_list} ...", file=sys.stderr) | |
| filter_ids = load_id_filter(args.pmcid_list, args.id_type) | |
| print(f"Loaded {len(filter_ids):,} IDs", file=sys.stderr) | |
| seen_inputs: set[Path] = set() | |
| input_files: list[Path] = [] | |
| if args.input_file_list: | |
| for p in iter_input_files_from_list(args.input_file_list.resolve()): | |
| rp = p.resolve() | |
| if rp not in seen_inputs: | |
| seen_inputs.add(rp) | |
| input_files.append(rp) | |
| if args.input_dir: | |
| patterns = args.input_globs or ["*.xml.gz", "*.xml"] | |
| for p in iter_input_files_from_dir(args.input_dir.resolve(), patterns): | |
| rp = p.resolve() | |
| if rp not in seen_inputs: | |
| seen_inputs.add(rp) | |
| input_files.append(rp) | |
| if not input_files: | |
| print("ERROR: no input XML files found", file=sys.stderr) | |
| return 1 | |
| scanned = 0 | |
| written = 0 | |
| skipped_no_pmcid = 0 | |
| skipped_not_in_list = 0 | |
| skipped_empty_abstract = 0 | |
| skipped_existing_files = 0 | |
| processed_files = 0 | |
| for xml_file in input_files: | |
| output_path = output_path_for_input(xml_file, args.output_dir) | |
| if output_path.exists(): | |
| skipped_existing_files += 1 | |
| print(f"Skipping existing output: {output_path}", file=sys.stderr) | |
| continue | |
| print(f"Processing {xml_file}", file=sys.stderr) | |
| file_scanned = 0 | |
| file_written = 0 | |
| file_skipped_no_pmcid = 0 | |
| file_skipped_not_in_list = 0 | |
| file_skipped_empty_abstract = 0 | |
| with output_path.open("wt", encoding="utf-8", buffering=WRITE_BUFFER_SIZE) as out: | |
| for record in iter_pubmed_records(xml_file): | |
| scanned += 1 | |
| file_scanned += 1 | |
| if args.progress_every > 0 and scanned % args.progress_every == 0: | |
| print( | |
| f"Scanned={scanned:,} Written={written:,} " | |
| f"NoPMCID={skipped_no_pmcid:,} NotInList={skipped_not_in_list:,} " | |
| f"NoAbstract={skipped_empty_abstract:,}", | |
| file=sys.stderr, | |
| ) | |
| reason = skip_reason( | |
| record, | |
| args.require_pmcid, | |
| filter_ids, | |
| args.id_type, | |
| args.allow_empty_abstract, | |
| ) | |
| if reason == "no_pmcid": | |
| skipped_no_pmcid += 1 | |
| file_skipped_no_pmcid += 1 | |
| continue | |
| if reason == "not_in_list": | |
| skipped_not_in_list += 1 | |
| file_skipped_not_in_list += 1 | |
| continue | |
| if reason == "no_abstract": | |
| skipped_empty_abstract += 1 | |
| file_skipped_empty_abstract += 1 | |
| continue | |
| # Every record ends with '\n' so each JSONL is safe to concatenate via `cat`. | |
| out.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| written += 1 | |
| file_written += 1 | |
| processed_files += 1 | |
| file_success_rate = (file_written / file_scanned * 100.0) if file_scanned else 0.0 | |
| print( | |
| "FileSummary " | |
| f"File={xml_file} " | |
| f"Output={output_path} " | |
| f"Scanned={file_scanned:,} Written={file_written:,} " | |
| f"SuccessRate={file_success_rate:.2f}% " | |
| f"NoPMCID={file_skipped_no_pmcid:,} " | |
| f"NotInList={file_skipped_not_in_list:,} " | |
| f"NoAbstract={file_skipped_empty_abstract:,}", | |
| file=sys.stderr, | |
| ) | |
| print("Finished", file=sys.stderr) | |
| print( | |
| f"FilesTotal={len(input_files):,} FilesProcessed={processed_files:,} " | |
| f"FilesSkippedExisting={skipped_existing_files:,}", | |
| file=sys.stderr, | |
| ) | |
| print( | |
| f"Scanned={scanned:,} Written={written:,} " | |
| f"NoPMCID={skipped_no_pmcid:,} NotInList={skipped_not_in_list:,} " | |
| f"NoAbstract={skipped_empty_abstract:,}", | |
| file=sys.stderr, | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment