feat: add PageIndex SDK with local/cloud dual-mode support (#207)

This commit is contained in:
Kylin 2026-04-06 22:51:04 +08:00 committed by Ray
parent f2dcffc0b7
commit c7fe93bb56
45 changed files with 4225 additions and 274 deletions

4
.gitignore vendored
View file

@ -4,3 +4,7 @@ __pycache__
.env*
.venv/
logs/
pageindex.egg-info/
*.db
venv/
uv.lock

62
examples/cloud_demo.py Normal file
View file

@ -0,0 +1,62 @@
"""
Agentic Vectorless RAG with PageIndex SDK - Cloud Demo
Uses CloudClient for fully-managed document indexing and QA.
No LLM API key needed the cloud service handles everything.
Steps:
1 Upload and index a PDF via PageIndex cloud
2 Stream a question with tool call visibility
Requirements:
pip install pageindex
export PAGEINDEX_API_KEY=your-api-key
"""
import asyncio
import os
from pathlib import Path
import requests
from pageindex import CloudClient
_EXAMPLES_DIR = Path(__file__).parent
PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf"
PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf"
# Download PDF if needed
if not PDF_PATH.exists():
print(f"Downloading {PDF_URL} ...")
PDF_PATH.parent.mkdir(parents=True, exist_ok=True)
with requests.get(PDF_URL, stream=True, timeout=30) as r:
r.raise_for_status()
with open(PDF_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete.\n")
client = CloudClient(api_key=os.environ["PAGEINDEX_API_KEY"])
col = client.collection()
doc_id = col.add(str(PDF_PATH))
print(f"Indexed: {doc_id}\n")
# Streaming query
stream = col.query("What is the main contribution of this paper?", stream=True)
async def main():
streamed_text = False
async for event in stream:
if event.type == "answer_delta":
print(event.data, end="", flush=True)
streamed_text = True
elif event.type == "tool_call":
if streamed_text:
print()
streamed_text = False
args = event.data.get("args", "")
print(f"[tool call] {event.data['name']}({args})")
elif event.type == "answer_done":
print()
streamed_text = False
asyncio.run(main())

69
examples/local_demo.py Normal file
View file

@ -0,0 +1,69 @@
"""
Agentic Vectorless RAG with PageIndex SDK - Local Demo
A simple example of using LocalClient for self-hosted document indexing
and agent-based QA. The agent uses OpenAI Agents SDK to reason over
the document's tree structure index.
Steps:
1 Download and index a PDF
2 Stream a question with tool call visibility
Requirements:
pip install pageindex
export OPENAI_API_KEY=your-api-key # or any LiteLLM-supported provider
"""
import asyncio
from pathlib import Path
import requests
from pageindex import LocalClient
_EXAMPLES_DIR = Path(__file__).parent
PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf"
PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf"
WORKSPACE = _EXAMPLES_DIR / "workspace"
MODEL = "gpt-4o-2024-11-20" # any LiteLLM-supported model
# Download PDF if needed
if not PDF_PATH.exists():
print(f"Downloading {PDF_URL} ...")
PDF_PATH.parent.mkdir(parents=True, exist_ok=True)
with requests.get(PDF_URL, stream=True, timeout=30) as r:
r.raise_for_status()
with open(PDF_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete.\n")
client = LocalClient(model=MODEL, storage_path=str(WORKSPACE))
col = client.collection()
doc_id = col.add(str(PDF_PATH))
print(f"Indexed: {doc_id}\n")
# Streaming query
stream = col.query(
"What is the main architecture proposed in this paper and how does self-attention work?",
stream=True,
)
async def main():
streamed_text = False
async for event in stream:
if event.type == "answer_delta":
print(event.data, end="", flush=True)
streamed_text = True
elif event.type == "tool_call":
if streamed_text:
print()
streamed_text = False
print(f"[tool call] {event.data['name']}")
elif event.type == "tool_result":
preview = str(event.data)[:200] + "..." if len(str(event.data)) > 200 else event.data
print(f"[tool output] {preview}")
elif event.type == "answer_done":
print()
streamed_text = False
asyncio.run(main())

View file

@ -1,4 +1,40 @@
# pageindex/__init__.py
# Upstream exports (backward compatibility)
from .page_index import *
from .page_index_md import md_to_tree
from .retrieve import get_document, get_document_structure, get_page_content
from .client import PageIndexClient
# SDK exports
from .client import PageIndexClient, LocalClient, CloudClient
from .config import IndexConfig
from .collection import Collection
from .parser.protocol import ContentNode, ParsedDocument, DocumentParser
from .storage.protocol import StorageEngine
from .events import QueryEvent
from .errors import (
PageIndexError,
CollectionNotFoundError,
DocumentNotFoundError,
IndexingError,
CloudAPIError,
FileTypeError,
)
__all__ = [
"PageIndexClient",
"LocalClient",
"CloudClient",
"IndexConfig",
"Collection",
"ContentNode",
"ParsedDocument",
"DocumentParser",
"StorageEngine",
"QueryEvent",
"PageIndexError",
"CollectionNotFoundError",
"DocumentNotFoundError",
"IndexingError",
"CloudAPIError",
"FileTypeError",
]

93
pageindex/agent.py Normal file
View file

@ -0,0 +1,93 @@
# pageindex/agent.py
from __future__ import annotations
from typing import AsyncIterator
from .events import QueryEvent
from .backend.protocol import AgentTools
SYSTEM_PROMPT = """
You are PageIndex, a document QA assistant.
TOOL USE:
- Call list_documents() to see available documents.
- Call get_document(doc_id) to confirm status and page/line count.
- Call get_document_structure(doc_id) to identify relevant page ranges.
- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document.
- Before each tool call, output one short sentence explaining the reason.
IMAGES:
- Page content may contain image references like ![image](path). Always preserve these in your answer so the downstream UI can render them.
- Place images near the relevant context in your answer.
Answer based only on tool output. Be concise.
"""
class QueryStream:
"""Streaming query result, similar to OpenAI's RunResultStreaming.
Usage:
stream = col.query("question", stream=True)
async for event in stream:
if event.type == "answer_delta":
print(event.data, end="", flush=True)
"""
def __init__(self, tools: AgentTools, question: str, model: str = None):
from agents import Agent
from agents.model_settings import ModelSettings
self._agent = Agent(
name="PageIndex",
instructions=SYSTEM_PROMPT,
tools=tools.function_tools,
mcp_servers=tools.mcp_servers,
model=model,
model_settings=ModelSettings(parallel_tool_calls=False),
)
self._question = question
async def stream_events(self) -> AsyncIterator[QueryEvent]:
"""Async generator yielding QueryEvent as they arrive."""
from agents import Runner, ItemHelpers
from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
from openai.types.responses import ResponseTextDeltaEvent
streamed_run = Runner.run_streamed(self._agent, self._question)
async for event in streamed_run.stream_events():
if isinstance(event, RawResponsesStreamEvent):
if isinstance(event.data, ResponseTextDeltaEvent):
yield QueryEvent(type="answer_delta", data=event.data.delta)
elif isinstance(event, RunItemStreamEvent):
item = event.item
if item.type == "tool_call_item":
raw = item.raw_item
yield QueryEvent(type="tool_call", data={
"name": raw.name, "args": getattr(raw, "arguments", "{}"),
})
elif item.type == "tool_call_output_item":
yield QueryEvent(type="tool_result", data=str(item.output))
elif item.type == "message_output_item":
text = ItemHelpers.text_message_output(item)
if text:
yield QueryEvent(type="answer_done", data=text)
def __aiter__(self):
return self.stream_events()
class AgentRunner:
def __init__(self, tools: AgentTools, model: str = None):
self._tools = tools
self._model = model
def run(self, question: str) -> str:
"""Sync non-streaming query. Returns answer string."""
from agents import Agent, Runner
from agents.model_settings import ModelSettings
agent = Agent(
name="PageIndex",
instructions=SYSTEM_PROMPT,
tools=self._tools.function_tools,
mcp_servers=self._tools.mcp_servers,
model=self._model,
model_settings=ModelSettings(parallel_tool_calls=False),
)
result = Runner.run_sync(agent, question)
return result.final_output

View file

352
pageindex/backend/cloud.py Normal file
View file

@ -0,0 +1,352 @@
# pageindex/backend/cloud.py
"""CloudBackend — connects to PageIndex cloud service (api.pageindex.ai).
API reference: https://github.com/VectifyAI/pageindex_sdk
"""
from __future__ import annotations
import json
import logging
import os
import re
import time
import urllib.parse
import requests
from typing import AsyncIterator
from .protocol import AgentTools
from ..errors import CloudAPIError, PageIndexError
from ..events import QueryEvent
logger = logging.getLogger(__name__)
API_BASE = "https://api.pageindex.ai"
_INTERNAL_TOOLS = frozenset({"ToolSearch", "Read", "Grep", "Glob", "Bash", "Edit", "Write"})
class CloudBackend:
def __init__(self, api_key: str):
self._api_key = api_key
self._headers = {"api_key": api_key}
self._folder_id_cache: dict[str, str | None] = {}
self._folder_warning_shown = False
# ── HTTP helpers ──────────────────────────────────────────────────────
def _warn_folder_upgrade(self) -> None:
if not self._folder_warning_shown:
logger.warning(
"Folders (collections) require a Max plan. "
"All documents are stored in a single global space — collection names are ignored. "
"Upgrade at https://dash.pageindex.ai/subscription"
)
self._folder_warning_shown = True
def _request(self, method: str, path: str, **kwargs) -> dict:
url = f"{API_BASE}{path}"
for attempt in range(3):
try:
resp = requests.request(method, url, headers=self._headers, timeout=30, **kwargs)
if resp.status_code in (429, 500, 502, 503):
logger.warning("Cloud API %s %s returned %d, retrying...", method, path, resp.status_code)
time.sleep(2 ** attempt)
continue
if resp.status_code != 200:
body = resp.text[:500] if resp.text else ""
raise CloudAPIError(f"Cloud API error {resp.status_code}: {body}")
return resp.json() if resp.content else {}
except requests.RequestException as e:
if attempt == 2:
raise CloudAPIError(f"Cloud API request failed: {e}") from e
time.sleep(2 ** attempt)
raise CloudAPIError("Max retries exceeded")
@staticmethod
def _validate_collection_name(name: str) -> None:
if not re.match(r'^[a-zA-Z0-9_-]{1,128}$', name):
raise PageIndexError(
f"Invalid collection name: {name!r}. "
"Must be 1-128 chars of [a-zA-Z0-9_-]."
)
@staticmethod
def _enc(value: str) -> str:
return urllib.parse.quote(value, safe="")
# ── Collection management (mapped to folders) ─────────────────────────
def create_collection(self, name: str) -> None:
self._validate_collection_name(name)
try:
resp = self._request("POST", "/folder/", json={"name": name})
self._folder_id_cache[name] = resp.get("folder", {}).get("id")
except CloudAPIError as e:
if "403" in str(e):
self._warn_folder_upgrade()
self._folder_id_cache[name] = None
else:
raise
def get_or_create_collection(self, name: str) -> None:
self._validate_collection_name(name)
try:
data = self._request("GET", "/folders/")
for folder in data.get("folders", []):
if folder.get("name") == name:
self._folder_id_cache[name] = folder["id"]
return
resp = self._request("POST", "/folder/", json={"name": name})
self._folder_id_cache[name] = resp.get("folder", {}).get("id")
except CloudAPIError as e:
if "403" in str(e):
self._warn_folder_upgrade()
self._folder_id_cache[name] = None
else:
raise
def _get_folder_id(self, name: str) -> str | None:
"""Resolve collection name to folder ID. Returns None if folders not available."""
if name in self._folder_id_cache:
return self._folder_id_cache.get(name)
try:
data = self._request("GET", "/folders/")
for folder in data.get("folders", []):
if folder.get("name") == name:
self._folder_id_cache[name] = folder["id"]
return folder["id"]
except CloudAPIError:
pass
self._folder_id_cache[name] = None
return None
def list_collections(self) -> list[str]:
data = self._request("GET", "/folders/")
return [f["name"] for f in data.get("folders", [])]
def delete_collection(self, name: str) -> None:
folder_id = self._get_folder_id(name)
if folder_id:
self._request("DELETE", f"/folder/{self._enc(folder_id)}/")
# ── Document management ───────────────────────────────────────────────
def add_document(self, collection: str, file_path: str) -> str:
folder_id = self._get_folder_id(collection)
data = {"if_retrieval": "true"}
if folder_id:
data["folder_id"] = folder_id
with open(file_path, "rb") as f:
resp = self._request("POST", "/doc/", files={"file": f}, data=data)
doc_id = resp["doc_id"]
# Poll until retrieval-ready
for _ in range(120): # 10 min max
tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree"})
if tree_resp.get("retrieval_ready"):
return doc_id
status = tree_resp.get("status", "")
if status == "failed":
raise CloudAPIError(f"Document {doc_id} indexing failed")
time.sleep(5)
raise CloudAPIError(f"Document {doc_id} indexing timed out")
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict:
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/metadata/")
# Fetch structure in the same call via tree endpoint
tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/",
params={"type": "tree", "summary": "true"})
raw_tree = tree_resp.get("tree", tree_resp.get("structure", tree_resp.get("result", [])))
return {
"doc_id": resp.get("id", doc_id),
"doc_name": resp.get("name", ""),
"doc_description": resp.get("description", ""),
"doc_type": "pdf",
"status": resp.get("status", ""),
"structure": self._normalize_tree(raw_tree),
}
def get_document_structure(self, collection: str, doc_id: str) -> list:
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree", "summary": "true"})
raw_tree = resp.get("tree", resp.get("structure", resp.get("result", [])))
return self._normalize_tree(raw_tree)
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list:
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "ocr", "format": "page"})
# Filter to requested pages
from ..index.utils import parse_pages
page_nums = set(parse_pages(pages))
all_pages = resp.get("pages", resp.get("ocr", resp.get("result", [])))
if isinstance(all_pages, list):
return [
{"page": p.get("page", p.get("page_index")),
"content": p.get("content", p.get("markdown", ""))}
for p in all_pages
if p.get("page", p.get("page_index")) in page_nums
]
return []
@staticmethod
def _normalize_tree(nodes: list) -> list:
"""Normalize cloud tree nodes to match local schema."""
result = []
for node in nodes:
normalized = {
"title": node.get("title", ""),
"node_id": node.get("node_id", ""),
"summary": node.get("summary", node.get("prefix_summary", "")),
"start_index": node.get("start_index", node.get("page_index")),
"end_index": node.get("end_index", node.get("page_index")),
}
if "text" in node:
normalized["text"] = node["text"]
children = node.get("nodes", [])
if children:
normalized["nodes"] = CloudBackend._normalize_tree(children)
result.append(normalized)
return result
def list_documents(self, collection: str) -> list[dict]:
folder_id = self._get_folder_id(collection)
params = {"limit": 100}
if folder_id:
params["folder_id"] = folder_id
data = self._request("GET", "/docs/", params=params)
return [
{"doc_id": d.get("id", ""), "doc_name": d.get("name", ""), "doc_type": "pdf"}
for d in data.get("documents", [])
]
def delete_document(self, collection: str, doc_id: str) -> None:
self._request("DELETE", f"/doc/{self._enc(doc_id)}/")
# ── Query (uses cloud chat/completions, no LLM key needed) ────────────
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
"""Non-streaming query via cloud chat/completions."""
doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection)
resp = self._request("POST", "/chat/completions/", json={
"messages": [{"role": "user", "content": question}],
"doc_id": doc_id,
"stream": False,
})
# Extract answer from response
choices = resp.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "")
return resp.get("content", resp.get("answer", ""))
async def query_stream(self, collection: str, question: str,
doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]:
"""Streaming query via cloud chat/completions SSE.
Events are yielded in real-time as they arrive from the server.
A background thread handles the blocking HTTP stream and pushes
events through an asyncio.Queue for true async streaming.
"""
import asyncio
import threading
doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection)
headers = self._headers
queue: asyncio.Queue[QueryEvent | None] = asyncio.Queue()
loop = asyncio.get_event_loop()
def _stream():
"""Background thread: read SSE and push events to queue."""
resp = requests.post(
f"{API_BASE}/chat/completions/",
headers=headers,
json={
"messages": [{"role": "user", "content": question}],
"doc_id": doc_id,
"stream": True,
"stream_metadata": True,
},
stream=True,
timeout=120,
)
try:
if resp.status_code != 200:
body = resp.text[:500] if resp.text else ""
loop.call_soon_threadsafe(
queue.put_nowait,
QueryEvent(type="answer_done",
data=f"Cloud streaming error {resp.status_code}: {body}"),
)
return
current_tool_name = None
current_tool_args: list[str] = []
for line in resp.iter_lines(decode_unicode=True):
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
meta = chunk.get("block_metadata", {})
block_type = meta.get("type", "")
choices = chunk.get("choices", [])
delta = choices[0].get("delta", {}) if choices else {}
content = delta.get("content", "")
if block_type == "mcp_tool_use_start":
current_tool_name = meta.get("tool_name", "")
current_tool_args = []
elif block_type == "tool_use":
if content:
current_tool_args.append(content)
elif block_type == "tool_use_stop":
if current_tool_name and current_tool_name not in _INTERNAL_TOOLS:
args_str = "".join(current_tool_args)
loop.call_soon_threadsafe(
queue.put_nowait,
QueryEvent(type="tool_call", data={
"name": current_tool_name,
"args": args_str,
}),
)
current_tool_name = None
current_tool_args = []
elif block_type == "text" and content:
loop.call_soon_threadsafe(
queue.put_nowait,
QueryEvent(type="answer_delta", data=content),
)
finally:
resp.close()
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
thread = threading.Thread(target=_stream, daemon=True)
thread.start()
while True:
event = await queue.get()
if event is None:
break
yield event
thread.join(timeout=5)
def _get_all_doc_ids(self, collection: str) -> list[str]:
"""Get all document IDs in a collection."""
docs = self.list_documents(collection)
return [d["doc_id"] for d in docs]
# ── Not used in cloud mode ────────────────────────────────────────────
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
"""Not used in cloud mode — query goes through chat/completions."""
return AgentTools()

245
pageindex/backend/local.py Normal file
View file

@ -0,0 +1,245 @@
# pageindex/backend/local.py
import hashlib
import os
import re
import uuid
import shutil
from pathlib import Path
from ..parser.protocol import DocumentParser, ParsedDocument
from ..parser.pdf import PdfParser
from ..parser.markdown import MarkdownParser
from ..storage.protocol import StorageEngine
from ..index.pipeline import build_index
from ..index.utils import parse_pages, get_pdf_page_content, get_md_page_content, remove_fields
from ..backend.protocol import AgentTools
from ..errors import FileTypeError, DocumentNotFoundError, IndexingError, PageIndexError
_COLLECTION_NAME_RE = re.compile(r'^[a-zA-Z0-9_-]{1,128}$')
class LocalBackend:
def __init__(self, storage: StorageEngine, files_dir: str, model: str = None,
retrieve_model: str = None, index_config=None):
self._storage = storage
self._files_dir = Path(files_dir)
self._model = model
self._retrieve_model = retrieve_model or model
self._index_config = index_config
self._parsers: list[DocumentParser] = [PdfParser(), MarkdownParser()]
def register_parser(self, parser: DocumentParser) -> None:
self._parsers.insert(0, parser) # user parsers checked first
def get_retrieve_model(self) -> str | None:
return self._retrieve_model
def _resolve_parser(self, file_path: str) -> DocumentParser:
ext = os.path.splitext(file_path)[1].lower()
for parser in self._parsers:
if ext in parser.supported_extensions():
return parser
raise FileTypeError(f"No parser for extension: {ext}")
# Collection management
def _validate_collection_name(self, name: str) -> None:
if not _COLLECTION_NAME_RE.match(name):
raise PageIndexError(f"Invalid collection name: {name!r}. Must be 1-128 chars of [a-zA-Z0-9_-].")
def create_collection(self, name: str) -> None:
self._validate_collection_name(name)
self._storage.create_collection(name)
def get_or_create_collection(self, name: str) -> None:
self._validate_collection_name(name)
self._storage.get_or_create_collection(name)
def list_collections(self) -> list[str]:
return self._storage.list_collections()
def delete_collection(self, name: str) -> None:
self._storage.delete_collection(name)
col_dir = self._files_dir / name
if col_dir.exists():
shutil.rmtree(col_dir)
@staticmethod
def _file_hash(file_path: str) -> str:
"""Compute SHA-256 hash of a file."""
h = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""):
h.update(chunk)
return h.hexdigest()
# Document management
def add_document(self, collection: str, file_path: str) -> str:
file_path = os.path.realpath(file_path)
if not os.path.isfile(file_path):
raise FileTypeError(f"Not a regular file: {file_path}")
parser = self._resolve_parser(file_path)
# Dedup: skip if same file already indexed in this collection
file_hash = self._file_hash(file_path)
existing_id = self._storage.find_document_by_hash(collection, file_hash)
if existing_id:
return existing_id
doc_id = str(uuid.uuid4())
# Copy file to managed directory
ext = os.path.splitext(file_path)[1]
col_dir = self._files_dir / collection
col_dir.mkdir(parents=True, exist_ok=True)
managed_path = col_dir / f"{doc_id}{ext}"
shutil.copy2(file_path, managed_path)
try:
# Store images alongside the document: files/{collection}/{doc_id}/images/
images_dir = str(col_dir / doc_id / "images")
parsed = parser.parse(file_path, model=self._model, images_dir=images_dir)
result = build_index(parsed, model=self._model, opt=self._index_config)
# Cache page text for fast retrieval (avoids re-reading files)
pages = [{"page": n.index, "content": n.content,
**({"images": n.images} if n.images else {})}
for n in parsed.nodes if n.content]
# Strip text from structure to save storage space (PDF only;
# markdown needs text in structure for fallback retrieval)
doc_type = ext.lstrip(".")
if doc_type == "pdf":
clean_structure = remove_fields(result["structure"], fields=["text"])
else:
clean_structure = result["structure"]
self._storage.save_document(collection, doc_id, {
"doc_name": parsed.doc_name,
"doc_description": result.get("doc_description", ""),
"file_path": str(managed_path),
"file_hash": file_hash,
"doc_type": doc_type,
"structure": clean_structure,
"pages": pages,
})
except Exception as e:
managed_path.unlink(missing_ok=True)
doc_dir = col_dir / doc_id
if doc_dir.exists():
shutil.rmtree(doc_dir)
raise IndexingError(f"Failed to index {file_path}: {e}") from e
return doc_id
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict:
"""Get document metadata with structure.
Args:
include_text: If True, populate each structure node's 'text' field
from cached page content. WARNING: may be very large do NOT
use in agent/LLM contexts as it can exhaust the context window.
"""
doc = self._storage.get_document(collection, doc_id)
if not doc:
return {}
doc["structure"] = self._storage.get_document_structure(collection, doc_id)
if include_text:
pages = self._storage.get_pages(collection, doc_id) or []
page_map = {p["page"]: p["content"] for p in pages}
self._fill_node_text(doc["structure"], page_map)
return doc
@staticmethod
def _fill_node_text(nodes: list, page_map: dict) -> None:
"""Recursively fill 'text' on structure nodes from cached page content."""
for node in nodes:
start = node.get("start_index")
end = node.get("end_index")
if start is not None and end is not None:
node["text"] = "\n".join(
page_map.get(p, "") for p in range(start, end + 1)
)
if "nodes" in node:
LocalBackend._fill_node_text(node["nodes"], page_map)
def get_document_structure(self, collection: str, doc_id: str) -> list:
return self._storage.get_document_structure(collection, doc_id)
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list:
doc = self._storage.get_document(collection, doc_id)
if not doc:
raise DocumentNotFoundError(f"Document {doc_id} not found")
page_nums = parse_pages(pages)
# Try cached pages first (fast, no file I/O)
cached_pages = self._storage.get_pages(collection, doc_id)
if cached_pages:
return [p for p in cached_pages if p["page"] in page_nums]
# Fallback to reading from file
if doc["doc_type"] == "pdf":
return get_pdf_page_content(doc["file_path"], page_nums)
else:
structure = self._storage.get_document_structure(collection, doc_id)
return get_md_page_content(structure, page_nums)
def list_documents(self, collection: str) -> list[dict]:
return self._storage.list_documents(collection)
def delete_document(self, collection: str, doc_id: str) -> None:
doc = self._storage.get_document(collection, doc_id)
if doc and doc.get("file_path"):
Path(doc["file_path"]).unlink(missing_ok=True)
# Clean up images directory: files/{collection}/{doc_id}/
doc_dir = self._files_dir / collection / doc_id
if doc_dir.exists():
shutil.rmtree(doc_dir)
self._storage.delete_document(collection, doc_id)
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
from agents import function_tool
import json
storage = self._storage
col_name = collection
backend = self
filter_ids = doc_ids
@function_tool
def list_documents() -> str:
"""List all documents in the collection."""
docs = storage.list_documents(col_name)
if filter_ids:
docs = [d for d in docs if d["doc_id"] in filter_ids]
return json.dumps(docs)
@function_tool
def get_document(doc_id: str) -> str:
"""Get document metadata."""
return json.dumps(storage.get_document(col_name, doc_id))
@function_tool
def get_document_structure(doc_id: str) -> str:
"""Get document tree structure (without text)."""
structure = storage.get_document_structure(col_name, doc_id)
return json.dumps(remove_fields(structure, fields=["text"]), ensure_ascii=False)
@function_tool
def get_page_content(doc_id: str, pages: str) -> str:
"""Get page content. Use tight ranges: '5-7', '3,8', '12'."""
result = backend.get_page_content(col_name, doc_id, pages)
return json.dumps(result, ensure_ascii=False)
return AgentTools(function_tools=[list_documents, get_document, get_document_structure, get_page_content])
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
from ..agent import AgentRunner
tools = self.get_agent_tools(collection, doc_ids)
return AgentRunner(tools=tools, model=self._retrieve_model).run(question)
async def query_stream(self, collection: str, question: str,
doc_ids: list[str] | None = None):
from ..agent import QueryStream
tools = self.get_agent_tools(collection, doc_ids)
stream = QueryStream(tools=tools, question=question, model=self._retrieve_model)
async for event in stream:
yield event

View file

@ -0,0 +1,34 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Protocol, Any, AsyncIterator, runtime_checkable
from ..events import QueryEvent
@dataclass
class AgentTools:
"""Structured container for agent tool configuration (local mode only)."""
function_tools: list[Any] = field(default_factory=list)
mcp_servers: list[Any] = field(default_factory=list)
@runtime_checkable
class Backend(Protocol):
# Collection management
def create_collection(self, name: str) -> None: ...
def get_or_create_collection(self, name: str) -> None: ...
def list_collections(self) -> list[str]: ...
def delete_collection(self, name: str) -> None: ...
# Document management
def add_document(self, collection: str, file_path: str) -> str: ...
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict: ...
def get_document_structure(self, collection: str, doc_id: str) -> list: ...
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list: ...
def list_documents(self, collection: str) -> list[dict]: ...
def delete_document(self, collection: str, doc_id: str) -> None: ...
# Query
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str: ...
async def query_stream(self, collection: str, question: str,
doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]: ...

View file

@ -1,18 +1,9 @@
import os
import uuid
import json
import asyncio
import concurrent.futures
# pageindex/client.py
from __future__ import annotations
from pathlib import Path
import PyPDF2
from .page_index import page_index
from .page_index_md import md_to_tree
from .retrieve import get_document, get_document_structure, get_page_content
from .utils import ConfigLoader, remove_fields
META_INDEX = "_meta.json"
from .collection import Collection
from .config import IndexConfig
from .parser.protocol import DocumentParser
def _normalize_retrieve_model(model: str) -> str:
@ -26,209 +17,145 @@ def _normalize_retrieve_model(model: str) -> str:
class PageIndexClient:
"""
A client for indexing and retrieving document content.
Flow: index() -> get_document() / get_document_structure() / get_page_content()
"""PageIndex client — supports both local and cloud modes.
For agent-based QA, see examples/agentic_vectorless_rag_demo.py.
Args:
api_key: PageIndex cloud API key. When provided, cloud mode is used
and local-only params (model, storage_path, index_config, ) are ignored.
model: LLM model for indexing (local mode only, default: gpt-4o-2024-11-20).
retrieve_model: LLM model for agent QA (local mode only, default: same as model).
storage_path: Directory for SQLite DB and files (local mode only, default: ./.pageindex).
storage: Custom StorageEngine instance (local mode only).
index_config: Advanced indexing parameters (local mode only, optional).
Pass an IndexConfig instance or a dict. Defaults are sensible for most use cases.
Usage:
# Local mode (auto-detected when no api_key)
client = PageIndexClient(model="gpt-5.4")
# Cloud mode (auto-detected when api_key provided)
client = PageIndexClient(api_key="your-api-key")
# Or use LocalClient / CloudClient for explicit mode selection
"""
def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None):
def __init__(self, api_key: str = None, model: str = None,
retrieve_model: str = None, storage_path: str = None,
storage=None, index_config: IndexConfig | dict = None):
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
self.workspace = Path(workspace).expanduser() if workspace else None
self._init_cloud(api_key)
else:
self._init_local(model, retrieve_model, storage_path, storage, index_config)
def _init_cloud(self, api_key: str):
from .backend.cloud import CloudBackend
self._backend = CloudBackend(api_key=api_key)
def _init_local(self, model: str = None, retrieve_model: str = None,
storage_path: str = None, storage=None,
index_config: IndexConfig | dict = None):
# Build IndexConfig: merge model/retrieve_model with index_config
overrides = {}
if model:
overrides["model"] = model
if retrieve_model:
overrides["retrieve_model"] = retrieve_model
opt = ConfigLoader().load(overrides or None)
self.model = opt.model
self.retrieve_model = _normalize_retrieve_model(opt.retrieve_model or self.model)
if self.workspace:
self.workspace.mkdir(parents=True, exist_ok=True)
self.documents = {}
if self.workspace:
self._load_workspace()
def index(self, file_path: str, mode: str = "auto") -> str:
"""Index a document. Returns a document_id."""
# Persist a canonical absolute path so workspace reloads do not
# reinterpret caller-relative paths against the workspace directory.
file_path = os.path.abspath(os.path.expanduser(file_path))
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
doc_id = str(uuid.uuid4())
ext = os.path.splitext(file_path)[1].lower()
is_pdf = ext == '.pdf'
is_md = ext in ['.md', '.markdown']
if mode == "pdf" or (mode == "auto" and is_pdf):
print(f"Indexing PDF: {file_path}")
result = page_index(
doc=file_path,
model=self.model,
if_add_node_summary='yes',
if_add_node_text='yes',
if_add_node_id='yes',
if_add_doc_description='yes'
)
# Extract per-page text so queries don't need the original PDF
pages = []
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for i, page in enumerate(pdf_reader.pages, 1):
pages.append({'page': i, 'content': page.extract_text() or ''})
self.documents[doc_id] = {
'id': doc_id,
'type': 'pdf',
'path': file_path,
'doc_name': result.get('doc_name', ''),
'doc_description': result.get('doc_description', ''),
'page_count': len(pages),
'structure': result['structure'],
'pages': pages,
}
elif mode == "md" or (mode == "auto" and is_md):
print(f"Indexing Markdown: {file_path}")
coro = md_to_tree(
md_path=file_path,
if_thinning=False,
if_add_node_summary='yes',
summary_token_threshold=200,
model=self.model,
if_add_doc_description='yes',
if_add_node_text='yes',
if_add_node_id='yes'
)
try:
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
result = pool.submit(asyncio.run, coro).result()
except RuntimeError:
result = asyncio.run(coro)
self.documents[doc_id] = {
'id': doc_id,
'type': 'md',
'path': file_path,
'doc_name': result.get('doc_name', ''),
'doc_description': result.get('doc_description', ''),
'line_count': result.get('line_count', 0),
'structure': result['structure'],
}
if isinstance(index_config, IndexConfig):
opt = index_config.model_copy(update=overrides)
elif isinstance(index_config, dict):
merged = {**index_config, **overrides} # explicit model/retrieve_model win
opt = IndexConfig(**merged)
else:
raise ValueError(f"Unsupported file format for: {file_path}")
opt = IndexConfig(**overrides) if overrides else IndexConfig()
print(f"Indexing complete. Document ID: {doc_id}")
if self.workspace:
self._save_doc(doc_id)
return doc_id
self._validate_llm_provider(opt.model)
storage_path = Path(storage_path or ".pageindex").resolve()
storage_path.mkdir(parents=True, exist_ok=True)
from .storage.sqlite import SQLiteStorage
from .backend.local import LocalBackend
storage_engine = storage or SQLiteStorage(str(storage_path / "pageindex.db"))
self._backend = LocalBackend(
storage=storage_engine,
files_dir=str(storage_path / "files"),
model=opt.model,
retrieve_model=_normalize_retrieve_model(opt.retrieve_model or opt.model),
index_config=opt,
)
@staticmethod
def _make_meta_entry(doc: dict) -> dict:
"""Build a lightweight meta entry from a document dict."""
entry = {
'type': doc.get('type', ''),
'doc_name': doc.get('doc_name', ''),
'doc_description': doc.get('doc_description', ''),
'path': doc.get('path', ''),
}
if doc.get('type') == 'pdf':
entry['page_count'] = doc.get('page_count')
elif doc.get('type') == 'md':
entry['line_count'] = doc.get('line_count')
return entry
@staticmethod
def _read_json(path) -> dict | None:
"""Read a JSON file, returning None on any error."""
def _validate_llm_provider(model: str) -> None:
"""Validate model and check API key via litellm. Warns if key seems missing."""
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: corrupt {Path(path).name}: {e}")
return None
def _save_doc(self, doc_id: str):
doc = self.documents[doc_id].copy()
# Strip text from structure nodes — redundant with pages (PDF only)
if doc.get('structure') and doc.get('type') == 'pdf':
doc['structure'] = remove_fields(doc['structure'], fields=['text'])
path = self.workspace / f"{doc_id}.json"
with open(path, "w", encoding="utf-8") as f:
json.dump(doc, f, ensure_ascii=False, indent=2)
self._save_meta(doc_id, self._make_meta_entry(doc))
# Drop heavy fields; will lazy-load on demand
self.documents[doc_id].pop('structure', None)
self.documents[doc_id].pop('pages', None)
def _rebuild_meta(self) -> dict:
"""Scan individual doc JSON files and return a meta dict."""
meta = {}
for path in self.workspace.glob("*.json"):
if path.name == META_INDEX:
continue
doc = self._read_json(path)
if doc and isinstance(doc, dict):
meta[path.stem] = self._make_meta_entry(doc)
return meta
def _read_meta(self) -> dict | None:
"""Read and validate _meta.json, returning None on any corruption."""
meta = self._read_json(self.workspace / META_INDEX)
if meta is not None and not isinstance(meta, dict):
print(f"Warning: {META_INDEX} is not a JSON object, ignoring")
return None
return meta
def _save_meta(self, doc_id: str, entry: dict):
meta = self._read_meta() or self._rebuild_meta()
meta[doc_id] = entry
meta_path = self.workspace / META_INDEX
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
def _load_workspace(self):
meta = self._read_meta()
if meta is None:
meta = self._rebuild_meta()
if meta:
print(f"Loaded {len(meta)} document(s) from workspace (legacy mode).")
for doc_id, entry in meta.items():
doc = dict(entry, id=doc_id)
if doc.get('path') and not os.path.isabs(doc['path']):
doc['path'] = str((self.workspace / doc['path']).resolve())
self.documents[doc_id] = doc
def _ensure_doc_loaded(self, doc_id: str):
"""Load full document JSON on demand (structure, pages, etc.)."""
doc = self.documents.get(doc_id)
if not doc or doc.get('structure') is not None:
import litellm
litellm.model_cost_map_url = ""
_, provider, _, _ = litellm.get_llm_provider(model=model)
except Exception:
return
full = self._read_json(self.workspace / f"{doc_id}.json")
if not full:
return
doc['structure'] = full.get('structure', [])
if full.get('pages'):
doc['pages'] = full['pages']
def get_document(self, doc_id: str) -> str:
"""Return document metadata JSON."""
return get_document(self.documents, doc_id)
key = litellm.get_api_key(llm_provider=provider, dynamic_api_key=None)
if not key:
import os
common_var = f"{provider.upper()}_API_KEY"
if not os.getenv(common_var):
from .errors import PageIndexError
raise PageIndexError(
f"API key not configured for provider '{provider}' (model: {model}). "
f"Set the {common_var} environment variable."
)
def get_document_structure(self, doc_id: str) -> str:
"""Return document tree structure JSON (without text fields)."""
if self.workspace:
self._ensure_doc_loaded(doc_id)
return get_document_structure(self.documents, doc_id)
def collection(self, name: str = "default") -> Collection:
"""Get or create a collection. Defaults to 'default'."""
self._backend.get_or_create_collection(name)
return Collection(name=name, backend=self._backend)
def get_page_content(self, doc_id: str, pages: str) -> str:
"""Return page content for the given pages string (e.g. '5-7', '3,8', '12')."""
if self.workspace:
self._ensure_doc_loaded(doc_id)
return get_page_content(self.documents, doc_id, pages)
def list_collections(self) -> list[str]:
return self._backend.list_collections()
def delete_collection(self, name: str) -> None:
self._backend.delete_collection(name)
def register_parser(self, parser: DocumentParser) -> None:
"""Register a custom document parser. Only available in local mode."""
if not hasattr(self._backend, 'register_parser'):
from .errors import PageIndexError
raise PageIndexError("Custom parsers are not supported in cloud mode")
self._backend.register_parser(parser)
class LocalClient(PageIndexClient):
"""Local mode — indexes and queries documents on your machine.
Args:
model: LLM model for indexing (default: gpt-4o-2024-11-20)
retrieve_model: LLM model for agent QA (default: same as model)
storage_path: Directory for SQLite DB and files (default: ./.pageindex)
storage: Custom StorageEngine instance (default: SQLiteStorage)
index_config: Advanced indexing parameters. Pass an IndexConfig instance
or a dict. All fields have sensible defaults most users don't need this.
Example::
# Simple — defaults are fine
client = LocalClient(model="gpt-5.4")
# Advanced — tune indexing parameters
from pageindex.config import IndexConfig
client = LocalClient(
model="gpt-5.4",
index_config=IndexConfig(toc_check_page_num=30),
)
"""
def __init__(self, model: str = None, retrieve_model: str = None,
storage_path: str = None, storage=None,
index_config: IndexConfig | dict = None):
self._init_local(model, retrieve_model, storage_path, storage, index_config)
class CloudClient(PageIndexClient):
"""Cloud mode — fully managed by PageIndex cloud service. No LLM key needed."""
def __init__(self, api_key: str):
self._init_cloud(api_key)

69
pageindex/collection.py Normal file
View file

@ -0,0 +1,69 @@
# pageindex/collection.py
from __future__ import annotations
from typing import AsyncIterator
from .events import QueryEvent
from .backend.protocol import Backend
class QueryStream:
"""Wraps backend.query_stream() as an async iterable object."""
def __init__(self, backend: Backend, collection: str, question: str,
doc_ids: list[str] | None = None):
self._backend = backend
self._collection = collection
self._question = question
self._doc_ids = doc_ids
async def stream_events(self) -> AsyncIterator[QueryEvent]:
async for event in self._backend.query_stream(
self._collection, self._question, self._doc_ids
):
yield event
def __aiter__(self):
return self.stream_events()
class Collection:
def __init__(self, name: str, backend: Backend):
self._name = name
self._backend = backend
@property
def name(self) -> str:
return self._name
def add(self, file_path: str) -> str:
return self._backend.add_document(self._name, file_path)
def list_documents(self) -> list[dict]:
return self._backend.list_documents(self._name)
def get_document(self, doc_id: str, include_text: bool = False) -> dict:
return self._backend.get_document(self._name, doc_id, include_text=include_text)
def get_document_structure(self, doc_id: str) -> list:
return self._backend.get_document_structure(self._name, doc_id)
def get_page_content(self, doc_id: str, pages: str) -> list:
return self._backend.get_page_content(self._name, doc_id, pages)
def delete_document(self, doc_id: str) -> None:
self._backend.delete_document(self._name, doc_id)
def query(self, question: str, doc_ids: list[str] | None = None,
stream: bool = False) -> str | QueryStream:
"""Query documents in this collection.
- stream=False: returns answer string (sync)
- stream=True: returns async iterable of QueryEvent
Usage:
answer = col.query("question")
async for event in col.query("question", stream=True):
...
"""
if stream:
return QueryStream(self._backend, self._name, question, doc_ids)
return self._backend.query(self._name, question, doc_ids)

22
pageindex/config.py Normal file
View file

@ -0,0 +1,22 @@
# pageindex/config.py
from __future__ import annotations
from pydantic import BaseModel
class IndexConfig(BaseModel):
"""Configuration for the PageIndex indexing pipeline.
All fields have sensible defaults. Advanced users can override
via LocalClient(index_config=IndexConfig(...)) or a dict.
"""
model_config = {"extra": "forbid"}
model: str = "gpt-4o-2024-11-20"
retrieve_model: str | None = None
toc_check_page_num: int = 20
max_page_num_each_node: int = 10
max_token_num_each_node: int = 20000
if_add_node_id: bool = True
if_add_node_summary: bool = True
if_add_doc_description: bool = True
if_add_node_text: bool = False

View file

@ -1,10 +0,0 @@
model: "gpt-4o-2024-11-20"
# model: "anthropic/claude-sonnet-4-6"
retrieve_model: "gpt-5.4" # defaults to `model` if not set
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000
if_add_node_id: "yes"
if_add_node_summary: "yes"
if_add_doc_description: "no"
if_add_node_text: "no"

28
pageindex/errors.py Normal file
View file

@ -0,0 +1,28 @@
class PageIndexError(Exception):
"""Base exception for all PageIndex SDK errors."""
pass
class CollectionNotFoundError(PageIndexError):
"""Collection does not exist."""
pass
class DocumentNotFoundError(PageIndexError):
"""Document ID not found."""
pass
class IndexingError(PageIndexError):
"""Indexing pipeline failure."""
pass
class CloudAPIError(PageIndexError):
"""Cloud API returned error."""
pass
class FileTypeError(PageIndexError):
"""Unsupported file type."""
pass

9
pageindex/events.py Normal file
View file

@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Literal, Any
@dataclass
class QueryEvent:
"""Event emitted during streaming query."""
type: Literal["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"]
data: Any

View file

View file

@ -0,0 +1,2 @@
# Re-export from the original utils.py for backward compatibility
from ..utils import *

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,341 @@
import asyncio
import json
import re
import os
try:
from .legacy_utils import *
except:
from legacy_utils import *
async def get_node_summary(node, summary_token_threshold=200, model=None):
node_text = node.get('text')
num_tokens = count_tokens(node_text, model=model)
if num_tokens < summary_token_threshold:
return node_text
else:
return await generate_node_summary(node, model=model)
async def generate_summaries_for_structure_md(structure, summary_token_threshold, model=None):
nodes = structure_to_list(structure)
tasks = [get_node_summary(node, summary_token_threshold=summary_token_threshold, model=model) for node in nodes]
summaries = await asyncio.gather(*tasks)
for node, summary in zip(nodes, summaries):
if not node.get('nodes'):
node['summary'] = summary
else:
node['prefix_summary'] = summary
return structure
def extract_nodes_from_markdown(markdown_content):
header_pattern = r'^(#{1,6})\s+(.+)$'
code_block_pattern = r'^```'
node_list = []
lines = markdown_content.split('\n')
in_code_block = False
for line_num, line in enumerate(lines, 1):
stripped_line = line.strip()
# Check for code block delimiters (triple backticks)
if re.match(code_block_pattern, stripped_line):
in_code_block = not in_code_block
continue
# Skip empty lines
if not stripped_line:
continue
# Only look for headers when not inside a code block
if not in_code_block:
match = re.match(header_pattern, stripped_line)
if match:
title = match.group(2).strip()
node_list.append({'node_title': title, 'line_num': line_num})
return node_list, lines
def extract_node_text_content(node_list, markdown_lines):
all_nodes = []
for node in node_list:
line_content = markdown_lines[node['line_num'] - 1]
header_match = re.match(r'^(#{1,6})', line_content)
if header_match is None:
print(f"Warning: Line {node['line_num']} does not contain a valid header: '{line_content}'")
continue
processed_node = {
'title': node['node_title'],
'line_num': node['line_num'],
'level': len(header_match.group(1))
}
all_nodes.append(processed_node)
for i, node in enumerate(all_nodes):
start_line = node['line_num'] - 1
if i + 1 < len(all_nodes):
end_line = all_nodes[i + 1]['line_num'] - 1
else:
end_line = len(markdown_lines)
node['text'] = '\n'.join(markdown_lines[start_line:end_line]).strip()
return all_nodes
def update_node_list_with_text_token_count(node_list, model=None):
def find_all_children(parent_index, parent_level, node_list):
"""Find all direct and indirect children of a parent node"""
children_indices = []
# Look for children after the parent
for i in range(parent_index + 1, len(node_list)):
current_level = node_list[i]['level']
# If we hit a node at same or higher level than parent, stop
if current_level <= parent_level:
break
# This is a descendant
children_indices.append(i)
return children_indices
# Make a copy to avoid modifying the original
result_list = node_list.copy()
# Process nodes from end to beginning to ensure children are processed before parents
for i in range(len(result_list) - 1, -1, -1):
current_node = result_list[i]
current_level = current_node['level']
# Get all children of this node
children_indices = find_all_children(i, current_level, result_list)
# Start with the node's own text
node_text = current_node.get('text', '')
total_text = node_text
# Add all children's text
for child_index in children_indices:
child_text = result_list[child_index].get('text', '')
if child_text:
total_text += '\n' + child_text
# Calculate token count for combined text
result_list[i]['text_token_count'] = count_tokens(total_text, model=model)
return result_list
def tree_thinning_for_index(node_list, min_node_token=None, model=None):
def find_all_children(parent_index, parent_level, node_list):
children_indices = []
for i in range(parent_index + 1, len(node_list)):
current_level = node_list[i]['level']
if current_level <= parent_level:
break
children_indices.append(i)
return children_indices
result_list = node_list.copy()
nodes_to_remove = set()
for i in range(len(result_list) - 1, -1, -1):
if i in nodes_to_remove:
continue
current_node = result_list[i]
current_level = current_node['level']
total_tokens = current_node.get('text_token_count', 0)
if total_tokens < min_node_token:
children_indices = find_all_children(i, current_level, result_list)
children_texts = []
for child_index in sorted(children_indices):
if child_index not in nodes_to_remove:
child_text = result_list[child_index].get('text', '')
if child_text.strip():
children_texts.append(child_text)
nodes_to_remove.add(child_index)
if children_texts:
parent_text = current_node.get('text', '')
merged_text = parent_text
for child_text in children_texts:
if merged_text and not merged_text.endswith('\n'):
merged_text += '\n\n'
merged_text += child_text
result_list[i]['text'] = merged_text
result_list[i]['text_token_count'] = count_tokens(merged_text, model=model)
for index in sorted(nodes_to_remove, reverse=True):
result_list.pop(index)
return result_list
def build_tree_from_nodes(node_list):
if not node_list:
return []
stack = []
root_nodes = []
node_counter = 1
for node in node_list:
current_level = node['level']
tree_node = {
'title': node['title'],
'node_id': str(node_counter).zfill(4),
'text': node['text'],
'line_num': node['line_num'],
'nodes': []
}
node_counter += 1
while stack and stack[-1][1] >= current_level:
stack.pop()
if not stack:
root_nodes.append(tree_node)
else:
parent_node, parent_level = stack[-1]
parent_node['nodes'].append(tree_node)
stack.append((tree_node, current_level))
return root_nodes
def clean_tree_for_output(tree_nodes):
cleaned_nodes = []
for node in tree_nodes:
cleaned_node = {
'title': node['title'],
'node_id': node['node_id'],
'text': node['text'],
'line_num': node['line_num']
}
if node['nodes']:
cleaned_node['nodes'] = clean_tree_for_output(node['nodes'])
cleaned_nodes.append(cleaned_node)
return cleaned_nodes
async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary=False, summary_token_threshold=None, model=None, if_add_doc_description=False, if_add_node_text=False, if_add_node_id=True):
with open(md_path, 'r', encoding='utf-8') as f:
markdown_content = f.read()
line_count = markdown_content.count('\n') + 1
print(f"Extracting nodes from markdown...")
node_list, markdown_lines = extract_nodes_from_markdown(markdown_content)
print(f"Extracting text content from nodes...")
nodes_with_content = extract_node_text_content(node_list, markdown_lines)
if if_thinning:
nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model)
print(f"Thinning nodes...")
nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model)
print(f"Building tree from nodes...")
tree_structure = build_tree_from_nodes(nodes_with_content)
if if_add_node_id:
write_node_id(tree_structure)
print(f"Formatting tree structure...")
if if_add_node_summary:
# Always include text for summary generation
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
print(f"Generating summaries for each node...")
tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model)
if not if_add_node_text:
# Remove text after summary generation if not requested
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
if if_add_doc_description:
print(f"Generating document description...")
clean_structure = create_clean_structure_for_description(tree_structure)
doc_description = generate_doc_description(clean_structure, model=model)
return {
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
'doc_description': doc_description,
'line_count': line_count,
'structure': tree_structure,
}
else:
# No summaries needed, format based on text preference
if if_add_node_text:
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
else:
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
return {
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
'line_count': line_count,
'structure': tree_structure,
}
if __name__ == "__main__":
import os
import json
# MD_NAME = 'Detect-Order-Construct'
MD_NAME = 'cognitive-load'
MD_PATH = os.path.join(os.path.dirname(__file__), '..', 'examples/documents/', f'{MD_NAME}.md')
MODEL="gpt-4.1"
IF_THINNING=False
THINNING_THRESHOLD=5000
SUMMARY_TOKEN_THRESHOLD=200
IF_SUMMARY=True
tree_structure = asyncio.run(md_to_tree(
md_path=MD_PATH,
if_thinning=IF_THINNING,
min_token_threshold=THINNING_THRESHOLD,
if_add_node_summary='yes' if IF_SUMMARY else 'no',
summary_token_threshold=SUMMARY_TOKEN_THRESHOLD,
model=MODEL))
print('\n' + '='*60)
print('TREE STRUCTURE')
print('='*60)
print_json(tree_structure)
print('\n' + '='*60)
print('TABLE OF CONTENTS')
print('='*60)
print_toc(tree_structure['structure'])
output_path = os.path.join(os.path.dirname(__file__), '..', 'results', f'{MD_NAME}_structure.json')
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(tree_structure, f, indent=2, ensure_ascii=False)
print(f"\nTree structure saved to: {output_path}")

122
pageindex/index/pipeline.py Normal file
View file

@ -0,0 +1,122 @@
# pageindex/index/pipeline.py
from __future__ import annotations
from ..parser.protocol import ContentNode, ParsedDocument
def detect_strategy(nodes: list[ContentNode]) -> str:
"""Determine which indexing strategy to use based on node data."""
if any(n.level is not None for n in nodes):
return "level_based"
return "content_based"
def build_tree_from_levels(nodes: list[ContentNode]) -> list[dict]:
"""Strategy 0: Build tree from explicit level information.
Adapted from pageindex/page_index_md.py:build_tree_from_nodes."""
stack = []
root_nodes = []
for node in nodes:
tree_node = {
"title": node.title or "",
"text": node.content,
"line_num": node.index,
"nodes": [],
}
current_level = node.level or 1
while stack and stack[-1][1] >= current_level:
stack.pop()
if not stack:
root_nodes.append(tree_node)
else:
parent_node, _ = stack[-1]
parent_node["nodes"].append(tree_node)
stack.append((tree_node, current_level))
return root_nodes
def _run_async(coro):
"""Run an async coroutine, handling the case where an event loop is already running."""
import asyncio
import concurrent.futures
try:
asyncio.get_running_loop()
# Already inside an event loop -- run in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
except RuntimeError:
return asyncio.run(coro)
def build_index(parsed: ParsedDocument, model: str = None, opt=None) -> dict:
"""Main entry point: ParsedDocument -> tree structure dict.
Routes to the appropriate strategy and runs enhancement."""
from .utils import (write_node_id, add_node_text, remove_structure_text,
generate_summaries_for_structure, generate_doc_description,
create_clean_structure_for_description)
from ..config import IndexConfig
if opt is None:
opt = IndexConfig(model=model) if model else IndexConfig()
nodes = parsed.nodes
strategy = detect_strategy(nodes)
if strategy == "level_based":
structure = build_tree_from_levels(nodes)
# For level-based, text is already in the tree nodes
else:
# Strategies 1-3: convert ContentNode list to page_list format for existing pipeline
page_list = [(n.content, n.tokens) for n in nodes]
structure = _run_async(_content_based_pipeline(page_list, opt))
# Unified enhancement
if opt.if_add_node_id:
write_node_id(structure)
if strategy != "level_based":
if opt.if_add_node_text or opt.if_add_node_summary:
add_node_text(structure, page_list)
if opt.if_add_node_summary:
_run_async(generate_summaries_for_structure(structure, model=opt.model))
if not opt.if_add_node_text and strategy != "level_based":
remove_structure_text(structure)
result = {
"doc_name": parsed.doc_name,
"structure": structure,
}
if opt.if_add_doc_description:
clean_structure = create_clean_structure_for_description(structure)
result["doc_description"] = generate_doc_description(
clean_structure, model=opt.model
)
return result
class _NullLogger:
"""Minimal logger that satisfies the tree_parser interface without writing files."""
def info(self, message, **kwargs): pass
def error(self, message, **kwargs): pass
def debug(self, message, **kwargs): pass
async def _content_based_pipeline(page_list, opt):
"""Strategies 1-3: delegates to the existing PDF pipeline from pageindex/page_index.py.
The page_list is already in the format expected by tree_parser:
[(page_text, token_count), ...]
"""
from .page_index import tree_parser
logger = _NullLogger()
structure = await tree_parser(page_list, opt, doc=None, logger=logger)
return structure

431
pageindex/index/utils.py Normal file
View file

@ -0,0 +1,431 @@
import litellm
import logging
import time
import json
import copy
import re
import asyncio
import PyPDF2
logger = logging.getLogger(__name__)
def count_tokens(text, model=None):
if not text:
return 0
return litellm.token_counter(model=model, text=text)
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
if model:
model = model.removeprefix("litellm/")
max_retries = 10
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
for i in range(max_retries):
try:
litellm.drop_params = True
response = litellm.completion(
model=model,
messages=messages,
temperature=0,
)
content = response.choices[0].message.content
if return_finish_reason:
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
return content, finish_reason
return content
except Exception as e:
logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries)
logger.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1)
else:
logger.error('Max retries reached for prompt: ' + prompt)
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
async def llm_acompletion(model, prompt):
if model:
model = model.removeprefix("litellm/")
max_retries = 10
messages = [{"role": "user", "content": prompt}]
for i in range(max_retries):
try:
litellm.drop_params = True
response = await litellm.acompletion(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content
except Exception as e:
logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries)
logger.error(f"Error: {e}")
if i < max_retries - 1:
await asyncio.sleep(1)
else:
logger.error('Max retries reached for prompt: ' + prompt)
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
def extract_json(content):
try:
# First, try to extract JSON enclosed within ```json and ```
start_idx = content.find("```json")
if start_idx != -1:
start_idx += 7 # Adjust index to start after the delimiter
end_idx = content.rfind("```")
json_content = content[start_idx:end_idx].strip()
else:
# If no delimiters, assume entire content could be JSON
json_content = content.strip()
# Clean up common issues that might cause parsing errors
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
json_content = ' '.join(json_content.split()) # Normalize whitespace
# Attempt to parse and return the JSON object
return json.loads(json_content)
except json.JSONDecodeError as e:
logging.error(f"Failed to extract JSON: {e}")
# Try to clean up the content further if initial parsing fails
try:
# Remove any trailing commas before closing brackets/braces
json_content = json_content.replace(',]', ']').replace(',}', '}')
return json.loads(json_content)
except Exception:
logging.error("Failed to parse JSON even after cleanup")
return {}
except Exception as e:
logging.error(f"Unexpected error while extracting JSON: {e}")
return {}
def get_json_content(response):
start_idx = response.find("```json")
if start_idx != -1:
start_idx += 7
response = response[start_idx:]
end_idx = response.rfind("```")
if end_idx != -1:
response = response[:end_idx]
json_content = response.strip()
return json_content
def write_node_id(data, node_id=0):
if isinstance(data, dict):
data['node_id'] = str(node_id).zfill(4)
node_id += 1
for key in list(data.keys()):
if 'nodes' in key:
node_id = write_node_id(data[key], node_id)
elif isinstance(data, list):
for index in range(len(data)):
node_id = write_node_id(data[index], node_id)
return node_id
def remove_fields(data, fields=None):
fields = fields or ["text"]
if isinstance(data, dict):
return {k: remove_fields(v, fields)
for k, v in data.items() if k not in fields}
elif isinstance(data, list):
return [remove_fields(item, fields) for item in data]
return data
def structure_to_list(structure):
if isinstance(structure, dict):
nodes = []
nodes.append(structure)
if 'nodes' in structure:
nodes.extend(structure_to_list(structure['nodes']))
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(structure_to_list(item))
return nodes
def get_nodes(structure):
if isinstance(structure, dict):
structure_node = copy.deepcopy(structure)
structure_node.pop('nodes', None)
nodes = [structure_node]
for key in list(structure.keys()):
if 'nodes' in key:
nodes.extend(get_nodes(structure[key]))
return nodes
elif isinstance(structure, list):
nodes = []
for item in structure:
nodes.extend(get_nodes(item))
return nodes
def get_leaf_nodes(structure):
if isinstance(structure, dict):
if not structure['nodes']:
structure_node = copy.deepcopy(structure)
structure_node.pop('nodes', None)
return [structure_node]
else:
leaf_nodes = []
for key in list(structure.keys()):
if 'nodes' in key:
leaf_nodes.extend(get_leaf_nodes(structure[key]))
return leaf_nodes
elif isinstance(structure, list):
leaf_nodes = []
for item in structure:
leaf_nodes.extend(get_leaf_nodes(item))
return leaf_nodes
async def generate_node_summary(node, model=None):
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
Partial Document Text: {node['text']}
Directly return the description, do not include any other text.
"""
response = await llm_acompletion(model, prompt)
return response
async def generate_summaries_for_structure(structure, model=None):
nodes = structure_to_list(structure)
tasks = [generate_node_summary(node, model=model) for node in nodes]
summaries = await asyncio.gather(*tasks)
for node, summary in zip(nodes, summaries):
node['summary'] = summary
return structure
def generate_doc_description(structure, model=None):
prompt = f"""Your are an expert in generating descriptions for a document.
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
Document Structure: {structure}
Directly return the description, do not include any other text.
"""
response = llm_completion(model, prompt)
return response
def list_to_tree(data):
def get_parent_structure(structure):
"""Helper function to get the parent structure code"""
if not structure:
return None
parts = str(structure).split('.')
return '.'.join(parts[:-1]) if len(parts) > 1 else None
# First pass: Create nodes and track parent-child relationships
nodes = {}
root_nodes = []
for item in data:
structure = item.get('structure')
node = {
'title': item.get('title'),
'start_index': item.get('start_index'),
'end_index': item.get('end_index'),
'nodes': []
}
nodes[structure] = node
# Find parent
parent_structure = get_parent_structure(structure)
if parent_structure:
# Add as child to parent if parent exists
if parent_structure in nodes:
nodes[parent_structure]['nodes'].append(node)
else:
root_nodes.append(node)
else:
# No parent, this is a root node
root_nodes.append(node)
# Helper function to clean empty children arrays
def clean_node(node):
if not node['nodes']:
del node['nodes']
else:
for child in node['nodes']:
clean_node(child)
return node
# Clean and return the tree
return [clean_node(node) for node in root_nodes]
def post_processing(structure, end_physical_index):
# First convert page_number to start_index in flat list
for i, item in enumerate(structure):
item['start_index'] = item.get('physical_index')
if i < len(structure) - 1:
if structure[i + 1].get('appear_start') == 'yes':
item['end_index'] = structure[i + 1]['physical_index']-1
else:
item['end_index'] = structure[i + 1]['physical_index']
else:
item['end_index'] = end_physical_index
tree = list_to_tree(structure)
if len(tree)!=0:
return tree
else:
### remove appear_start
for node in structure:
node.pop('appear_start', None)
node.pop('physical_index', None)
return structure
def reorder_dict(data, key_order):
if not key_order:
return data
return {key: data[key] for key in key_order if key in data}
def format_structure(structure, order=None):
if not order:
return structure
if isinstance(structure, dict):
if 'nodes' in structure:
structure['nodes'] = format_structure(structure['nodes'], order)
if not structure.get('nodes'):
structure.pop('nodes', None)
structure = reorder_dict(structure, order)
elif isinstance(structure, list):
structure = [format_structure(item, order) for item in structure]
return structure
def create_clean_structure_for_description(structure):
"""
Create a clean structure for document description generation,
excluding unnecessary fields like 'text'.
"""
if isinstance(structure, dict):
clean_node = {}
# Only include essential fields for description
for key in ['title', 'node_id', 'summary', 'prefix_summary']:
if key in structure:
clean_node[key] = structure[key]
# Recursively process child nodes
if 'nodes' in structure and structure['nodes']:
clean_node['nodes'] = create_clean_structure_for_description(structure['nodes'])
return clean_node
elif isinstance(structure, list):
return [create_clean_structure_for_description(item) for item in structure]
else:
return structure
def _get_text_of_pages(page_list, start_page, end_page):
"""Concatenate text from page_list for pages [start_page, end_page] (1-indexed)."""
text = ""
for page_num in range(start_page - 1, end_page):
text += page_list[page_num][0]
return text
def add_node_text(node, page_list):
"""Recursively add 'text' field to each node from page_list content.
Each node must have 'start_index' and 'end_index' (1-indexed page numbers).
page_list is [(page_text, token_count), ...].
"""
if isinstance(node, dict):
start_page = node.get('start_index')
end_page = node.get('end_index')
if start_page is not None and end_page is not None:
node['text'] = _get_text_of_pages(page_list, start_page, end_page)
if 'nodes' in node:
add_node_text(node['nodes'], page_list)
elif isinstance(node, list):
for item in node:
add_node_text(item, page_list)
def remove_structure_text(data):
if isinstance(data, dict):
data.pop('text', None)
if 'nodes' in data:
remove_structure_text(data['nodes'])
elif isinstance(data, list):
for item in data:
remove_structure_text(item)
return data
# ── Functions migrated from retrieve.py ──────────────────────────────────────
def parse_pages(pages: str) -> list[int]:
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
result = []
for part in pages.split(','):
part = part.strip()
if '-' in part:
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
if start > end:
raise ValueError(f"Invalid range '{part}': start must be <= end")
result.extend(range(start, end + 1))
else:
result.append(int(part))
result = [p for p in result if p >= 1]
result = sorted(set(result))
if len(result) > 1000:
raise ValueError(f"Page range too large: {len(result)} pages (max 1000)")
return result
def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]:
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
total = len(pdf_reader.pages)
valid_pages = [p for p in page_nums if 1 <= p <= total]
return [
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
for p in valid_pages
]
def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]:
"""
For Markdown documents, 'pages' are line numbers.
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
"""
if not page_nums:
return []
min_line, max_line = min(page_nums), max(page_nums)
results = []
seen = set()
def _traverse(nodes):
for node in nodes:
ln = node.get('line_num')
if ln and min_line <= ln <= max_line and ln not in seen:
seen.add(ln)
results.append({'page': ln, 'content': node.get('text', '')})
if node.get('nodes'):
_traverse(node['nodes'])
_traverse(structure)
results.sort(key=lambda x: x['page'])
return results

View file

@ -1113,11 +1113,12 @@ def page_index_main(doc, opt=None):
def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None):
from .config import IndexConfig
user_opt = {
arg: value for arg, value in locals().items()
if arg != "doc" and value is not None
}
opt = ConfigLoader().load(user_opt)
opt = IndexConfig(**user_opt)
return page_index_main(doc, opt)

View file

View file

@ -0,0 +1,59 @@
import re
from pathlib import Path
from .protocol import ContentNode, ParsedDocument
from ..index.utils import count_tokens
class MarkdownParser:
def supported_extensions(self) -> list[str]:
return [".md", ".markdown"]
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
path = Path(file_path)
model = kwargs.get("model")
with open(path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
headers = self._extract_headers(lines)
nodes = self._build_nodes(headers, lines, model)
return ParsedDocument(doc_name=path.stem, nodes=nodes)
def _extract_headers(self, lines: list[str]) -> list[dict]:
header_pattern = r"^(#{1,6})\s+(.+)$"
code_block_pattern = r"^```"
headers = []
in_code_block = False
for line_num, line in enumerate(lines, 1):
stripped = line.strip()
if re.match(code_block_pattern, stripped):
in_code_block = not in_code_block
continue
if not in_code_block and stripped:
match = re.match(header_pattern, stripped)
if match:
headers.append({
"title": match.group(2).strip(),
"level": len(match.group(1)),
"line_num": line_num,
})
return headers
def _build_nodes(self, headers: list[dict], lines: list[str], model: str | None) -> list[ContentNode]:
nodes = []
for i, header in enumerate(headers):
start = header["line_num"] - 1
end = headers[i + 1]["line_num"] - 1 if i + 1 < len(headers) else len(lines)
text = "\n".join(lines[start:end]).strip()
tokens = count_tokens(text, model=model)
nodes.append(ContentNode(
content=text,
tokens=tokens,
title=header["title"],
index=header["line_num"],
level=header["level"],
))
return nodes

101
pageindex/parser/pdf.py Normal file
View file

@ -0,0 +1,101 @@
import pymupdf
from pathlib import Path
from .protocol import ContentNode, ParsedDocument
from ..index.utils import count_tokens
# Minimum image dimension to keep (skip icons/artifacts)
_MIN_IMAGE_SIZE = 32
class PdfParser:
def supported_extensions(self) -> list[str]:
return [".pdf"]
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
path = Path(file_path)
model = kwargs.get("model")
images_dir = kwargs.get("images_dir")
nodes = []
with pymupdf.open(str(path)) as doc:
for i, page in enumerate(doc):
page_num = i + 1
if images_dir:
content, images = self._extract_page_with_images(
doc, page, page_num, images_dir)
else:
content = page.get_text()
images = None
tokens = count_tokens(content, model=model)
nodes.append(ContentNode(
content=content or "",
tokens=tokens,
index=page_num,
images=images if images else None,
))
return ParsedDocument(doc_name=path.stem, nodes=nodes)
@staticmethod
def _extract_page_with_images(doc, page, page_num: int,
images_dir: str) -> tuple[str, list[dict]]:
"""Extract text and images from a page, preserving their relative order.
Uses get_text("dict") to iterate blocks in reading order.
Text blocks become text; image blocks are saved to disk and replaced
with an inline placeholder: ![image](path)
"""
images_path = Path(images_dir)
images_path.mkdir(parents=True, exist_ok=True)
# Use path relative to cwd so downstream consumers can access directly
try:
rel_images_path = images_path.relative_to(Path.cwd())
except ValueError:
rel_images_path = images_path
parts: list[str] = []
images: list[dict] = []
img_idx = 0
for block in page.get_text("dict")["blocks"]:
if block["type"] == 0: # text block
lines = []
for line in block["lines"]:
spans_text = "".join(span["text"] for span in line["spans"])
lines.append(spans_text)
parts.append("\n".join(lines))
elif block["type"] == 1: # image block
width = block.get("width", 0)
height = block.get("height", 0)
if width < _MIN_IMAGE_SIZE or height < _MIN_IMAGE_SIZE:
continue
image_bytes = block.get("image")
ext = block.get("ext", "png")
if not image_bytes:
continue
try:
pix = pymupdf.Pixmap(image_bytes)
if pix.n > 4:
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
filename = f"p{page_num}_img{img_idx}.png"
save_path = images_path / filename
pix.save(str(save_path))
pix = None
except Exception:
continue
rel_path = str(rel_images_path / filename)
images.append({
"path": rel_path,
"width": width,
"height": height,
})
parts.append(f"![image]({rel_path})")
img_idx += 1
content = "\n".join(parts)
return content, images

View file

@ -0,0 +1,28 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
@dataclass
class ContentNode:
"""Universal content unit produced by parsers."""
content: str
tokens: int
title: str | None = None
index: int | None = None
level: int | None = None
images: list[dict] | None = None # [{"path": str, "width": int, "height": int}, ...]
@dataclass
class ParsedDocument:
"""Unified parser output. Always a flat list of ContentNode."""
doc_name: str
nodes: list[ContentNode]
metadata: dict | None = None
@runtime_checkable
class DocumentParser(Protocol):
def supported_extensions(self) -> list[str]: ...
def parse(self, file_path: str, **kwargs) -> ParsedDocument: ...

View file

View file

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Protocol, runtime_checkable
@runtime_checkable
class StorageEngine(Protocol):
def create_collection(self, name: str) -> None: ...
def get_or_create_collection(self, name: str) -> None: ...
def list_collections(self) -> list[str]: ...
def delete_collection(self, name: str) -> None: ...
def save_document(self, collection: str, doc_id: str, doc: dict) -> None: ...
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: ...
def get_document(self, collection: str, doc_id: str) -> dict: ...
def get_document_structure(self, collection: str, doc_id: str) -> list: ...
def get_pages(self, collection: str, doc_id: str) -> list | None: ...
def list_documents(self, collection: str) -> list[dict]: ...
def delete_document(self, collection: str, doc_id: str) -> None: ...
def close(self) -> None: ...

164
pageindex/storage/sqlite.py Normal file
View file

@ -0,0 +1,164 @@
import json
import sqlite3
import threading
from pathlib import Path
class SQLiteStorage:
def __init__(self, db_path: str):
self._db_path = Path(db_path).expanduser()
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._local = threading.local()
self._connections: list[sqlite3.Connection] = []
self._conn_lock = threading.Lock()
self._init_schema()
def _get_conn(self) -> sqlite3.Connection:
"""Return a thread-local SQLite connection."""
if not hasattr(self._local, "conn"):
conn = sqlite3.connect(str(self._db_path))
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
self._local.conn = conn
with self._conn_lock:
self._connections.append(conn)
return self._local.conn
def _init_schema(self):
conn = self._get_conn()
conn.execute("PRAGMA user_version = 1")
conn.executescript("""
CREATE TABLE IF NOT EXISTS collections (
name TEXT PRIMARY KEY CHECK(length(name) <= 128 AND name GLOB '[a-zA-Z0-9_-]*'),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS documents (
doc_id TEXT PRIMARY KEY,
collection_name TEXT NOT NULL REFERENCES collections(name) ON DELETE CASCADE,
doc_name TEXT,
doc_description TEXT,
file_path TEXT,
file_hash TEXT,
doc_type TEXT NOT NULL,
structure JSON,
pages JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_docs_collection ON documents(collection_name);
CREATE INDEX IF NOT EXISTS idx_docs_hash ON documents(collection_name, file_hash);
""")
conn.commit()
def create_collection(self, name: str) -> None:
conn = self._get_conn()
conn.execute("INSERT INTO collections (name) VALUES (?)", (name,))
conn.commit()
def get_or_create_collection(self, name: str) -> None:
conn = self._get_conn()
conn.execute("INSERT OR IGNORE INTO collections (name) VALUES (?)", (name,))
conn.commit()
def list_collections(self) -> list[str]:
conn = self._get_conn()
rows = conn.execute("SELECT name FROM collections ORDER BY name").fetchall()
return [r[0] for r in rows]
def delete_collection(self, name: str) -> None:
conn = self._get_conn()
conn.execute("DELETE FROM collections WHERE name = ?", (name,))
conn.commit()
def save_document(self, collection: str, doc_id: str, doc: dict) -> None:
conn = self._get_conn()
conn.execute(
"""INSERT OR REPLACE INTO documents
(doc_id, collection_name, doc_name, doc_description, file_path, file_hash, doc_type, structure, pages)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(doc_id, collection, doc.get("doc_name"), doc.get("doc_description"),
doc.get("file_path"), doc.get("file_hash"), doc["doc_type"],
json.dumps(doc.get("structure", [])),
json.dumps(doc.get("pages")) if doc.get("pages") else None),
)
conn.commit()
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None:
conn = self._get_conn()
row = conn.execute(
"SELECT doc_id FROM documents WHERE collection_name = ? AND file_hash = ?",
(collection, file_hash),
).fetchone()
return row[0] if row else None
def get_document(self, collection: str, doc_id: str) -> dict:
conn = self._get_conn()
row = conn.execute(
"SELECT doc_id, doc_name, doc_description, file_path, doc_type FROM documents WHERE doc_id = ? AND collection_name = ?",
(doc_id, collection),
).fetchone()
if not row:
return {}
return {"doc_id": row[0], "doc_name": row[1], "doc_description": row[2],
"file_path": row[3], "doc_type": row[4]}
def get_document_structure(self, collection: str, doc_id: str) -> list:
conn = self._get_conn()
row = conn.execute(
"SELECT structure FROM documents WHERE doc_id = ? AND collection_name = ?",
(doc_id, collection),
).fetchone()
if not row:
return []
return json.loads(row[0])
def get_pages(self, collection: str, doc_id: str) -> list | None:
"""Return cached page content, or None if not cached."""
conn = self._get_conn()
row = conn.execute(
"SELECT pages FROM documents WHERE doc_id = ? AND collection_name = ?",
(doc_id, collection),
).fetchone()
if not row or not row[0]:
return None
return json.loads(row[0])
def list_documents(self, collection: str) -> list[dict]:
conn = self._get_conn()
rows = conn.execute(
"SELECT doc_id, doc_name, doc_type FROM documents WHERE collection_name = ? ORDER BY created_at",
(collection,),
).fetchall()
return [{"doc_id": r[0], "doc_name": r[1], "doc_type": r[2]} for r in rows]
def delete_document(self, collection: str, doc_id: str) -> None:
conn = self._get_conn()
conn.execute(
"DELETE FROM documents WHERE doc_id = ? AND collection_name = ?",
(doc_id, collection),
)
conn.commit()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def close(self) -> None:
"""Close all tracked SQLite connections across all threads."""
with self._conn_lock:
for conn in self._connections:
try:
conn.close()
except Exception:
pass
self._connections.clear()
if hasattr(self._local, "conn"):
del self._local.conn
def __del__(self):
try:
self.close()
except Exception:
pass

48
pyproject.toml Normal file
View file

@ -0,0 +1,48 @@
[build-system]
requires = ["setuptools>=68.0"]
build-backend = "setuptools.build_meta"
[project]
name = "pageindex"
version = "0.3.0"
description = "Python SDK for PageIndex"
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.10"
authors = [
{name = "Ray", email = "ray@vectify.ai"},
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
keywords = ["rag", "document", "retrieval", "llm", "pageindex"]
dependencies = [
"litellm>=1.82.0",
"pymupdf>=1.26.0",
"PyPDF2>=3.0.0",
"python-dotenv>=1.0.0",
"pyyaml>=6.0",
"openai-agents>=0.1.0",
"requests>=2.28.0",
"httpx[socks]>=0.28.1",
]
[project.optional-dependencies]
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
[project.urls]
Homepage = "https://github.com/VectifyAI/PageIndex"
Documentation = "https://docs.pageindex.ai"
Repository = "https://github.com/VectifyAI/PageIndex"
Issues = "https://github.com/VectifyAI/PageIndex/issues"
[tool.setuptools.packages.find]
include = ["pageindex*"]

View file

@ -1,9 +1,9 @@
import argparse
import os
import json
from pageindex import *
from pageindex.page_index_md import md_to_tree
from pageindex.utils import ConfigLoader
from pageindex.index.page_index import *
from pageindex.index.page_index_md import md_to_tree
from pageindex.config import IndexConfig
if __name__ == "__main__":
# Set up argument parser
@ -11,7 +11,7 @@ if __name__ == "__main__":
parser.add_argument('--pdf_path', type=str, help='Path to the PDF file')
parser.add_argument('--md_path', type=str, help='Path to the Markdown file')
parser.add_argument('--model', type=str, default=None, help='Model to use (overrides config.yaml)')
parser.add_argument('--model', type=str, default=None, help='Model to use')
parser.add_argument('--toc-check-pages', type=int, default=None,
help='Number of pages to check for table of contents (PDF only)')
@ -20,15 +20,15 @@ if __name__ == "__main__":
parser.add_argument('--max-tokens-per-node', type=int, default=None,
help='Maximum number of tokens per node (PDF only)')
parser.add_argument('--if-add-node-id', type=str, default=None,
help='Whether to add node id to the node')
parser.add_argument('--if-add-node-summary', type=str, default=None,
help='Whether to add summary to the node')
parser.add_argument('--if-add-doc-description', type=str, default=None,
help='Whether to add doc description to the doc')
parser.add_argument('--if-add-node-text', type=str, default=None,
help='Whether to add text to the node')
parser.add_argument('--if-add-node-id', action='store_true', default=None,
help='Add node id to the node')
parser.add_argument('--if-add-node-summary', action='store_true', default=None,
help='Add summary to the node')
parser.add_argument('--if-add-doc-description', action='store_true', default=None,
help='Add doc description to the doc')
parser.add_argument('--if-add-node-text', action='store_true', default=None,
help='Add text to the node')
# Markdown specific arguments
parser.add_argument('--if-thinning', type=str, default='no',
help='Whether to apply tree thinning for markdown (markdown only)')
@ -37,77 +37,61 @@ if __name__ == "__main__":
parser.add_argument('--summary-token-threshold', type=int, default=200,
help='Token threshold for generating summaries (markdown only)')
args = parser.parse_args()
# Validate that exactly one file type is specified
if not args.pdf_path and not args.md_path:
raise ValueError("Either --pdf_path or --md_path must be specified")
if args.pdf_path and args.md_path:
raise ValueError("Only one of --pdf_path or --md_path can be specified")
# Build IndexConfig from CLI args (None values use defaults)
config_overrides = {
k: v for k, v in {
"model": args.model,
"toc_check_page_num": args.toc_check_pages,
"max_page_num_each_node": args.max_pages_per_node,
"max_token_num_each_node": args.max_tokens_per_node,
"if_add_node_id": args.if_add_node_id,
"if_add_node_summary": args.if_add_node_summary,
"if_add_doc_description": args.if_add_doc_description,
"if_add_node_text": args.if_add_node_text,
}.items() if v is not None
}
opt = IndexConfig(**config_overrides)
if args.pdf_path:
# Validate PDF file
if not args.pdf_path.lower().endswith('.pdf'):
raise ValueError("PDF file must have .pdf extension")
if not os.path.isfile(args.pdf_path):
raise ValueError(f"PDF file not found: {args.pdf_path}")
# Process PDF file
user_opt = {
'model': args.model,
'toc_check_page_num': args.toc_check_pages,
'max_page_num_each_node': args.max_pages_per_node,
'max_token_num_each_node': args.max_tokens_per_node,
'if_add_node_id': args.if_add_node_id,
'if_add_node_summary': args.if_add_node_summary,
'if_add_doc_description': args.if_add_doc_description,
'if_add_node_text': args.if_add_node_text,
}
opt = ConfigLoader().load({k: v for k, v in user_opt.items() if v is not None})
# Process the PDF
toc_with_page_number = page_index_main(args.pdf_path, opt)
print('Parsing done, saving to file...')
# Save results
pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0]
pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0]
output_dir = './results'
output_file = f'{output_dir}/{pdf_name}_structure.json'
os.makedirs(output_dir, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(toc_with_page_number, f, indent=2)
print(f'Tree structure saved to: {output_file}')
elif args.md_path:
# Validate Markdown file
if not args.md_path.lower().endswith(('.md', '.markdown')):
raise ValueError("Markdown file must have .md or .markdown extension")
if not os.path.isfile(args.md_path):
raise ValueError(f"Markdown file not found: {args.md_path}")
# Process markdown file
print('Processing markdown file...')
# Process the markdown
import asyncio
# Use ConfigLoader to get consistent defaults (matching PDF behavior)
from pageindex.utils import ConfigLoader
config_loader = ConfigLoader()
# Create options dict with user args
user_opt = {
'model': args.model,
'if_add_node_summary': args.if_add_node_summary,
'if_add_doc_description': args.if_add_doc_description,
'if_add_node_text': args.if_add_node_text,
'if_add_node_id': args.if_add_node_id
}
# Load config with defaults from config.yaml
opt = config_loader.load(user_opt)
toc_with_page_number = asyncio.run(md_to_tree(
md_path=args.md_path,
if_thinning=args.if_thinning.lower() == 'yes',
@ -119,16 +103,16 @@ if __name__ == "__main__":
if_add_node_text=opt.if_add_node_text,
if_add_node_id=opt.if_add_node_id
))
print('Parsing done, saving to file...')
# Save results
md_name = os.path.splitext(os.path.basename(args.md_path))[0]
md_name = os.path.splitext(os.path.basename(args.md_path))[0]
output_dir = './results'
output_file = f'{output_dir}/{md_name}_structure.json'
os.makedirs(output_dir, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False)
print(f'Tree structure saved to: {output_file}')
print(f'Tree structure saved to: {output_file}')

14
tests/test_agent.py Normal file
View file

@ -0,0 +1,14 @@
from pageindex.agent import AgentRunner, SYSTEM_PROMPT
from pageindex.backend.protocol import AgentTools
def test_agent_runner_init():
tools = AgentTools(function_tools=["mock_tool"])
runner = AgentRunner(tools=tools, model="gpt-4o")
assert runner._model == "gpt-4o"
def test_system_prompt_has_tool_instructions():
assert "list_documents" in SYSTEM_PROMPT
assert "get_document_structure" in SYSTEM_PROMPT
assert "get_page_content" in SYSTEM_PROMPT

51
tests/test_client.py Normal file
View file

@ -0,0 +1,51 @@
# tests/sdk/test_client.py
import pytest
from pageindex.client import PageIndexClient, LocalClient, CloudClient
def test_local_client_is_pageindex_client(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
assert isinstance(client, PageIndexClient)
def test_cloud_client_is_pageindex_client():
client = CloudClient(api_key="pi-test")
assert isinstance(client, PageIndexClient)
def test_collection_default_name(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
col = client.collection()
assert col.name == "default"
def test_collection_custom_name(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
col = client.collection("papers")
assert col.name == "papers"
def test_list_collections_empty(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
assert client.list_collections() == []
def test_list_collections_after_create(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
client.collection("papers")
assert "papers" in client.list_collections()
def test_delete_collection(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
client.collection("papers")
client.delete_collection("papers")
assert "papers" not in client.list_collections()
def test_register_parser(tmp_path):
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
class FakeParser:
def supported_extensions(self): return [".txt"]
def parse(self, file_path, **kwargs): pass
client.register_parser(FakeParser())

View file

@ -0,0 +1,16 @@
from pageindex.backend.cloud import CloudBackend, API_BASE
def test_cloud_backend_init():
backend = CloudBackend(api_key="pi-test")
assert backend._api_key == "pi-test"
assert backend._headers["api_key"] == "pi-test"
def test_api_base_url():
assert "pageindex.ai" in API_BASE
def test_get_retrieve_model_is_none():
backend = CloudBackend(api_key="pi-test")
assert backend.get_agent_tools("col").function_tools == []

41
tests/test_collection.py Normal file
View file

@ -0,0 +1,41 @@
# tests/sdk/test_collection.py
import pytest
from unittest.mock import MagicMock
from pageindex.collection import Collection
@pytest.fixture
def col():
backend = MagicMock()
backend.list_documents.return_value = [
{"doc_id": "d1", "doc_name": "paper.pdf", "doc_type": "pdf"}
]
backend.get_document.return_value = {"doc_id": "d1", "doc_name": "paper.pdf"}
backend.add_document.return_value = "d1"
return Collection(name="papers", backend=backend)
def test_add(col):
doc_id = col.add("paper.pdf")
assert doc_id == "d1"
col._backend.add_document.assert_called_once_with("papers", "paper.pdf")
def test_list_documents(col):
docs = col.list_documents()
assert len(docs) == 1
assert docs[0]["doc_id"] == "d1"
def test_get_document(col):
doc = col.get_document("d1")
assert doc["doc_name"] == "paper.pdf"
def test_delete_document(col):
col.delete_document("d1")
col._backend.delete_document.assert_called_once_with("papers", "d1")
def test_name_property(col):
assert col.name == "papers"

28
tests/test_config.py Normal file
View file

@ -0,0 +1,28 @@
# tests/test_config.py
import pytest
from pageindex.config import IndexConfig
def test_defaults():
config = IndexConfig()
assert config.model == "gpt-4o-2024-11-20"
assert config.retrieve_model is None
assert config.toc_check_page_num == 20
def test_overrides():
config = IndexConfig(model="gpt-5.4", retrieve_model="claude-sonnet")
assert config.model == "gpt-5.4"
assert config.retrieve_model == "claude-sonnet"
def test_unknown_key_raises():
with pytest.raises(Exception):
IndexConfig(nonexistent_key="value")
def test_model_copy_with_update():
config = IndexConfig(toc_check_page_num=30)
updated = config.model_copy(update={"model": "gpt-5.4"})
assert updated.model == "gpt-5.4"
assert updated.toc_check_page_num == 30

View file

@ -0,0 +1,45 @@
from pageindex.parser.protocol import ContentNode, ParsedDocument, DocumentParser
def test_content_node_required_fields():
node = ContentNode(content="hello", tokens=5)
assert node.content == "hello"
assert node.tokens == 5
assert node.title is None
assert node.index is None
assert node.level is None
def test_content_node_all_fields():
node = ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1)
assert node.title == "Intro"
assert node.index == 1
assert node.level == 1
def test_parsed_document():
nodes = [ContentNode(content="page1", tokens=100, index=1)]
doc = ParsedDocument(doc_name="test.pdf", nodes=nodes)
assert doc.doc_name == "test.pdf"
assert len(doc.nodes) == 1
assert doc.metadata is None
def test_parsed_document_with_metadata():
nodes = [ContentNode(content="page1", tokens=100)]
doc = ParsedDocument(doc_name="test.pdf", nodes=nodes, metadata={"author": "John"})
assert doc.metadata["author"] == "John"
def test_document_parser_protocol():
"""Verify a class implementing DocumentParser is structurally compatible."""
class MyParser:
def supported_extensions(self) -> list[str]:
return [".txt"]
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
return ParsedDocument(doc_name="test", nodes=[])
parser = MyParser()
assert parser.supported_extensions() == [".txt"]
result = parser.parse("test.txt")
assert isinstance(result, ParsedDocument)

27
tests/test_errors.py Normal file
View file

@ -0,0 +1,27 @@
from pageindex.errors import (
PageIndexError,
CollectionNotFoundError,
DocumentNotFoundError,
IndexingError,
CloudAPIError,
FileTypeError,
)
def test_all_errors_inherit_from_base():
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
assert issubclass(cls, PageIndexError)
assert issubclass(cls, Exception)
def test_error_message():
err = FileTypeError("Unsupported: .docx")
assert str(err) == "Unsupported: .docx"
def test_catch_base_catches_all():
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
try:
raise cls("test")
except PageIndexError:
pass # expected

26
tests/test_events.py Normal file
View file

@ -0,0 +1,26 @@
from pageindex.events import QueryEvent
from pageindex.backend.protocol import AgentTools
def test_query_event():
event = QueryEvent(type="answer_delta", data="hello")
assert event.type == "answer_delta"
assert event.data == "hello"
def test_query_event_types():
for t in ["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"]:
event = QueryEvent(type=t, data="test")
assert event.type == t
def test_agent_tools_default_empty():
tools = AgentTools()
assert tools.function_tools == []
assert tools.mcp_servers == []
def test_agent_tools_with_values():
tools = AgentTools(function_tools=["tool1"], mcp_servers=["server1"])
assert len(tools.function_tools) == 1
assert len(tools.mcp_servers) == 1

View file

@ -0,0 +1,50 @@
# tests/sdk/test_local_backend.py
import pytest
from pathlib import Path
from pageindex.backend.local import LocalBackend
from pageindex.storage.sqlite import SQLiteStorage
from pageindex.errors import FileTypeError
@pytest.fixture
def backend(tmp_path):
storage = SQLiteStorage(str(tmp_path / "test.db"))
files_dir = tmp_path / "files"
return LocalBackend(storage=storage, files_dir=str(files_dir), model="gpt-4o")
def test_collection_lifecycle(backend):
backend.get_or_create_collection("papers")
assert "papers" in backend.list_collections()
backend.delete_collection("papers")
assert "papers" not in backend.list_collections()
def test_list_documents_empty(backend):
backend.get_or_create_collection("papers")
assert backend.list_documents("papers") == []
def test_unsupported_file_type_raises(backend, tmp_path):
backend.get_or_create_collection("papers")
bad_file = tmp_path / "test.xyz"
bad_file.write_text("hello")
with pytest.raises(FileTypeError):
backend.add_document("papers", str(bad_file))
def test_register_custom_parser(backend):
from pageindex.parser.protocol import ParsedDocument, ContentNode
class TxtParser:
def supported_extensions(self):
return [".txt"]
def parse(self, file_path, **kwargs):
text = Path(file_path).read_text()
return ParsedDocument(doc_name="test", nodes=[
ContentNode(content=text, tokens=len(text.split()), title="Content", index=1, level=1)
])
backend.register_parser(TxtParser())
# Now .txt should be supported (won't raise FileTypeError)
assert backend._resolve_parser("test.txt") is not None

View file

@ -0,0 +1,55 @@
import pytest
from pathlib import Path
from pageindex.parser.markdown import MarkdownParser
from pageindex.parser.protocol import ContentNode, ParsedDocument
@pytest.fixture
def sample_md(tmp_path):
md = tmp_path / "test.md"
md.write_text("""# Chapter 1
Some intro text.
## Section 1.1
Details here.
## Section 1.2
More details.
# Chapter 2
Another chapter.
""")
return str(md)
def test_supported_extensions():
parser = MarkdownParser()
exts = parser.supported_extensions()
assert ".md" in exts
assert ".markdown" in exts
def test_parse_returns_parsed_document(sample_md):
parser = MarkdownParser()
result = parser.parse(sample_md)
assert isinstance(result, ParsedDocument)
assert result.doc_name == "test"
def test_parse_nodes_have_level(sample_md):
parser = MarkdownParser()
result = parser.parse(sample_md)
assert len(result.nodes) == 4
assert result.nodes[0].level == 1
assert result.nodes[0].title == "Chapter 1"
assert result.nodes[1].level == 2
assert result.nodes[1].title == "Section 1.1"
assert result.nodes[3].level == 1
def test_parse_nodes_have_content(sample_md):
parser = MarkdownParser()
result = parser.parse(sample_md)
assert "Some intro text" in result.nodes[0].content
assert "Details here" in result.nodes[1].content
def test_parse_nodes_have_index(sample_md):
parser = MarkdownParser()
result = parser.parse(sample_md)
for node in result.nodes:
assert node.index is not None

29
tests/test_pdf_parser.py Normal file
View file

@ -0,0 +1,29 @@
import pytest
from pathlib import Path
from pageindex.parser.pdf import PdfParser
from pageindex.parser.protocol import ContentNode, ParsedDocument
TEST_PDF = Path("tests/pdfs/deepseek-r1.pdf")
def test_supported_extensions():
parser = PdfParser()
assert ".pdf" in parser.supported_extensions()
@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available")
def test_parse_returns_parsed_document():
parser = PdfParser()
result = parser.parse(str(TEST_PDF))
assert isinstance(result, ParsedDocument)
assert len(result.nodes) > 0
assert result.doc_name != ""
@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available")
def test_parse_nodes_are_flat_without_level():
parser = PdfParser()
result = parser.parse(str(TEST_PDF))
for node in result.nodes:
assert isinstance(node, ContentNode)
assert node.content is not None
assert node.tokens >= 0
assert node.index is not None
assert node.level is None

95
tests/test_pipeline.py Normal file
View file

@ -0,0 +1,95 @@
# tests/sdk/test_pipeline.py
import asyncio
from unittest.mock import patch, AsyncMock
from pageindex.parser.protocol import ContentNode, ParsedDocument
from pageindex.index.pipeline import (
detect_strategy, build_tree_from_levels, build_index,
_content_based_pipeline, _NullLogger,
)
def test_detect_strategy_with_level():
nodes = [
ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1),
ContentNode(content="## Details", tokens=10, title="Details", index=5, level=2),
]
assert detect_strategy(nodes) == "level_based"
def test_detect_strategy_without_level():
nodes = [
ContentNode(content="Page 1 text", tokens=100, index=1),
ContentNode(content="Page 2 text", tokens=100, index=2),
]
assert detect_strategy(nodes) == "content_based"
def test_build_tree_from_levels():
nodes = [
ContentNode(content="ch1 text", tokens=10, title="Chapter 1", index=1, level=1),
ContentNode(content="s1.1 text", tokens=10, title="Section 1.1", index=5, level=2),
ContentNode(content="s1.2 text", tokens=10, title="Section 1.2", index=10, level=2),
ContentNode(content="ch2 text", tokens=10, title="Chapter 2", index=20, level=1),
]
tree = build_tree_from_levels(nodes)
assert len(tree) == 2 # 2 root nodes (chapters)
assert tree[0]["title"] == "Chapter 1"
assert len(tree[0]["nodes"]) == 2 # 2 sections under chapter 1
assert tree[0]["nodes"][0]["title"] == "Section 1.1"
assert tree[0]["nodes"][1]["title"] == "Section 1.2"
assert tree[1]["title"] == "Chapter 2"
assert len(tree[1]["nodes"]) == 0
def test_build_tree_from_levels_single_level():
nodes = [
ContentNode(content="a", tokens=5, title="A", index=1, level=1),
ContentNode(content="b", tokens=5, title="B", index=2, level=1),
]
tree = build_tree_from_levels(nodes)
assert len(tree) == 2
assert tree[0]["title"] == "A"
assert tree[1]["title"] == "B"
def test_build_tree_from_levels_deep_nesting():
nodes = [
ContentNode(content="h1", tokens=5, title="H1", index=1, level=1),
ContentNode(content="h2", tokens=5, title="H2", index=2, level=2),
ContentNode(content="h3", tokens=5, title="H3", index=3, level=3),
]
tree = build_tree_from_levels(nodes)
assert len(tree) == 1
assert tree[0]["title"] == "H1"
assert len(tree[0]["nodes"]) == 1
assert tree[0]["nodes"][0]["title"] == "H2"
assert len(tree[0]["nodes"][0]["nodes"]) == 1
assert tree[0]["nodes"][0]["nodes"][0]["title"] == "H3"
def test_content_based_pipeline_does_not_raise():
"""_content_based_pipeline should delegate to tree_parser, not raise NotImplementedError."""
fake_tree = [{"title": "Intro", "start_index": 1, "end_index": 2, "nodes": []}]
async def fake_tree_parser(page_list, opt, doc=None, logger=None):
return fake_tree
page_list = [("Page 1 text", 50), ("Page 2 text", 60)]
from types import SimpleNamespace
opt = SimpleNamespace(model="test-model")
with patch("pageindex.index.page_index.tree_parser", new=fake_tree_parser):
result = asyncio.run(_content_based_pipeline(page_list, opt))
assert result == fake_tree
def test_null_logger_methods():
"""NullLogger should have info/error/debug and not raise."""
logger = _NullLogger()
logger.info("test message")
logger.error("test error")
logger.debug("test debug")
logger.info({"key": "value"})

View file

@ -0,0 +1,61 @@
import pytest
from pageindex.storage.sqlite import SQLiteStorage
@pytest.fixture
def storage(tmp_path):
return SQLiteStorage(str(tmp_path / "test.db"))
def test_create_and_list_collections(storage):
storage.create_collection("papers")
assert "papers" in storage.list_collections()
def test_get_or_create_collection_idempotent(storage):
storage.get_or_create_collection("papers")
storage.get_or_create_collection("papers")
assert storage.list_collections().count("papers") == 1
def test_delete_collection(storage):
storage.create_collection("papers")
storage.delete_collection("papers")
assert "papers" not in storage.list_collections()
def test_save_and_get_document(storage):
storage.create_collection("papers")
doc = {
"doc_name": "test.pdf", "doc_description": "A test",
"file_path": "/tmp/test.pdf", "doc_type": "pdf",
"structure": [{"title": "Intro", "node_id": "0001"}],
}
storage.save_document("papers", "doc-1", doc)
result = storage.get_document("papers", "doc-1")
assert result["doc_name"] == "test.pdf"
assert result["doc_type"] == "pdf"
def test_get_document_structure(storage):
storage.create_collection("papers")
structure = [{"title": "Ch1", "node_id": "0001", "nodes": []}]
storage.save_document("papers", "doc-1", {
"doc_name": "test.pdf", "doc_type": "pdf",
"file_path": "/tmp/test.pdf", "structure": structure,
})
result = storage.get_document_structure("papers", "doc-1")
assert result[0]["title"] == "Ch1"
def test_list_documents(storage):
storage.create_collection("papers")
storage.save_document("papers", "doc-1", {"doc_name": "p1.pdf", "doc_type": "pdf", "file_path": "/tmp/p1.pdf", "structure": []})
storage.save_document("papers", "doc-2", {"doc_name": "p2.pdf", "doc_type": "pdf", "file_path": "/tmp/p2.pdf", "structure": []})
docs = storage.list_documents("papers")
assert len(docs) == 2
def test_delete_document(storage):
storage.create_collection("papers")
storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []})
storage.delete_document("papers", "doc-1")
assert len(storage.list_documents("papers")) == 0
def test_delete_collection_cascades_documents(storage):
storage.create_collection("papers")
storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []})
storage.delete_collection("papers")
assert "papers" not in storage.list_collections()

View file

@ -0,0 +1,19 @@
from pageindex.storage.protocol import StorageEngine
def test_storage_engine_is_protocol():
class FakeStorage:
def create_collection(self, name: str) -> None: pass
def get_or_create_collection(self, name: str) -> None: pass
def list_collections(self) -> list[str]: return []
def delete_collection(self, name: str) -> None: pass
def save_document(self, collection: str, doc_id: str, doc: dict) -> None: pass
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: return None
def get_document(self, collection: str, doc_id: str) -> dict: return {}
def get_document_structure(self, collection: str, doc_id: str) -> dict: return {}
def get_pages(self, collection: str, doc_id: str) -> list | None: return None
def list_documents(self, collection: str) -> list[dict]: return []
def delete_document(self, collection: str, doc_id: str) -> None: pass
def close(self) -> None: pass
storage = FakeStorage()
assert isinstance(storage, StorageEngine)