Skip to content

Instantly share code, notes, and snippets.

@ElliotRoe
Created March 31, 2026 12:02
Show Gist options
  • Select an option

  • Save ElliotRoe/54dbb816cbdc3a5df29b331eb59ea1d6 to your computer and use it in GitHub Desktop.

Select an option

Save ElliotRoe/54dbb816cbdc3a5df29b331eb59ea1d6 to your computer and use it in GitHub Desktop.
Lit Lake Core Files
from __future__ import annotations
import argparse
import hashlib
import json
import sqlite3
import struct
import os
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Sequence
from docling.document_converter import DocumentConverter
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
@dataclass(frozen=True)
class UpsertReferenceInput:
db_path: str | Path
source_system: str
source_id: str
title: str
authors: str | None = None
year: str | None = None
abstract: str | None = None
pdf_path: str | Path = ""
@dataclass(frozen=True)
class UpsertReferenceResult:
reference_id: int
document_file_id: int
pdf_changed: bool
chunks_written: int
embeddings_written: int
sha256: str
def to_dict(self) -> dict[str, object]:
return asdict(self)
@dataclass(frozen=True)
class ChunkArtifact:
content: str
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
DEFAULT_DIM = 384
CACHE_DIR = Path("~/.cache/docling/models").expanduser().resolve()
def connect_db(path: Path | str, *, read_only: bool = False) -> sqlite3.Connection:
if read_only:
conn = sqlite3.connect(f"file:{Path(path)}?mode=ro", uri=True, check_same_thread=False)
else:
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA busy_timeout=5000;")
conn.execute("PRAGMA foreign_keys=ON;")
_load_sqlite_vec(conn)
return conn
def _load_sqlite_vec(conn: sqlite3.Connection) -> None:
import sqlite_vec
conn.enable_load_extension(True)
try:
sqlite_vec.load(conn)
finally:
conn.enable_load_extension(False)
def init_db(db_path: str | Path) -> None:
conn = connect_db(Path(db_path).expanduser().resolve())
try:
_init_schema(conn)
finally:
conn.close()
def _init_schema(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS reference_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT,
authors TEXT,
year TEXT,
source_system TEXT NOT NULL,
source_id TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
UNIQUE(source_system, source_id)
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS document_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
reference_id INTEGER NOT NULL,
file_path TEXT NOT NULL,
mime_type TEXT NOT NULL DEFAULT 'application/pdf',
file_hash_sha256 TEXT NOT NULL,
extracted_text TEXT,
source_system TEXT NOT NULL,
source_id TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(reference_id) REFERENCES reference_items(id) ON DELETE CASCADE,
UNIQUE(source_system, source_id)
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
reference_id INTEGER NOT NULL,
document_file_id INTEGER,
kind TEXT NOT NULL,
content TEXT,
chunk_index INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(reference_id) REFERENCES reference_items(id) ON DELETE CASCADE,
FOREIGN KEY(document_file_id) REFERENCES document_files(id) ON DELETE CASCADE
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_documents_reference_id
ON documents(reference_id)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_documents_document_file_id
ON documents(document_file_id)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_documents_kind
ON documents(kind)
"""
)
conn.execute(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS documents_fts USING fts5(
content,
kind UNINDEXED,
reference_id UNINDEXED,
tokenize='unicode61'
)
"""
)
conn.execute(
"CREATE VIRTUAL TABLE IF NOT EXISTS vec_documents USING vec0(embedding float[384] distance_metric=cosine);"
)
conn.commit()
def delete_fts_rows(conn: sqlite3.Connection, doc_ids: list[int]) -> None:
for doc_id in doc_ids:
conn.execute("DELETE FROM documents_fts WHERE rowid = ?", (doc_id,))
def upsert_fts_row(
conn: sqlite3.Connection,
*,
doc_id: int,
content: str,
kind: str,
reference_id: int,
) -> None:
conn.execute("DELETE FROM documents_fts WHERE rowid = ?", (doc_id,))
conn.execute(
"""
INSERT INTO documents_fts(rowid, content, kind, reference_id)
VALUES (?, ?, ?, ?)
""",
(doc_id, content, kind, reference_id),
)
def search_by_vec(
conn: sqlite3.Connection,
query_vector: list[float],
k: int,
) -> list[tuple[int, str, str, int, float]]:
"""Run KNN over vec_documents and return (doc_id, content, kind, reference_id, distance)."""
if not query_vector or k <= 0:
return []
vector_blob = sqlite3.Binary(serialize_vector(query_vector))
rows = conn.execute(
"""
SELECT v.rowid AS doc_id, v.distance,
d.content, d.kind, d.reference_id
FROM vec_documents v
JOIN documents d ON d.id = v.rowid
WHERE v.embedding MATCH ? AND k = ?
""",
(vector_blob, k),
).fetchall()
return [
(int(r["doc_id"]), r["content"] or "", r["kind"], int(r["reference_id"]), float(r["distance"]))
for r in rows
]
def search_by_string(
conn: sqlite3.Connection,
query: str,
k: int,
) -> list[tuple[int, str, str, int, float]]:
"""Embed the query string and run semantic search. Returns (doc_id, content, kind, reference_id, distance)."""
query = (query or "").strip()
if not query:
return []
vectors = embed_texts([query])
if not vectors:
return []
return search_by_vec(conn, vectors[0], k)
def delete_vec_rows(conn: sqlite3.Connection, doc_ids: list[int]) -> None:
for doc_id in doc_ids:
conn.execute("DELETE FROM vec_documents WHERE rowid = ?", (doc_id,))
def upsert_vec_row(conn: sqlite3.Connection, *, doc_id: int, vector_bytes: bytes) -> None:
conn.execute("DELETE FROM vec_documents WHERE rowid = ?", (doc_id,))
conn.execute(
"INSERT INTO vec_documents(rowid, embedding) VALUES (?, ?)",
(doc_id, sqlite3.Binary(vector_bytes)),
)
def delete_documents(conn: sqlite3.Connection, doc_ids: list[int]) -> None:
if not doc_ids:
return
delete_fts_rows(conn, doc_ids)
delete_vec_rows(conn, doc_ids)
placeholders = ",".join("?" for _ in doc_ids)
conn.execute(f"DELETE FROM documents WHERE id IN ({placeholders})", doc_ids)
def ingest_pdf_chunks(pdf_path: Path, converter: DocumentConverter, chunker: HybridChunker) -> list[ChunkArtifact]:
result = converter.convert(str(pdf_path))
document = getattr(result, "document", result)
raw_chunks = list(chunker.chunk(document))
artifacts: list[ChunkArtifact] = []
for chunk in raw_chunks:
content = _chunk_content(chunker, chunk).strip()
if content:
artifacts.append(ChunkArtifact(content=content))
if not artifacts:
raise ValueError(f"Docling produced no chunks for {pdf_path}")
return artifacts
def _chunk_content(chunker: Any, chunk: Any) -> str:
text = getattr(chunk, "text", None)
if isinstance(text, str):
return text
if callable(text):
out = text()
if isinstance(out, str):
return out
serialize = getattr(chunker, "serialize", None)
if callable(serialize):
out = serialize(chunk)
if isinstance(out, str):
return out
content = getattr(chunk, "content", None)
if isinstance(content, str):
return content
return str(chunk)
def embed_texts(
texts: list[str]
) -> list[list[float]]:
if not texts:
return []
from fastembed import TextEmbedding
model_cache = CACHE_DIR / "FastEmbed"
model = TextEmbedding(model_name=DEFAULT_MODEL, cache_dir=str(model_cache))
vectors = list(model.embed(texts))
out = [list(vector) for vector in vectors]
return out
def serialize_vector(vector: list[float]) -> bytes:
return struct.pack(f"{len(vector)}f", *vector)
def upsert_reference(
payload: UpsertReferenceInput,
converter,
chunker
) -> UpsertReferenceResult:
db_path = Path(payload.db_path).expanduser().resolve()
pdf_path = Path(payload.pdf_path).expanduser().resolve()
_validate_input(payload, pdf_path)
sha256 = _sha256_file(pdf_path)
title = _normalize_required(payload.title, field_name="title")
authors = _normalize_optional(payload.authors)
year = _normalize_optional(payload.year)
abstract = _normalize_optional(payload.abstract)
conn = connect_db(db_path)
try:
_init_schema(conn)
existing_ref = conn.execute(
"""
SELECT id, title, authors, year
FROM reference_items
WHERE source_system = ? AND source_id = ?
""",
(payload.source_system, payload.source_id),
).fetchone()
reference_id = int(existing_ref["id"]) if existing_ref is not None else None
existing_file = conn.execute(
"""
SELECT id, reference_id, file_path, file_hash_sha256, extracted_text
FROM document_files
WHERE source_system = ? AND source_id = ?
""",
(payload.source_system, payload.source_id),
).fetchone()
metadata_docs = {}
if reference_id is not None:
rows = conn.execute(
"""
SELECT id, kind, content
FROM documents
WHERE reference_id = ?
AND kind IN ('title','abstract')
AND document_file_id IS NULL
""",
(reference_id,),
).fetchall()
metadata_docs = {str(row["kind"]): row for row in rows}
title_changed = _document_content_changed(metadata_docs.get("title"), title)
abstract_changed = _document_content_changed(metadata_docs.get("abstract"), abstract)
old_hash = str(existing_file["file_hash_sha256"]) if existing_file is not None else None
pdf_changed = old_hash != sha256
chunk_artifacts = []
extracted_text = None
if pdf_changed:
chunk_artifacts = ingest_pdf_chunks(pdf_path, converter, chunker)
extracted_text = "\n\n".join(chunk.content for chunk in chunk_artifacts)
embedding_keys: list[tuple[str, int | None]] = []
embedding_texts: list[str] = []
if title_changed and title:
embedding_keys.append(("title", None))
embedding_texts.append(title)
if abstract_changed and abstract:
embedding_keys.append(("abstract", None))
embedding_texts.append(abstract)
if pdf_changed:
for idx, chunk in enumerate(chunk_artifacts):
embedding_keys.append(("fulltext_chunk", idx))
embedding_texts.append(chunk.content)
embedded_vectors = embed_texts(embedding_texts)
if len(embedded_vectors) != len(embedding_keys):
raise ValueError("Embedding provider returned unexpected vector count")
vectors_by_key = {
key: embedded_vectors[idx]
for idx, key in enumerate(embedding_keys)
}
conn.execute("BEGIN IMMEDIATE")
try:
reference_id = _upsert_reference_row(
conn,
existing_ref=existing_ref,
source_system=payload.source_system,
source_id=payload.source_id,
title=title,
authors=authors,
year=year,
)
file_id = _upsert_file_row(
conn,
existing_file=existing_file,
reference_id=reference_id,
source_system=payload.source_system,
source_id=payload.source_id,
pdf_path=pdf_path,
sha256=sha256,
extracted_text=extracted_text,
pdf_changed=pdf_changed,
)
embeddings_written = 0
title_doc_id = _apply_metadata_document(
conn,
reference_id=reference_id,
kind="title",
content=title,
existing_row=metadata_docs.get("title"),
)
if title_changed and title_doc_id is not None:
vector = vectors_by_key[("title", None)]
_index_document(
conn,
doc_id=title_doc_id,
reference_id=reference_id,
kind="title",
content=title,
vector=vector,
)
embeddings_written += 1
abstract_doc_id = _apply_metadata_document(
conn,
reference_id=reference_id,
kind="abstract",
content=abstract,
existing_row=metadata_docs.get("abstract"),
)
if abstract_changed and abstract_doc_id is not None and abstract is not None:
vector = vectors_by_key[("abstract", None)]
_index_document(
conn,
doc_id=abstract_doc_id,
reference_id=reference_id,
kind="abstract",
content=abstract,
vector=vector,
)
embeddings_written += 1
chunks_written = 0
if pdf_changed:
old_chunk_rows = conn.execute(
"""
SELECT id
FROM documents
WHERE document_file_id = ?
AND kind = 'fulltext_chunk'
ORDER BY chunk_index ASC
""",
(file_id,),
).fetchall()
delete_documents(conn, [int(row["id"]) for row in old_chunk_rows])
for idx, chunk in enumerate(chunk_artifacts):
doc_id = _insert_document(
conn,
reference_id=reference_id,
document_file_id=file_id,
kind="fulltext_chunk",
content=chunk.content,
chunk_index=idx,
)
vector = vectors_by_key[("fulltext_chunk", idx)]
_index_document(
conn,
doc_id=doc_id,
reference_id=reference_id,
kind="fulltext_chunk",
content=chunk.content,
vector=vector,
)
embeddings_written += 1
chunks_written += 1
conn.commit()
except Exception:
conn.rollback()
raise
return UpsertReferenceResult(
reference_id=reference_id,
document_file_id=file_id,
pdf_changed=pdf_changed,
chunks_written=chunks_written,
embeddings_written=embeddings_written,
sha256=sha256,
)
finally:
conn.close()
def _validate_input(payload: UpsertReferenceInput, pdf_path: Path) -> None:
if not _normalize_required(payload.source_system, field_name="source_system"):
raise ValueError("source_system is required")
if not _normalize_required(payload.source_id, field_name="source_id"):
raise ValueError("source_id is required")
if not pdf_path.exists() or not pdf_path.is_file():
raise FileNotFoundError(f"PDF path not found: {pdf_path}")
def _upsert_reference_row(
conn: sqlite3.Connection,
*,
existing_ref: sqlite3.Row | None,
source_system: str,
source_id: str,
title: str,
authors: str | None,
year: str | None,
) -> int:
if existing_ref is None:
cur = conn.execute(
"""
INSERT INTO reference_items(
source_system, source_id, title, authors, year
) VALUES (?, ?, ?, ?, ?)
""",
(source_system, source_id, title, authors, year),
)
return int(cur.lastrowid)
changed = (
_normalize_optional(existing_ref["title"]) != title
or _normalize_optional(existing_ref["authors"]) != authors
or _normalize_optional(existing_ref["year"]) != year
)
if changed:
conn.execute(
"""
UPDATE reference_items
SET title = ?, authors = ?, year = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(title, authors, year, int(existing_ref["id"])),
)
return int(existing_ref["id"])
def _upsert_file_row(
conn: sqlite3.Connection,
*,
existing_file: sqlite3.Row | None,
reference_id: int,
source_system: str,
source_id: str,
pdf_path: Path,
sha256: str,
extracted_text: str | None,
pdf_changed: bool,
) -> int:
if existing_file is None:
cur = conn.execute(
"""
INSERT INTO document_files(
reference_id,
file_path,
mime_type,
file_hash_sha256,
extracted_text,
source_system,
source_id
) VALUES (?, ?, 'application/pdf', ?, ?, ?, ?)
""",
(
reference_id,
str(pdf_path),
sha256,
extracted_text,
source_system,
source_id,
),
)
return int(cur.lastrowid)
next_extracted = extracted_text if pdf_changed else existing_file["extracted_text"]
conn.execute(
"""
UPDATE document_files
SET reference_id = ?,
file_path = ?,
file_hash_sha256 = ?,
extracted_text = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(
reference_id,
str(pdf_path),
sha256,
next_extracted,
int(existing_file["id"]),
),
)
return int(existing_file["id"])
def _apply_metadata_document(
conn: sqlite3.Connection,
*,
reference_id: int,
kind: str,
content: str | None,
existing_row: sqlite3.Row | None,
) -> int | None:
if content is None:
if existing_row is not None:
delete_documents(conn, [int(existing_row["id"])])
return None
if existing_row is None:
return _insert_document(
conn,
reference_id=reference_id,
document_file_id=None,
kind=kind,
content=content,
chunk_index=None,
)
old_content = _normalize_optional(existing_row["content"])
if old_content == content:
return int(existing_row["id"])
conn.execute(
"""
UPDATE documents
SET content = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(content, int(existing_row["id"])),
)
return int(existing_row["id"])
def _insert_document(
conn: sqlite3.Connection,
*,
reference_id: int,
document_file_id: int | None,
kind: str,
content: str,
chunk_index: int | None,
) -> int:
cur = conn.execute(
"""
INSERT INTO documents(
reference_id,
document_file_id,
kind,
content,
chunk_index
) VALUES (?, ?, ?, ?, ?)
""",
(reference_id, document_file_id, kind, content, chunk_index),
)
return int(cur.lastrowid)
def _index_document(
conn: sqlite3.Connection,
*,
doc_id: int,
reference_id: int,
kind: str,
content: str,
vector: list[float],
) -> None:
upsert_fts_row(
conn,
doc_id=doc_id,
content=content,
kind=kind,
reference_id=reference_id,
)
upsert_vec_row(conn, doc_id=doc_id, vector_bytes=serialize_vector(vector))
def _normalize_required(value: str | None, *, field_name: str) -> str:
out = _normalize_optional(value)
if out is None:
raise ValueError(f"{field_name} is required")
return out
def _normalize_optional(value: Any) -> str | None:
if value is None:
return None
text = str(value).strip()
return text if text else None
def _sha256_file(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
while True:
chunk = handle.read(1024 * 1024)
if not chunk:
break
digest.update(chunk)
return digest.hexdigest()
def _document_content_changed(existing_row: sqlite3.Row | None, new_content: str | None) -> bool:
if existing_row is None:
return new_content is not None
old_content = _normalize_optional(existing_row["content"])
return old_content != new_content
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="litlake-core headless reference manager")
subparsers = parser.add_subparsers(dest="command", required=True)
upsert = subparsers.add_parser("upsert", help="Upsert a single reference + PDF")
upsert.add_argument("--db", required=True, dest="db_path")
upsert.add_argument("--source-system", required=True)
upsert.add_argument("--source-id", required=True)
upsert.add_argument("--title", required=True)
upsert.add_argument("--authors", required=True)
upsert.add_argument("--year", required=True)
upsert.add_argument("--abstract", required=True)
upsert.add_argument("--pdf", required=True, dest="pdf_path")
upsert = subparsers.add_parser("batch_upsert", help="Upsert a batch of references in jsonl format")
upsert.add_argument("--db", required=True, dest="db_path")
upsert.add_argument("--jsonl", required=True, dest="jsonl_path")
search_parser = subparsers.add_parser("search", help="Semantic search by query string")
search_parser.add_argument("--db", required=True, dest="db_path")
search_parser.add_argument("query", help="Search query (embedded and matched against document embeddings)")
search_parser.add_argument("-k", "--top-k", type=int, default=10, dest="top_k", help="Number of results (default: 10)")
return parser
def main(argv: Sequence[str] | None = None) -> int:
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import EasyOcrOptions, PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
if args.command == "upsert":
pipeline_options = PdfPipelineOptions(artifacts_path=str(CACHE_DIR))
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
}
)
chunker = HybridChunker()
payload = UpsertReferenceInput(
db_path=args.db_path,
source_system=args.source_system,
source_id=args.source_id,
title=args.title,
authors=args.authors,
year=args.year,
abstract=args.abstract,
pdf_path=args.pdf_path,
)
result = upsert_reference(payload, converter, chunker)
json.dump(result.to_dict(), sys.stdout, ensure_ascii=False)
sys.stdout.write("\n")
return 0
if args.command == "batch_upsert":
pipeline_options = PdfPipelineOptions(artifacts_path=str(CACHE_DIR))
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
}
)
chunker = HybridChunker()
if not os.path.exists(args.jsonl_path):
parser.error(f"Path must exist: {args.jsonl_path}")
with open(args.jsonl_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
reference_args = json.loads(line)
if not os.path.exists(reference_args['pdf_path']):
continue
payload = UpsertReferenceInput(
db_path=args.db_path,
source_system=reference_args['source_system'],
source_id=reference_args['source_id'],
title=reference_args['title'],
authors=reference_args['authors'],
year=reference_args['year'],
abstract=reference_args['abstract'],
pdf_path=reference_args['pdf_path'],
)
result = upsert_reference(payload, converter, chunker)
json.dump(result.to_dict(), sys.stdout, ensure_ascii=False)
sys.stdout.write("\n")
return 0
if args.command == "search":
conn = connect_db(args.db_path, read_only=True)
try:
results = search_by_string(conn, args.query, args.top_k)
for doc_id, content, kind, reference_id, distance in results:
out = {
"doc_id": doc_id,
"reference_id": reference_id,
"kind": kind,
"distance": distance,
"content": content[:500] + ("..." if len(content) > 500 else ""),
}
json.dump(out, sys.stdout, ensure_ascii=False)
sys.stdout.write("\n")
finally:
conn.close()
return 0
parser.error(f"Unknown command: {args.command}")
return 2
if __name__ == "__main__":
raise SystemExit(main())
"""Generate per-input JSONL files from PubMed XML with optional filters.
Each input XML/XML.GZ becomes one output JSONL in --output-dir.
If that JSONL already exists, the XML file is skipped.
"""
from __future__ import annotations
import argparse
import gzip
import json
import re
import sys
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Iterable, Iterator, Optional
PMCID_TOKEN_RE = re.compile(r"^(?:PMC)?(\d+)$", re.IGNORECASE)
DIGITS_ONLY_RE = re.compile(r"^\d+$")
YEAR_RE = re.compile(r"\b(1[89]\d{2}|20\d{2}|2100)\b")
WRITE_BUFFER_SIZE = 1024 * 1024
def open_text(path: Path):
if path.suffix.lower() == ".gz":
return gzip.open(path, "rt", encoding="utf-8", errors="replace")
return path.open("rt", encoding="utf-8", errors="replace")
def collapse_ws(text: str) -> str:
return " ".join(text.split())
def element_text(elem: Optional[ET.Element]) -> str:
if elem is None:
return ""
return collapse_ws("".join(elem.itertext()))
def normalize_pmcid(value: str) -> Optional[str]:
match = PMCID_TOKEN_RE.match(value.strip())
if not match:
return None
return f"PMC{int(match.group(1))}"
def pmcid_to_int(value: str) -> Optional[int]:
match = PMCID_TOKEN_RE.match(value.strip())
if not match:
return None
return int(match.group(1))
def extract_id_token(line: str, id_type: str) -> Optional[int]:
for token in re.split(r"[\s,;|]+", line.strip()):
if not token:
continue
token = token.strip().strip('"\'')
if id_type == "pmcid":
normalized = pmcid_to_int(token)
if normalized is not None:
return normalized
elif id_type == "pmid" and DIGITS_ONLY_RE.match(token):
return int(token)
return None
def load_id_filter(id_file: Path, id_type: str) -> set[int]:
ids: set[int] = set()
with open_text(id_file) as handle:
for line in handle:
line = line.strip()
if not line or line.startswith("#"):
continue
token = extract_id_token(line, id_type)
if token is not None:
ids.add(token)
return ids
def iter_input_files_from_dir(input_dir: Path, patterns: list[str]) -> Iterator[Path]:
if input_dir.is_file():
yield input_dir
return
seen: set[Path] = set()
for pattern in patterns:
for file_path in sorted(input_dir.glob(pattern)):
if not file_path.is_file():
continue
resolved = file_path.resolve()
if resolved in seen:
continue
seen.add(resolved)
yield resolved
def iter_input_files_from_list(file_list: Path) -> Iterator[Path]:
with open_text(file_list) as handle:
for raw_line in handle:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
p = Path(line).expanduser()
if not p.is_absolute():
p = (file_list.parent / p).resolve()
if p.is_file():
yield p
else:
print(f"WARNING: listed file not found, skipping: {line}", file=sys.stderr)
def extract_abstract(article: ET.Element) -> str:
sections: list[str] = []
for node in article.findall("./MedlineCitation/Article/Abstract/AbstractText"):
text = element_text(node)
if not text:
continue
label = collapse_ws((node.attrib.get("Label") or "").strip())
if label and not text.lower().startswith(label.lower()):
text = f"{label}: {text}"
sections.append(text)
if not sections:
for node in article.findall("./MedlineCitation/OtherAbstract/AbstractText"):
text = element_text(node)
if text:
sections.append(text)
return "\n\n".join(sections)
def extract_publication_year(article: ET.Element) -> Optional[int]:
candidate_paths = (
"./MedlineCitation/Article/Journal/JournalIssue/PubDate/Year",
"./MedlineCitation/Article/ArticleDate/Year",
"./MedlineCitation/DateCompleted/Year",
"./MedlineCitation/DateCreated/Year",
)
for path in candidate_paths:
value = (article.findtext(path) or "").strip()
if value and DIGITS_ONLY_RE.match(value):
year = int(value)
if 1800 <= year <= 2100:
return year
medline_date = (
article.findtext("./MedlineCitation/Article/Journal/JournalIssue/PubDate/MedlineDate")
or ""
).strip()
if medline_date:
match = YEAR_RE.search(medline_date)
if match:
return int(match.group(1))
return None
def extract_authors(article: ET.Element) -> list[str]:
authors: list[str] = []
for node in article.findall("./MedlineCitation/Article/AuthorList/Author"):
collective = element_text(node.find("./CollectiveName"))
if collective:
authors.append(collective)
continue
last = element_text(node.find("./LastName"))
fore = element_text(node.find("./ForeName"))
initials = element_text(node.find("./Initials"))
if fore and last:
authors.append(f"{fore} {last}")
elif initials and last:
authors.append(f"{initials} {last}")
elif last:
authors.append(last)
elif fore:
authors.append(fore)
return authors
def iter_pubmed_records(xml_path: Path) -> Iterator[dict[str, object]]:
with open_text(xml_path) as handle:
context = ET.iterparse(handle, events=("start", "end"))
_, root = next(context)
for event, elem in context:
if event != "end" or elem.tag != "PubmedArticle":
continue
pmid = (elem.findtext("./MedlineCitation/PMID") or "").strip()
if not pmid:
elem.clear()
continue
title = element_text(elem.find("./MedlineCitation/Article/ArticleTitle"))
abstract = extract_abstract(elem)
publication_year = extract_publication_year(elem)
authors = extract_authors(elem)
pmcid = ""
for node in elem.findall("./PubmedData/ArticleIdList/ArticleId"):
if (node.attrib.get("IdType") or "").lower() == "pmc":
pmcid = normalize_pmcid(element_text(node)) or ""
if pmcid:
break
yield {
"pmid": str(int(pmid)) if DIGITS_ONLY_RE.match(pmid) else pmid,
"pmcid": pmcid,
"publication_year": publication_year,
"authors": authors,
"title": title,
"abstract": abstract,
}
elem.clear()
root.clear()
def skip_reason(
record: dict[str, object],
require_pmcid: bool,
filter_ids: Optional[set[int]],
id_type: str,
allow_empty_abstract: bool,
) -> Optional[str]:
pmcid = str(record.get("pmcid") or "")
pmid = str(record.get("pmid") or "")
abstract = str(record.get("abstract") or "")
if require_pmcid and not pmcid:
return "no_pmcid"
if filter_ids is not None:
if id_type == "pmcid":
lookup_value = pmcid_to_int(pmcid) if pmcid else None
else:
lookup_value = int(pmid) if DIGITS_ONLY_RE.match(pmid) else None
if lookup_value is None or lookup_value not in filter_ids:
return "not_in_list"
if not allow_empty_abstract and not abstract:
return "no_abstract"
return None
def output_path_for_input(xml_file: Path, output_dir: Path) -> Path:
name = xml_file.name
if name.endswith(".xml.gz"):
stem = name[: -len(".xml.gz")]
elif name.endswith(".xml"):
stem = name[: -len(".xml")]
else:
stem = xml_file.stem
return output_dir / f"{stem}.jsonl"
def parse_args(argv: Optional[Iterable[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Extract title/abstract/year/authors from PubMed XML into per-file JSONL outputs."
)
parser.add_argument(
"--input-dir",
type=Path,
help="Input XML directory or single XML/XML.GZ file",
)
parser.add_argument(
"--input-file-list",
type=Path,
help="Text file with XML/XML.GZ paths (one per line)",
)
parser.add_argument(
"--input-glob",
action="append",
dest="input_globs",
help=(
"Glob pattern(s) relative to --input-dir, repeatable. "
"Examples: '*.xml.gz', 'pubmed26n*.xml.gz', '**/*.xml.gz'. "
"Defaults to '*.xml.gz' and '*.xml'"
),
)
parser.add_argument("--output-dir", type=Path, required=True, help="Output directory for per-file JSONL")
parser.add_argument(
"--pmcid-list",
type=Path,
help="Optional text file of PMCID/PMID values to keep",
)
parser.add_argument(
"--require-pmcid",
action="store_true",
help="Keep only records with PMCID",
)
parser.add_argument(
"--id-type",
choices=("pmcid", "pmid"),
default="pmcid",
help="How to interpret --pmcid-list values (default: pmcid)",
)
parser.add_argument(
"--allow-empty-abstract",
action="store_true",
help="Keep records even when abstract is missing",
)
parser.add_argument(
"--progress-every",
type=int,
default=100_000,
help="Print global progress every N scanned records (default: 100000)",
)
return parser.parse_args(argv)
def main(argv: Optional[Iterable[str]] = None) -> int:
args = parse_args(argv)
if not args.input_dir and not args.input_file_list:
print("ERROR: provide at least one of --input-dir or --input-file-list", file=sys.stderr)
return 1
if args.input_dir and not args.input_dir.exists():
print(f"ERROR: input path not found: {args.input_dir}", file=sys.stderr)
return 1
if args.input_file_list and not args.input_file_list.exists():
print(f"ERROR: input file list not found: {args.input_file_list}", file=sys.stderr)
return 1
args.output_dir.mkdir(parents=True, exist_ok=True)
filter_ids: Optional[set[int]] = None
if args.pmcid_list:
if not args.pmcid_list.exists():
print(f"ERROR: --pmcid-list file not found: {args.pmcid_list}", file=sys.stderr)
return 1
print(f"Loading ID filter from {args.pmcid_list} ...", file=sys.stderr)
filter_ids = load_id_filter(args.pmcid_list, args.id_type)
print(f"Loaded {len(filter_ids):,} IDs", file=sys.stderr)
seen_inputs: set[Path] = set()
input_files: list[Path] = []
if args.input_file_list:
for p in iter_input_files_from_list(args.input_file_list.resolve()):
rp = p.resolve()
if rp not in seen_inputs:
seen_inputs.add(rp)
input_files.append(rp)
if args.input_dir:
patterns = args.input_globs or ["*.xml.gz", "*.xml"]
for p in iter_input_files_from_dir(args.input_dir.resolve(), patterns):
rp = p.resolve()
if rp not in seen_inputs:
seen_inputs.add(rp)
input_files.append(rp)
if not input_files:
print("ERROR: no input XML files found", file=sys.stderr)
return 1
scanned = 0
written = 0
skipped_no_pmcid = 0
skipped_not_in_list = 0
skipped_empty_abstract = 0
skipped_existing_files = 0
processed_files = 0
for xml_file in input_files:
output_path = output_path_for_input(xml_file, args.output_dir)
if output_path.exists():
skipped_existing_files += 1
print(f"Skipping existing output: {output_path}", file=sys.stderr)
continue
print(f"Processing {xml_file}", file=sys.stderr)
file_scanned = 0
file_written = 0
file_skipped_no_pmcid = 0
file_skipped_not_in_list = 0
file_skipped_empty_abstract = 0
with output_path.open("wt", encoding="utf-8", buffering=WRITE_BUFFER_SIZE) as out:
for record in iter_pubmed_records(xml_file):
scanned += 1
file_scanned += 1
if args.progress_every > 0 and scanned % args.progress_every == 0:
print(
f"Scanned={scanned:,} Written={written:,} "
f"NoPMCID={skipped_no_pmcid:,} NotInList={skipped_not_in_list:,} "
f"NoAbstract={skipped_empty_abstract:,}",
file=sys.stderr,
)
reason = skip_reason(
record,
args.require_pmcid,
filter_ids,
args.id_type,
args.allow_empty_abstract,
)
if reason == "no_pmcid":
skipped_no_pmcid += 1
file_skipped_no_pmcid += 1
continue
if reason == "not_in_list":
skipped_not_in_list += 1
file_skipped_not_in_list += 1
continue
if reason == "no_abstract":
skipped_empty_abstract += 1
file_skipped_empty_abstract += 1
continue
# Every record ends with '\n' so each JSONL is safe to concatenate via `cat`.
out.write(json.dumps(record, ensure_ascii=False) + "\n")
written += 1
file_written += 1
processed_files += 1
file_success_rate = (file_written / file_scanned * 100.0) if file_scanned else 0.0
print(
"FileSummary "
f"File={xml_file} "
f"Output={output_path} "
f"Scanned={file_scanned:,} Written={file_written:,} "
f"SuccessRate={file_success_rate:.2f}% "
f"NoPMCID={file_skipped_no_pmcid:,} "
f"NotInList={file_skipped_not_in_list:,} "
f"NoAbstract={file_skipped_empty_abstract:,}",
file=sys.stderr,
)
print("Finished", file=sys.stderr)
print(
f"FilesTotal={len(input_files):,} FilesProcessed={processed_files:,} "
f"FilesSkippedExisting={skipped_existing_files:,}",
file=sys.stderr,
)
print(
f"Scanned={scanned:,} Written={written:,} "
f"NoPMCID={skipped_no_pmcid:,} NotInList={skipped_not_in_list:,} "
f"NoAbstract={skipped_empty_abstract:,}",
file=sys.stderr,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment