mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
feat: add PageIndex SDK with local/cloud dual-mode support (#207)
This commit is contained in:
parent
f2dcffc0b7
commit
c7fe93bb56
45 changed files with 4225 additions and 274 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -4,3 +4,7 @@ __pycache__
|
|||
.env*
|
||||
.venv/
|
||||
logs/
|
||||
pageindex.egg-info/
|
||||
*.db
|
||||
venv/
|
||||
uv.lock
|
||||
|
|
|
|||
62
examples/cloud_demo.py
Normal file
62
examples/cloud_demo.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
Agentic Vectorless RAG with PageIndex SDK - Cloud Demo
|
||||
|
||||
Uses CloudClient for fully-managed document indexing and QA.
|
||||
No LLM API key needed — the cloud service handles everything.
|
||||
|
||||
Steps:
|
||||
1 — Upload and index a PDF via PageIndex cloud
|
||||
2 — Stream a question with tool call visibility
|
||||
|
||||
Requirements:
|
||||
pip install pageindex
|
||||
export PAGEINDEX_API_KEY=your-api-key
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from pageindex import CloudClient
|
||||
|
||||
_EXAMPLES_DIR = Path(__file__).parent
|
||||
PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf"
|
||||
PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf"
|
||||
|
||||
# Download PDF if needed
|
||||
if not PDF_PATH.exists():
|
||||
print(f"Downloading {PDF_URL} ...")
|
||||
PDF_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with requests.get(PDF_URL, stream=True, timeout=30) as r:
|
||||
r.raise_for_status()
|
||||
with open(PDF_PATH, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
print("Download complete.\n")
|
||||
|
||||
client = CloudClient(api_key=os.environ["PAGEINDEX_API_KEY"])
|
||||
col = client.collection()
|
||||
|
||||
doc_id = col.add(str(PDF_PATH))
|
||||
print(f"Indexed: {doc_id}\n")
|
||||
|
||||
# Streaming query
|
||||
stream = col.query("What is the main contribution of this paper?", stream=True)
|
||||
|
||||
async def main():
|
||||
streamed_text = False
|
||||
async for event in stream:
|
||||
if event.type == "answer_delta":
|
||||
print(event.data, end="", flush=True)
|
||||
streamed_text = True
|
||||
elif event.type == "tool_call":
|
||||
if streamed_text:
|
||||
print()
|
||||
streamed_text = False
|
||||
args = event.data.get("args", "")
|
||||
print(f"[tool call] {event.data['name']}({args})")
|
||||
elif event.type == "answer_done":
|
||||
print()
|
||||
streamed_text = False
|
||||
|
||||
asyncio.run(main())
|
||||
69
examples/local_demo.py
Normal file
69
examples/local_demo.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Agentic Vectorless RAG with PageIndex SDK - Local Demo
|
||||
|
||||
A simple example of using LocalClient for self-hosted document indexing
|
||||
and agent-based QA. The agent uses OpenAI Agents SDK to reason over
|
||||
the document's tree structure index.
|
||||
|
||||
Steps:
|
||||
1 — Download and index a PDF
|
||||
2 — Stream a question with tool call visibility
|
||||
|
||||
Requirements:
|
||||
pip install pageindex
|
||||
export OPENAI_API_KEY=your-api-key # or any LiteLLM-supported provider
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from pageindex import LocalClient
|
||||
|
||||
_EXAMPLES_DIR = Path(__file__).parent
|
||||
PDF_URL = "https://arxiv.org/pdf/1706.03762.pdf"
|
||||
PDF_PATH = _EXAMPLES_DIR / "documents" / "attention.pdf"
|
||||
WORKSPACE = _EXAMPLES_DIR / "workspace"
|
||||
MODEL = "gpt-4o-2024-11-20" # any LiteLLM-supported model
|
||||
|
||||
# Download PDF if needed
|
||||
if not PDF_PATH.exists():
|
||||
print(f"Downloading {PDF_URL} ...")
|
||||
PDF_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with requests.get(PDF_URL, stream=True, timeout=30) as r:
|
||||
r.raise_for_status()
|
||||
with open(PDF_PATH, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
print("Download complete.\n")
|
||||
|
||||
client = LocalClient(model=MODEL, storage_path=str(WORKSPACE))
|
||||
col = client.collection()
|
||||
|
||||
doc_id = col.add(str(PDF_PATH))
|
||||
print(f"Indexed: {doc_id}\n")
|
||||
|
||||
# Streaming query
|
||||
stream = col.query(
|
||||
"What is the main architecture proposed in this paper and how does self-attention work?",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async def main():
|
||||
streamed_text = False
|
||||
async for event in stream:
|
||||
if event.type == "answer_delta":
|
||||
print(event.data, end="", flush=True)
|
||||
streamed_text = True
|
||||
elif event.type == "tool_call":
|
||||
if streamed_text:
|
||||
print()
|
||||
streamed_text = False
|
||||
print(f"[tool call] {event.data['name']}")
|
||||
elif event.type == "tool_result":
|
||||
preview = str(event.data)[:200] + "..." if len(str(event.data)) > 200 else event.data
|
||||
print(f"[tool output] {preview}")
|
||||
elif event.type == "answer_done":
|
||||
print()
|
||||
streamed_text = False
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,4 +1,40 @@
|
|||
# pageindex/__init__.py
|
||||
# Upstream exports (backward compatibility)
|
||||
from .page_index import *
|
||||
from .page_index_md import md_to_tree
|
||||
from .retrieve import get_document, get_document_structure, get_page_content
|
||||
from .client import PageIndexClient
|
||||
|
||||
# SDK exports
|
||||
from .client import PageIndexClient, LocalClient, CloudClient
|
||||
from .config import IndexConfig
|
||||
from .collection import Collection
|
||||
from .parser.protocol import ContentNode, ParsedDocument, DocumentParser
|
||||
from .storage.protocol import StorageEngine
|
||||
from .events import QueryEvent
|
||||
from .errors import (
|
||||
PageIndexError,
|
||||
CollectionNotFoundError,
|
||||
DocumentNotFoundError,
|
||||
IndexingError,
|
||||
CloudAPIError,
|
||||
FileTypeError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PageIndexClient",
|
||||
"LocalClient",
|
||||
"CloudClient",
|
||||
"IndexConfig",
|
||||
"Collection",
|
||||
"ContentNode",
|
||||
"ParsedDocument",
|
||||
"DocumentParser",
|
||||
"StorageEngine",
|
||||
"QueryEvent",
|
||||
"PageIndexError",
|
||||
"CollectionNotFoundError",
|
||||
"DocumentNotFoundError",
|
||||
"IndexingError",
|
||||
"CloudAPIError",
|
||||
"FileTypeError",
|
||||
]
|
||||
|
|
|
|||
93
pageindex/agent.py
Normal file
93
pageindex/agent.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
# pageindex/agent.py
|
||||
from __future__ import annotations
|
||||
from typing import AsyncIterator
|
||||
from .events import QueryEvent
|
||||
from .backend.protocol import AgentTools
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are PageIndex, a document QA assistant.
|
||||
TOOL USE:
|
||||
- Call list_documents() to see available documents.
|
||||
- Call get_document(doc_id) to confirm status and page/line count.
|
||||
- Call get_document_structure(doc_id) to identify relevant page ranges.
|
||||
- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document.
|
||||
- Before each tool call, output one short sentence explaining the reason.
|
||||
IMAGES:
|
||||
- Page content may contain image references like . Always preserve these in your answer so the downstream UI can render them.
|
||||
- Place images near the relevant context in your answer.
|
||||
Answer based only on tool output. Be concise.
|
||||
"""
|
||||
|
||||
|
||||
class QueryStream:
|
||||
"""Streaming query result, similar to OpenAI's RunResultStreaming.
|
||||
|
||||
Usage:
|
||||
stream = col.query("question", stream=True)
|
||||
async for event in stream:
|
||||
if event.type == "answer_delta":
|
||||
print(event.data, end="", flush=True)
|
||||
"""
|
||||
|
||||
def __init__(self, tools: AgentTools, question: str, model: str = None):
|
||||
from agents import Agent
|
||||
from agents.model_settings import ModelSettings
|
||||
self._agent = Agent(
|
||||
name="PageIndex",
|
||||
instructions=SYSTEM_PROMPT,
|
||||
tools=tools.function_tools,
|
||||
mcp_servers=tools.mcp_servers,
|
||||
model=model,
|
||||
model_settings=ModelSettings(parallel_tool_calls=False),
|
||||
)
|
||||
self._question = question
|
||||
|
||||
async def stream_events(self) -> AsyncIterator[QueryEvent]:
|
||||
"""Async generator yielding QueryEvent as they arrive."""
|
||||
from agents import Runner, ItemHelpers
|
||||
from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
|
||||
from openai.types.responses import ResponseTextDeltaEvent
|
||||
|
||||
streamed_run = Runner.run_streamed(self._agent, self._question)
|
||||
async for event in streamed_run.stream_events():
|
||||
if isinstance(event, RawResponsesStreamEvent):
|
||||
if isinstance(event.data, ResponseTextDeltaEvent):
|
||||
yield QueryEvent(type="answer_delta", data=event.data.delta)
|
||||
elif isinstance(event, RunItemStreamEvent):
|
||||
item = event.item
|
||||
if item.type == "tool_call_item":
|
||||
raw = item.raw_item
|
||||
yield QueryEvent(type="tool_call", data={
|
||||
"name": raw.name, "args": getattr(raw, "arguments", "{}"),
|
||||
})
|
||||
elif item.type == "tool_call_output_item":
|
||||
yield QueryEvent(type="tool_result", data=str(item.output))
|
||||
elif item.type == "message_output_item":
|
||||
text = ItemHelpers.text_message_output(item)
|
||||
if text:
|
||||
yield QueryEvent(type="answer_done", data=text)
|
||||
|
||||
def __aiter__(self):
|
||||
return self.stream_events()
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
def __init__(self, tools: AgentTools, model: str = None):
|
||||
self._tools = tools
|
||||
self._model = model
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""Sync non-streaming query. Returns answer string."""
|
||||
from agents import Agent, Runner
|
||||
from agents.model_settings import ModelSettings
|
||||
agent = Agent(
|
||||
name="PageIndex",
|
||||
instructions=SYSTEM_PROMPT,
|
||||
tools=self._tools.function_tools,
|
||||
mcp_servers=self._tools.mcp_servers,
|
||||
model=self._model,
|
||||
model_settings=ModelSettings(parallel_tool_calls=False),
|
||||
)
|
||||
result = Runner.run_sync(agent, question)
|
||||
return result.final_output
|
||||
0
pageindex/backend/__init__.py
Normal file
0
pageindex/backend/__init__.py
Normal file
352
pageindex/backend/cloud.py
Normal file
352
pageindex/backend/cloud.py
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
# pageindex/backend/cloud.py
|
||||
"""CloudBackend — connects to PageIndex cloud service (api.pageindex.ai).
|
||||
|
||||
API reference: https://github.com/VectifyAI/pageindex_sdk
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
import requests
|
||||
from typing import AsyncIterator
|
||||
|
||||
from .protocol import AgentTools
|
||||
from ..errors import CloudAPIError, PageIndexError
|
||||
from ..events import QueryEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_BASE = "https://api.pageindex.ai"
|
||||
|
||||
_INTERNAL_TOOLS = frozenset({"ToolSearch", "Read", "Grep", "Glob", "Bash", "Edit", "Write"})
|
||||
|
||||
|
||||
class CloudBackend:
|
||||
def __init__(self, api_key: str):
|
||||
self._api_key = api_key
|
||||
self._headers = {"api_key": api_key}
|
||||
self._folder_id_cache: dict[str, str | None] = {}
|
||||
self._folder_warning_shown = False
|
||||
|
||||
# ── HTTP helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _warn_folder_upgrade(self) -> None:
|
||||
if not self._folder_warning_shown:
|
||||
logger.warning(
|
||||
"Folders (collections) require a Max plan. "
|
||||
"All documents are stored in a single global space — collection names are ignored. "
|
||||
"Upgrade at https://dash.pageindex.ai/subscription"
|
||||
)
|
||||
self._folder_warning_shown = True
|
||||
|
||||
def _request(self, method: str, path: str, **kwargs) -> dict:
|
||||
url = f"{API_BASE}{path}"
|
||||
for attempt in range(3):
|
||||
try:
|
||||
resp = requests.request(method, url, headers=self._headers, timeout=30, **kwargs)
|
||||
if resp.status_code in (429, 500, 502, 503):
|
||||
logger.warning("Cloud API %s %s returned %d, retrying...", method, path, resp.status_code)
|
||||
time.sleep(2 ** attempt)
|
||||
continue
|
||||
if resp.status_code != 200:
|
||||
body = resp.text[:500] if resp.text else ""
|
||||
raise CloudAPIError(f"Cloud API error {resp.status_code}: {body}")
|
||||
return resp.json() if resp.content else {}
|
||||
except requests.RequestException as e:
|
||||
if attempt == 2:
|
||||
raise CloudAPIError(f"Cloud API request failed: {e}") from e
|
||||
time.sleep(2 ** attempt)
|
||||
raise CloudAPIError("Max retries exceeded")
|
||||
|
||||
@staticmethod
|
||||
def _validate_collection_name(name: str) -> None:
|
||||
if not re.match(r'^[a-zA-Z0-9_-]{1,128}$', name):
|
||||
raise PageIndexError(
|
||||
f"Invalid collection name: {name!r}. "
|
||||
"Must be 1-128 chars of [a-zA-Z0-9_-]."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _enc(value: str) -> str:
|
||||
return urllib.parse.quote(value, safe="")
|
||||
|
||||
# ── Collection management (mapped to folders) ─────────────────────────
|
||||
|
||||
def create_collection(self, name: str) -> None:
|
||||
self._validate_collection_name(name)
|
||||
try:
|
||||
resp = self._request("POST", "/folder/", json={"name": name})
|
||||
self._folder_id_cache[name] = resp.get("folder", {}).get("id")
|
||||
except CloudAPIError as e:
|
||||
if "403" in str(e):
|
||||
self._warn_folder_upgrade()
|
||||
self._folder_id_cache[name] = None
|
||||
else:
|
||||
raise
|
||||
|
||||
def get_or_create_collection(self, name: str) -> None:
|
||||
self._validate_collection_name(name)
|
||||
try:
|
||||
data = self._request("GET", "/folders/")
|
||||
for folder in data.get("folders", []):
|
||||
if folder.get("name") == name:
|
||||
self._folder_id_cache[name] = folder["id"]
|
||||
return
|
||||
resp = self._request("POST", "/folder/", json={"name": name})
|
||||
self._folder_id_cache[name] = resp.get("folder", {}).get("id")
|
||||
except CloudAPIError as e:
|
||||
if "403" in str(e):
|
||||
self._warn_folder_upgrade()
|
||||
self._folder_id_cache[name] = None
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_folder_id(self, name: str) -> str | None:
|
||||
"""Resolve collection name to folder ID. Returns None if folders not available."""
|
||||
if name in self._folder_id_cache:
|
||||
return self._folder_id_cache.get(name)
|
||||
try:
|
||||
data = self._request("GET", "/folders/")
|
||||
for folder in data.get("folders", []):
|
||||
if folder.get("name") == name:
|
||||
self._folder_id_cache[name] = folder["id"]
|
||||
return folder["id"]
|
||||
except CloudAPIError:
|
||||
pass
|
||||
self._folder_id_cache[name] = None
|
||||
return None
|
||||
|
||||
def list_collections(self) -> list[str]:
|
||||
data = self._request("GET", "/folders/")
|
||||
return [f["name"] for f in data.get("folders", [])]
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
folder_id = self._get_folder_id(name)
|
||||
if folder_id:
|
||||
self._request("DELETE", f"/folder/{self._enc(folder_id)}/")
|
||||
|
||||
# ── Document management ───────────────────────────────────────────────
|
||||
|
||||
def add_document(self, collection: str, file_path: str) -> str:
|
||||
folder_id = self._get_folder_id(collection)
|
||||
data = {"if_retrieval": "true"}
|
||||
if folder_id:
|
||||
data["folder_id"] = folder_id
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
resp = self._request("POST", "/doc/", files={"file": f}, data=data)
|
||||
|
||||
doc_id = resp["doc_id"]
|
||||
|
||||
# Poll until retrieval-ready
|
||||
for _ in range(120): # 10 min max
|
||||
tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree"})
|
||||
if tree_resp.get("retrieval_ready"):
|
||||
return doc_id
|
||||
status = tree_resp.get("status", "")
|
||||
if status == "failed":
|
||||
raise CloudAPIError(f"Document {doc_id} indexing failed")
|
||||
time.sleep(5)
|
||||
|
||||
raise CloudAPIError(f"Document {doc_id} indexing timed out")
|
||||
|
||||
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict:
|
||||
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/metadata/")
|
||||
# Fetch structure in the same call via tree endpoint
|
||||
tree_resp = self._request("GET", f"/doc/{self._enc(doc_id)}/",
|
||||
params={"type": "tree", "summary": "true"})
|
||||
raw_tree = tree_resp.get("tree", tree_resp.get("structure", tree_resp.get("result", [])))
|
||||
return {
|
||||
"doc_id": resp.get("id", doc_id),
|
||||
"doc_name": resp.get("name", ""),
|
||||
"doc_description": resp.get("description", ""),
|
||||
"doc_type": "pdf",
|
||||
"status": resp.get("status", ""),
|
||||
"structure": self._normalize_tree(raw_tree),
|
||||
}
|
||||
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> list:
|
||||
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "tree", "summary": "true"})
|
||||
raw_tree = resp.get("tree", resp.get("structure", resp.get("result", [])))
|
||||
return self._normalize_tree(raw_tree)
|
||||
|
||||
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list:
|
||||
resp = self._request("GET", f"/doc/{self._enc(doc_id)}/", params={"type": "ocr", "format": "page"})
|
||||
# Filter to requested pages
|
||||
from ..index.utils import parse_pages
|
||||
page_nums = set(parse_pages(pages))
|
||||
all_pages = resp.get("pages", resp.get("ocr", resp.get("result", [])))
|
||||
if isinstance(all_pages, list):
|
||||
return [
|
||||
{"page": p.get("page", p.get("page_index")),
|
||||
"content": p.get("content", p.get("markdown", ""))}
|
||||
for p in all_pages
|
||||
if p.get("page", p.get("page_index")) in page_nums
|
||||
]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tree(nodes: list) -> list:
|
||||
"""Normalize cloud tree nodes to match local schema."""
|
||||
result = []
|
||||
for node in nodes:
|
||||
normalized = {
|
||||
"title": node.get("title", ""),
|
||||
"node_id": node.get("node_id", ""),
|
||||
"summary": node.get("summary", node.get("prefix_summary", "")),
|
||||
"start_index": node.get("start_index", node.get("page_index")),
|
||||
"end_index": node.get("end_index", node.get("page_index")),
|
||||
}
|
||||
if "text" in node:
|
||||
normalized["text"] = node["text"]
|
||||
children = node.get("nodes", [])
|
||||
if children:
|
||||
normalized["nodes"] = CloudBackend._normalize_tree(children)
|
||||
result.append(normalized)
|
||||
return result
|
||||
|
||||
def list_documents(self, collection: str) -> list[dict]:
|
||||
folder_id = self._get_folder_id(collection)
|
||||
params = {"limit": 100}
|
||||
if folder_id:
|
||||
params["folder_id"] = folder_id
|
||||
data = self._request("GET", "/docs/", params=params)
|
||||
return [
|
||||
{"doc_id": d.get("id", ""), "doc_name": d.get("name", ""), "doc_type": "pdf"}
|
||||
for d in data.get("documents", [])
|
||||
]
|
||||
|
||||
def delete_document(self, collection: str, doc_id: str) -> None:
|
||||
self._request("DELETE", f"/doc/{self._enc(doc_id)}/")
|
||||
|
||||
# ── Query (uses cloud chat/completions, no LLM key needed) ────────────
|
||||
|
||||
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
|
||||
"""Non-streaming query via cloud chat/completions."""
|
||||
doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection)
|
||||
resp = self._request("POST", "/chat/completions/", json={
|
||||
"messages": [{"role": "user", "content": question}],
|
||||
"doc_id": doc_id,
|
||||
"stream": False,
|
||||
})
|
||||
# Extract answer from response
|
||||
choices = resp.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "")
|
||||
return resp.get("content", resp.get("answer", ""))
|
||||
|
||||
async def query_stream(self, collection: str, question: str,
|
||||
doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]:
|
||||
"""Streaming query via cloud chat/completions SSE.
|
||||
|
||||
Events are yielded in real-time as they arrive from the server.
|
||||
A background thread handles the blocking HTTP stream and pushes
|
||||
events through an asyncio.Queue for true async streaming.
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
doc_id = doc_ids if doc_ids else self._get_all_doc_ids(collection)
|
||||
headers = self._headers
|
||||
queue: asyncio.Queue[QueryEvent | None] = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _stream():
|
||||
"""Background thread: read SSE and push events to queue."""
|
||||
resp = requests.post(
|
||||
f"{API_BASE}/chat/completions/",
|
||||
headers=headers,
|
||||
json={
|
||||
"messages": [{"role": "user", "content": question}],
|
||||
"doc_id": doc_id,
|
||||
"stream": True,
|
||||
"stream_metadata": True,
|
||||
},
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
try:
|
||||
if resp.status_code != 200:
|
||||
body = resp.text[:500] if resp.text else ""
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait,
|
||||
QueryEvent(type="answer_done",
|
||||
data=f"Cloud streaming error {resp.status_code}: {body}"),
|
||||
)
|
||||
return
|
||||
|
||||
current_tool_name = None
|
||||
current_tool_args: list[str] = []
|
||||
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
meta = chunk.get("block_metadata", {})
|
||||
block_type = meta.get("type", "")
|
||||
choices = chunk.get("choices", [])
|
||||
delta = choices[0].get("delta", {}) if choices else {}
|
||||
content = delta.get("content", "")
|
||||
|
||||
if block_type == "mcp_tool_use_start":
|
||||
current_tool_name = meta.get("tool_name", "")
|
||||
current_tool_args = []
|
||||
|
||||
elif block_type == "tool_use":
|
||||
if content:
|
||||
current_tool_args.append(content)
|
||||
|
||||
elif block_type == "tool_use_stop":
|
||||
if current_tool_name and current_tool_name not in _INTERNAL_TOOLS:
|
||||
args_str = "".join(current_tool_args)
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait,
|
||||
QueryEvent(type="tool_call", data={
|
||||
"name": current_tool_name,
|
||||
"args": args_str,
|
||||
}),
|
||||
)
|
||||
current_tool_name = None
|
||||
current_tool_args = []
|
||||
|
||||
elif block_type == "text" and content:
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait,
|
||||
QueryEvent(type="answer_delta", data=content),
|
||||
)
|
||||
|
||||
finally:
|
||||
resp.close()
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
|
||||
|
||||
thread = threading.Thread(target=_stream, daemon=True)
|
||||
thread.start()
|
||||
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
|
||||
thread.join(timeout=5)
|
||||
|
||||
def _get_all_doc_ids(self, collection: str) -> list[str]:
|
||||
"""Get all document IDs in a collection."""
|
||||
docs = self.list_documents(collection)
|
||||
return [d["doc_id"] for d in docs]
|
||||
|
||||
# ── Not used in cloud mode ────────────────────────────────────────────
|
||||
|
||||
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
|
||||
"""Not used in cloud mode — query goes through chat/completions."""
|
||||
return AgentTools()
|
||||
245
pageindex/backend/local.py
Normal file
245
pageindex/backend/local.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
# pageindex/backend/local.py
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from ..parser.protocol import DocumentParser, ParsedDocument
|
||||
from ..parser.pdf import PdfParser
|
||||
from ..parser.markdown import MarkdownParser
|
||||
from ..storage.protocol import StorageEngine
|
||||
from ..index.pipeline import build_index
|
||||
from ..index.utils import parse_pages, get_pdf_page_content, get_md_page_content, remove_fields
|
||||
from ..backend.protocol import AgentTools
|
||||
from ..errors import FileTypeError, DocumentNotFoundError, IndexingError, PageIndexError
|
||||
|
||||
_COLLECTION_NAME_RE = re.compile(r'^[a-zA-Z0-9_-]{1,128}$')
|
||||
|
||||
|
||||
class LocalBackend:
|
||||
def __init__(self, storage: StorageEngine, files_dir: str, model: str = None,
|
||||
retrieve_model: str = None, index_config=None):
|
||||
self._storage = storage
|
||||
self._files_dir = Path(files_dir)
|
||||
self._model = model
|
||||
self._retrieve_model = retrieve_model or model
|
||||
self._index_config = index_config
|
||||
self._parsers: list[DocumentParser] = [PdfParser(), MarkdownParser()]
|
||||
|
||||
def register_parser(self, parser: DocumentParser) -> None:
|
||||
self._parsers.insert(0, parser) # user parsers checked first
|
||||
|
||||
def get_retrieve_model(self) -> str | None:
|
||||
return self._retrieve_model
|
||||
|
||||
def _resolve_parser(self, file_path: str) -> DocumentParser:
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
for parser in self._parsers:
|
||||
if ext in parser.supported_extensions():
|
||||
return parser
|
||||
raise FileTypeError(f"No parser for extension: {ext}")
|
||||
|
||||
# Collection management
|
||||
def _validate_collection_name(self, name: str) -> None:
|
||||
if not _COLLECTION_NAME_RE.match(name):
|
||||
raise PageIndexError(f"Invalid collection name: {name!r}. Must be 1-128 chars of [a-zA-Z0-9_-].")
|
||||
|
||||
def create_collection(self, name: str) -> None:
|
||||
self._validate_collection_name(name)
|
||||
self._storage.create_collection(name)
|
||||
|
||||
def get_or_create_collection(self, name: str) -> None:
|
||||
self._validate_collection_name(name)
|
||||
self._storage.get_or_create_collection(name)
|
||||
|
||||
def list_collections(self) -> list[str]:
|
||||
return self._storage.list_collections()
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
self._storage.delete_collection(name)
|
||||
col_dir = self._files_dir / name
|
||||
if col_dir.exists():
|
||||
shutil.rmtree(col_dir)
|
||||
|
||||
@staticmethod
|
||||
def _file_hash(file_path: str) -> str:
|
||||
"""Compute SHA-256 hash of a file."""
|
||||
h = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
# Document management
|
||||
def add_document(self, collection: str, file_path: str) -> str:
|
||||
file_path = os.path.realpath(file_path)
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileTypeError(f"Not a regular file: {file_path}")
|
||||
parser = self._resolve_parser(file_path)
|
||||
|
||||
# Dedup: skip if same file already indexed in this collection
|
||||
file_hash = self._file_hash(file_path)
|
||||
existing_id = self._storage.find_document_by_hash(collection, file_hash)
|
||||
if existing_id:
|
||||
return existing_id
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Copy file to managed directory
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
col_dir = self._files_dir / collection
|
||||
col_dir.mkdir(parents=True, exist_ok=True)
|
||||
managed_path = col_dir / f"{doc_id}{ext}"
|
||||
shutil.copy2(file_path, managed_path)
|
||||
|
||||
try:
|
||||
# Store images alongside the document: files/{collection}/{doc_id}/images/
|
||||
images_dir = str(col_dir / doc_id / "images")
|
||||
parsed = parser.parse(file_path, model=self._model, images_dir=images_dir)
|
||||
result = build_index(parsed, model=self._model, opt=self._index_config)
|
||||
|
||||
# Cache page text for fast retrieval (avoids re-reading files)
|
||||
pages = [{"page": n.index, "content": n.content,
|
||||
**({"images": n.images} if n.images else {})}
|
||||
for n in parsed.nodes if n.content]
|
||||
|
||||
# Strip text from structure to save storage space (PDF only;
|
||||
# markdown needs text in structure for fallback retrieval)
|
||||
doc_type = ext.lstrip(".")
|
||||
if doc_type == "pdf":
|
||||
clean_structure = remove_fields(result["structure"], fields=["text"])
|
||||
else:
|
||||
clean_structure = result["structure"]
|
||||
|
||||
self._storage.save_document(collection, doc_id, {
|
||||
"doc_name": parsed.doc_name,
|
||||
"doc_description": result.get("doc_description", ""),
|
||||
"file_path": str(managed_path),
|
||||
"file_hash": file_hash,
|
||||
"doc_type": doc_type,
|
||||
"structure": clean_structure,
|
||||
"pages": pages,
|
||||
})
|
||||
except Exception as e:
|
||||
managed_path.unlink(missing_ok=True)
|
||||
doc_dir = col_dir / doc_id
|
||||
if doc_dir.exists():
|
||||
shutil.rmtree(doc_dir)
|
||||
raise IndexingError(f"Failed to index {file_path}: {e}") from e
|
||||
|
||||
return doc_id
|
||||
|
||||
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict:
|
||||
"""Get document metadata with structure.
|
||||
|
||||
Args:
|
||||
include_text: If True, populate each structure node's 'text' field
|
||||
from cached page content. WARNING: may be very large — do NOT
|
||||
use in agent/LLM contexts as it can exhaust the context window.
|
||||
"""
|
||||
doc = self._storage.get_document(collection, doc_id)
|
||||
if not doc:
|
||||
return {}
|
||||
doc["structure"] = self._storage.get_document_structure(collection, doc_id)
|
||||
if include_text:
|
||||
pages = self._storage.get_pages(collection, doc_id) or []
|
||||
page_map = {p["page"]: p["content"] for p in pages}
|
||||
self._fill_node_text(doc["structure"], page_map)
|
||||
return doc
|
||||
|
||||
@staticmethod
|
||||
def _fill_node_text(nodes: list, page_map: dict) -> None:
|
||||
"""Recursively fill 'text' on structure nodes from cached page content."""
|
||||
for node in nodes:
|
||||
start = node.get("start_index")
|
||||
end = node.get("end_index")
|
||||
if start is not None and end is not None:
|
||||
node["text"] = "\n".join(
|
||||
page_map.get(p, "") for p in range(start, end + 1)
|
||||
)
|
||||
if "nodes" in node:
|
||||
LocalBackend._fill_node_text(node["nodes"], page_map)
|
||||
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> list:
|
||||
return self._storage.get_document_structure(collection, doc_id)
|
||||
|
||||
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list:
|
||||
doc = self._storage.get_document(collection, doc_id)
|
||||
if not doc:
|
||||
raise DocumentNotFoundError(f"Document {doc_id} not found")
|
||||
page_nums = parse_pages(pages)
|
||||
|
||||
# Try cached pages first (fast, no file I/O)
|
||||
cached_pages = self._storage.get_pages(collection, doc_id)
|
||||
if cached_pages:
|
||||
return [p for p in cached_pages if p["page"] in page_nums]
|
||||
|
||||
# Fallback to reading from file
|
||||
if doc["doc_type"] == "pdf":
|
||||
return get_pdf_page_content(doc["file_path"], page_nums)
|
||||
else:
|
||||
structure = self._storage.get_document_structure(collection, doc_id)
|
||||
return get_md_page_content(structure, page_nums)
|
||||
|
||||
def list_documents(self, collection: str) -> list[dict]:
|
||||
return self._storage.list_documents(collection)
|
||||
|
||||
def delete_document(self, collection: str, doc_id: str) -> None:
|
||||
doc = self._storage.get_document(collection, doc_id)
|
||||
if doc and doc.get("file_path"):
|
||||
Path(doc["file_path"]).unlink(missing_ok=True)
|
||||
# Clean up images directory: files/{collection}/{doc_id}/
|
||||
doc_dir = self._files_dir / collection / doc_id
|
||||
if doc_dir.exists():
|
||||
shutil.rmtree(doc_dir)
|
||||
self._storage.delete_document(collection, doc_id)
|
||||
|
||||
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
|
||||
from agents import function_tool
|
||||
import json
|
||||
storage = self._storage
|
||||
col_name = collection
|
||||
backend = self
|
||||
filter_ids = doc_ids
|
||||
|
||||
@function_tool
|
||||
def list_documents() -> str:
|
||||
"""List all documents in the collection."""
|
||||
docs = storage.list_documents(col_name)
|
||||
if filter_ids:
|
||||
docs = [d for d in docs if d["doc_id"] in filter_ids]
|
||||
return json.dumps(docs)
|
||||
|
||||
@function_tool
|
||||
def get_document(doc_id: str) -> str:
|
||||
"""Get document metadata."""
|
||||
return json.dumps(storage.get_document(col_name, doc_id))
|
||||
|
||||
@function_tool
|
||||
def get_document_structure(doc_id: str) -> str:
|
||||
"""Get document tree structure (without text)."""
|
||||
structure = storage.get_document_structure(col_name, doc_id)
|
||||
return json.dumps(remove_fields(structure, fields=["text"]), ensure_ascii=False)
|
||||
|
||||
@function_tool
|
||||
def get_page_content(doc_id: str, pages: str) -> str:
|
||||
"""Get page content. Use tight ranges: '5-7', '3,8', '12'."""
|
||||
result = backend.get_page_content(col_name, doc_id, pages)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return AgentTools(function_tools=[list_documents, get_document, get_document_structure, get_page_content])
|
||||
|
||||
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
|
||||
from ..agent import AgentRunner
|
||||
tools = self.get_agent_tools(collection, doc_ids)
|
||||
return AgentRunner(tools=tools, model=self._retrieve_model).run(question)
|
||||
|
||||
async def query_stream(self, collection: str, question: str,
|
||||
doc_ids: list[str] | None = None):
|
||||
from ..agent import QueryStream
|
||||
tools = self.get_agent_tools(collection, doc_ids)
|
||||
stream = QueryStream(tools=tools, question=question, model=self._retrieve_model)
|
||||
async for event in stream:
|
||||
yield event
|
||||
34
pageindex/backend/protocol.py
Normal file
34
pageindex/backend/protocol.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, Any, AsyncIterator, runtime_checkable
|
||||
|
||||
from ..events import QueryEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentTools:
|
||||
"""Structured container for agent tool configuration (local mode only)."""
|
||||
function_tools: list[Any] = field(default_factory=list)
|
||||
mcp_servers: list[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Backend(Protocol):
|
||||
# Collection management
|
||||
def create_collection(self, name: str) -> None: ...
|
||||
def get_or_create_collection(self, name: str) -> None: ...
|
||||
def list_collections(self) -> list[str]: ...
|
||||
def delete_collection(self, name: str) -> None: ...
|
||||
|
||||
# Document management
|
||||
def add_document(self, collection: str, file_path: str) -> str: ...
|
||||
def get_document(self, collection: str, doc_id: str, include_text: bool = False) -> dict: ...
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> list: ...
|
||||
def get_page_content(self, collection: str, doc_id: str, pages: str) -> list: ...
|
||||
def list_documents(self, collection: str) -> list[dict]: ...
|
||||
def delete_document(self, collection: str, doc_id: str) -> None: ...
|
||||
|
||||
# Query
|
||||
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str: ...
|
||||
async def query_stream(self, collection: str, question: str,
|
||||
doc_ids: list[str] | None = None) -> AsyncIterator[QueryEvent]: ...
|
||||
|
|
@ -1,18 +1,9 @@
|
|||
import os
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
# pageindex/client.py
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import PyPDF2
|
||||
|
||||
from .page_index import page_index
|
||||
from .page_index_md import md_to_tree
|
||||
from .retrieve import get_document, get_document_structure, get_page_content
|
||||
from .utils import ConfigLoader, remove_fields
|
||||
|
||||
META_INDEX = "_meta.json"
|
||||
from .collection import Collection
|
||||
from .config import IndexConfig
|
||||
from .parser.protocol import DocumentParser
|
||||
|
||||
|
||||
def _normalize_retrieve_model(model: str) -> str:
|
||||
|
|
@ -26,209 +17,145 @@ def _normalize_retrieve_model(model: str) -> str:
|
|||
|
||||
|
||||
class PageIndexClient:
|
||||
"""
|
||||
A client for indexing and retrieving document content.
|
||||
Flow: index() -> get_document() / get_document_structure() / get_page_content()
|
||||
"""PageIndex client — supports both local and cloud modes.
|
||||
|
||||
For agent-based QA, see examples/agentic_vectorless_rag_demo.py.
|
||||
Args:
|
||||
api_key: PageIndex cloud API key. When provided, cloud mode is used
|
||||
and local-only params (model, storage_path, index_config, …) are ignored.
|
||||
model: LLM model for indexing (local mode only, default: gpt-4o-2024-11-20).
|
||||
retrieve_model: LLM model for agent QA (local mode only, default: same as model).
|
||||
storage_path: Directory for SQLite DB and files (local mode only, default: ./.pageindex).
|
||||
storage: Custom StorageEngine instance (local mode only).
|
||||
index_config: Advanced indexing parameters (local mode only, optional).
|
||||
Pass an IndexConfig instance or a dict. Defaults are sensible for most use cases.
|
||||
|
||||
Usage:
|
||||
# Local mode (auto-detected when no api_key)
|
||||
client = PageIndexClient(model="gpt-5.4")
|
||||
|
||||
# Cloud mode (auto-detected when api_key provided)
|
||||
client = PageIndexClient(api_key="your-api-key")
|
||||
|
||||
# Or use LocalClient / CloudClient for explicit mode selection
|
||||
"""
|
||||
def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None):
|
||||
|
||||
def __init__(self, api_key: str = None, model: str = None,
|
||||
retrieve_model: str = None, storage_path: str = None,
|
||||
storage=None, index_config: IndexConfig | dict = None):
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
|
||||
self.workspace = Path(workspace).expanduser() if workspace else None
|
||||
self._init_cloud(api_key)
|
||||
else:
|
||||
self._init_local(model, retrieve_model, storage_path, storage, index_config)
|
||||
|
||||
def _init_cloud(self, api_key: str):
|
||||
from .backend.cloud import CloudBackend
|
||||
self._backend = CloudBackend(api_key=api_key)
|
||||
|
||||
def _init_local(self, model: str = None, retrieve_model: str = None,
|
||||
storage_path: str = None, storage=None,
|
||||
index_config: IndexConfig | dict = None):
|
||||
# Build IndexConfig: merge model/retrieve_model with index_config
|
||||
overrides = {}
|
||||
if model:
|
||||
overrides["model"] = model
|
||||
if retrieve_model:
|
||||
overrides["retrieve_model"] = retrieve_model
|
||||
opt = ConfigLoader().load(overrides or None)
|
||||
self.model = opt.model
|
||||
self.retrieve_model = _normalize_retrieve_model(opt.retrieve_model or self.model)
|
||||
if self.workspace:
|
||||
self.workspace.mkdir(parents=True, exist_ok=True)
|
||||
self.documents = {}
|
||||
if self.workspace:
|
||||
self._load_workspace()
|
||||
|
||||
def index(self, file_path: str, mode: str = "auto") -> str:
|
||||
"""Index a document. Returns a document_id."""
|
||||
# Persist a canonical absolute path so workspace reloads do not
|
||||
# reinterpret caller-relative paths against the workspace directory.
|
||||
file_path = os.path.abspath(os.path.expanduser(file_path))
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
is_pdf = ext == '.pdf'
|
||||
is_md = ext in ['.md', '.markdown']
|
||||
|
||||
if mode == "pdf" or (mode == "auto" and is_pdf):
|
||||
print(f"Indexing PDF: {file_path}")
|
||||
result = page_index(
|
||||
doc=file_path,
|
||||
model=self.model,
|
||||
if_add_node_summary='yes',
|
||||
if_add_node_text='yes',
|
||||
if_add_node_id='yes',
|
||||
if_add_doc_description='yes'
|
||||
)
|
||||
# Extract per-page text so queries don't need the original PDF
|
||||
pages = []
|
||||
with open(file_path, 'rb') as f:
|
||||
pdf_reader = PyPDF2.PdfReader(f)
|
||||
for i, page in enumerate(pdf_reader.pages, 1):
|
||||
pages.append({'page': i, 'content': page.extract_text() or ''})
|
||||
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'type': 'pdf',
|
||||
'path': file_path,
|
||||
'doc_name': result.get('doc_name', ''),
|
||||
'doc_description': result.get('doc_description', ''),
|
||||
'page_count': len(pages),
|
||||
'structure': result['structure'],
|
||||
'pages': pages,
|
||||
}
|
||||
|
||||
elif mode == "md" or (mode == "auto" and is_md):
|
||||
print(f"Indexing Markdown: {file_path}")
|
||||
coro = md_to_tree(
|
||||
md_path=file_path,
|
||||
if_thinning=False,
|
||||
if_add_node_summary='yes',
|
||||
summary_token_threshold=200,
|
||||
model=self.model,
|
||||
if_add_doc_description='yes',
|
||||
if_add_node_text='yes',
|
||||
if_add_node_id='yes'
|
||||
)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
result = pool.submit(asyncio.run, coro).result()
|
||||
except RuntimeError:
|
||||
result = asyncio.run(coro)
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'type': 'md',
|
||||
'path': file_path,
|
||||
'doc_name': result.get('doc_name', ''),
|
||||
'doc_description': result.get('doc_description', ''),
|
||||
'line_count': result.get('line_count', 0),
|
||||
'structure': result['structure'],
|
||||
}
|
||||
if isinstance(index_config, IndexConfig):
|
||||
opt = index_config.model_copy(update=overrides)
|
||||
elif isinstance(index_config, dict):
|
||||
merged = {**index_config, **overrides} # explicit model/retrieve_model win
|
||||
opt = IndexConfig(**merged)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format for: {file_path}")
|
||||
opt = IndexConfig(**overrides) if overrides else IndexConfig()
|
||||
|
||||
print(f"Indexing complete. Document ID: {doc_id}")
|
||||
if self.workspace:
|
||||
self._save_doc(doc_id)
|
||||
return doc_id
|
||||
self._validate_llm_provider(opt.model)
|
||||
|
||||
storage_path = Path(storage_path or ".pageindex").resolve()
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from .storage.sqlite import SQLiteStorage
|
||||
from .backend.local import LocalBackend
|
||||
storage_engine = storage or SQLiteStorage(str(storage_path / "pageindex.db"))
|
||||
self._backend = LocalBackend(
|
||||
storage=storage_engine,
|
||||
files_dir=str(storage_path / "files"),
|
||||
model=opt.model,
|
||||
retrieve_model=_normalize_retrieve_model(opt.retrieve_model or opt.model),
|
||||
index_config=opt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_meta_entry(doc: dict) -> dict:
|
||||
"""Build a lightweight meta entry from a document dict."""
|
||||
entry = {
|
||||
'type': doc.get('type', ''),
|
||||
'doc_name': doc.get('doc_name', ''),
|
||||
'doc_description': doc.get('doc_description', ''),
|
||||
'path': doc.get('path', ''),
|
||||
}
|
||||
if doc.get('type') == 'pdf':
|
||||
entry['page_count'] = doc.get('page_count')
|
||||
elif doc.get('type') == 'md':
|
||||
entry['line_count'] = doc.get('line_count')
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path) -> dict | None:
|
||||
"""Read a JSON file, returning None on any error."""
|
||||
def _validate_llm_provider(model: str) -> None:
|
||||
"""Validate model and check API key via litellm. Warns if key seems missing."""
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Warning: corrupt {Path(path).name}: {e}")
|
||||
return None
|
||||
|
||||
def _save_doc(self, doc_id: str):
|
||||
doc = self.documents[doc_id].copy()
|
||||
# Strip text from structure nodes — redundant with pages (PDF only)
|
||||
if doc.get('structure') and doc.get('type') == 'pdf':
|
||||
doc['structure'] = remove_fields(doc['structure'], fields=['text'])
|
||||
path = self.workspace / f"{doc_id}.json"
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(doc, f, ensure_ascii=False, indent=2)
|
||||
self._save_meta(doc_id, self._make_meta_entry(doc))
|
||||
# Drop heavy fields; will lazy-load on demand
|
||||
self.documents[doc_id].pop('structure', None)
|
||||
self.documents[doc_id].pop('pages', None)
|
||||
|
||||
def _rebuild_meta(self) -> dict:
|
||||
"""Scan individual doc JSON files and return a meta dict."""
|
||||
meta = {}
|
||||
for path in self.workspace.glob("*.json"):
|
||||
if path.name == META_INDEX:
|
||||
continue
|
||||
doc = self._read_json(path)
|
||||
if doc and isinstance(doc, dict):
|
||||
meta[path.stem] = self._make_meta_entry(doc)
|
||||
return meta
|
||||
|
||||
def _read_meta(self) -> dict | None:
|
||||
"""Read and validate _meta.json, returning None on any corruption."""
|
||||
meta = self._read_json(self.workspace / META_INDEX)
|
||||
if meta is not None and not isinstance(meta, dict):
|
||||
print(f"Warning: {META_INDEX} is not a JSON object, ignoring")
|
||||
return None
|
||||
return meta
|
||||
|
||||
def _save_meta(self, doc_id: str, entry: dict):
|
||||
meta = self._read_meta() or self._rebuild_meta()
|
||||
meta[doc_id] = entry
|
||||
meta_path = self.workspace / META_INDEX
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _load_workspace(self):
|
||||
meta = self._read_meta()
|
||||
if meta is None:
|
||||
meta = self._rebuild_meta()
|
||||
if meta:
|
||||
print(f"Loaded {len(meta)} document(s) from workspace (legacy mode).")
|
||||
for doc_id, entry in meta.items():
|
||||
doc = dict(entry, id=doc_id)
|
||||
if doc.get('path') and not os.path.isabs(doc['path']):
|
||||
doc['path'] = str((self.workspace / doc['path']).resolve())
|
||||
self.documents[doc_id] = doc
|
||||
|
||||
def _ensure_doc_loaded(self, doc_id: str):
|
||||
"""Load full document JSON on demand (structure, pages, etc.)."""
|
||||
doc = self.documents.get(doc_id)
|
||||
if not doc or doc.get('structure') is not None:
|
||||
import litellm
|
||||
litellm.model_cost_map_url = ""
|
||||
_, provider, _, _ = litellm.get_llm_provider(model=model)
|
||||
except Exception:
|
||||
return
|
||||
full = self._read_json(self.workspace / f"{doc_id}.json")
|
||||
if not full:
|
||||
return
|
||||
doc['structure'] = full.get('structure', [])
|
||||
if full.get('pages'):
|
||||
doc['pages'] = full['pages']
|
||||
|
||||
def get_document(self, doc_id: str) -> str:
|
||||
"""Return document metadata JSON."""
|
||||
return get_document(self.documents, doc_id)
|
||||
key = litellm.get_api_key(llm_provider=provider, dynamic_api_key=None)
|
||||
if not key:
|
||||
import os
|
||||
common_var = f"{provider.upper()}_API_KEY"
|
||||
if not os.getenv(common_var):
|
||||
from .errors import PageIndexError
|
||||
raise PageIndexError(
|
||||
f"API key not configured for provider '{provider}' (model: {model}). "
|
||||
f"Set the {common_var} environment variable."
|
||||
)
|
||||
|
||||
def get_document_structure(self, doc_id: str) -> str:
|
||||
"""Return document tree structure JSON (without text fields)."""
|
||||
if self.workspace:
|
||||
self._ensure_doc_loaded(doc_id)
|
||||
return get_document_structure(self.documents, doc_id)
|
||||
def collection(self, name: str = "default") -> Collection:
|
||||
"""Get or create a collection. Defaults to 'default'."""
|
||||
self._backend.get_or_create_collection(name)
|
||||
return Collection(name=name, backend=self._backend)
|
||||
|
||||
def get_page_content(self, doc_id: str, pages: str) -> str:
|
||||
"""Return page content for the given pages string (e.g. '5-7', '3,8', '12')."""
|
||||
if self.workspace:
|
||||
self._ensure_doc_loaded(doc_id)
|
||||
return get_page_content(self.documents, doc_id, pages)
|
||||
def list_collections(self) -> list[str]:
|
||||
return self._backend.list_collections()
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
self._backend.delete_collection(name)
|
||||
|
||||
def register_parser(self, parser: DocumentParser) -> None:
|
||||
"""Register a custom document parser. Only available in local mode."""
|
||||
if not hasattr(self._backend, 'register_parser'):
|
||||
from .errors import PageIndexError
|
||||
raise PageIndexError("Custom parsers are not supported in cloud mode")
|
||||
self._backend.register_parser(parser)
|
||||
|
||||
|
||||
class LocalClient(PageIndexClient):
|
||||
"""Local mode — indexes and queries documents on your machine.
|
||||
|
||||
Args:
|
||||
model: LLM model for indexing (default: gpt-4o-2024-11-20)
|
||||
retrieve_model: LLM model for agent QA (default: same as model)
|
||||
storage_path: Directory for SQLite DB and files (default: ./.pageindex)
|
||||
storage: Custom StorageEngine instance (default: SQLiteStorage)
|
||||
index_config: Advanced indexing parameters. Pass an IndexConfig instance
|
||||
or a dict. All fields have sensible defaults — most users don't need this.
|
||||
|
||||
Example::
|
||||
|
||||
# Simple — defaults are fine
|
||||
client = LocalClient(model="gpt-5.4")
|
||||
|
||||
# Advanced — tune indexing parameters
|
||||
from pageindex.config import IndexConfig
|
||||
client = LocalClient(
|
||||
model="gpt-5.4",
|
||||
index_config=IndexConfig(toc_check_page_num=30),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = None, retrieve_model: str = None,
|
||||
storage_path: str = None, storage=None,
|
||||
index_config: IndexConfig | dict = None):
|
||||
self._init_local(model, retrieve_model, storage_path, storage, index_config)
|
||||
|
||||
|
||||
class CloudClient(PageIndexClient):
|
||||
"""Cloud mode — fully managed by PageIndex cloud service. No LLM key needed."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self._init_cloud(api_key)
|
||||
|
|
|
|||
69
pageindex/collection.py
Normal file
69
pageindex/collection.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# pageindex/collection.py
|
||||
from __future__ import annotations
|
||||
from typing import AsyncIterator
|
||||
from .events import QueryEvent
|
||||
from .backend.protocol import Backend
|
||||
|
||||
|
||||
class QueryStream:
|
||||
"""Wraps backend.query_stream() as an async iterable object."""
|
||||
|
||||
def __init__(self, backend: Backend, collection: str, question: str,
|
||||
doc_ids: list[str] | None = None):
|
||||
self._backend = backend
|
||||
self._collection = collection
|
||||
self._question = question
|
||||
self._doc_ids = doc_ids
|
||||
|
||||
async def stream_events(self) -> AsyncIterator[QueryEvent]:
|
||||
async for event in self._backend.query_stream(
|
||||
self._collection, self._question, self._doc_ids
|
||||
):
|
||||
yield event
|
||||
|
||||
def __aiter__(self):
|
||||
return self.stream_events()
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(self, name: str, backend: Backend):
|
||||
self._name = name
|
||||
self._backend = backend
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def add(self, file_path: str) -> str:
|
||||
return self._backend.add_document(self._name, file_path)
|
||||
|
||||
def list_documents(self) -> list[dict]:
|
||||
return self._backend.list_documents(self._name)
|
||||
|
||||
def get_document(self, doc_id: str, include_text: bool = False) -> dict:
|
||||
return self._backend.get_document(self._name, doc_id, include_text=include_text)
|
||||
|
||||
def get_document_structure(self, doc_id: str) -> list:
|
||||
return self._backend.get_document_structure(self._name, doc_id)
|
||||
|
||||
def get_page_content(self, doc_id: str, pages: str) -> list:
|
||||
return self._backend.get_page_content(self._name, doc_id, pages)
|
||||
|
||||
def delete_document(self, doc_id: str) -> None:
|
||||
self._backend.delete_document(self._name, doc_id)
|
||||
|
||||
def query(self, question: str, doc_ids: list[str] | None = None,
|
||||
stream: bool = False) -> str | QueryStream:
|
||||
"""Query documents in this collection.
|
||||
|
||||
- stream=False: returns answer string (sync)
|
||||
- stream=True: returns async iterable of QueryEvent
|
||||
|
||||
Usage:
|
||||
answer = col.query("question")
|
||||
async for event in col.query("question", stream=True):
|
||||
...
|
||||
"""
|
||||
if stream:
|
||||
return QueryStream(self._backend, self._name, question, doc_ids)
|
||||
return self._backend.query(self._name, question, doc_ids)
|
||||
22
pageindex/config.py
Normal file
22
pageindex/config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# pageindex/config.py
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
"""Configuration for the PageIndex indexing pipeline.
|
||||
|
||||
All fields have sensible defaults. Advanced users can override
|
||||
via LocalClient(index_config=IndexConfig(...)) or a dict.
|
||||
"""
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
model: str = "gpt-4o-2024-11-20"
|
||||
retrieve_model: str | None = None
|
||||
toc_check_page_num: int = 20
|
||||
max_page_num_each_node: int = 10
|
||||
max_token_num_each_node: int = 20000
|
||||
if_add_node_id: bool = True
|
||||
if_add_node_summary: bool = True
|
||||
if_add_doc_description: bool = True
|
||||
if_add_node_text: bool = False
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
model: "gpt-4o-2024-11-20"
|
||||
# model: "anthropic/claude-sonnet-4-6"
|
||||
retrieve_model: "gpt-5.4" # defaults to `model` if not set
|
||||
toc_check_page_num: 20
|
||||
max_page_num_each_node: 10
|
||||
max_token_num_each_node: 20000
|
||||
if_add_node_id: "yes"
|
||||
if_add_node_summary: "yes"
|
||||
if_add_doc_description: "no"
|
||||
if_add_node_text: "no"
|
||||
28
pageindex/errors.py
Normal file
28
pageindex/errors.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
class PageIndexError(Exception):
|
||||
"""Base exception for all PageIndex SDK errors."""
|
||||
pass
|
||||
|
||||
|
||||
class CollectionNotFoundError(PageIndexError):
|
||||
"""Collection does not exist."""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentNotFoundError(PageIndexError):
|
||||
"""Document ID not found."""
|
||||
pass
|
||||
|
||||
|
||||
class IndexingError(PageIndexError):
|
||||
"""Indexing pipeline failure."""
|
||||
pass
|
||||
|
||||
|
||||
class CloudAPIError(PageIndexError):
|
||||
"""Cloud API returned error."""
|
||||
pass
|
||||
|
||||
|
||||
class FileTypeError(PageIndexError):
|
||||
"""Unsupported file type."""
|
||||
pass
|
||||
9
pageindex/events.py
Normal file
9
pageindex/events.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Literal, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryEvent:
|
||||
"""Event emitted during streaming query."""
|
||||
type: Literal["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"]
|
||||
data: Any
|
||||
0
pageindex/index/__init__.py
Normal file
0
pageindex/index/__init__.py
Normal file
2
pageindex/index/legacy_utils.py
Normal file
2
pageindex/index/legacy_utils.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Re-export from the original utils.py for backward compatibility
|
||||
from ..utils import *
|
||||
1155
pageindex/index/page_index.py
Normal file
1155
pageindex/index/page_index.py
Normal file
File diff suppressed because it is too large
Load diff
341
pageindex/index/page_index_md.py
Normal file
341
pageindex/index/page_index_md.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
try:
|
||||
from .legacy_utils import *
|
||||
except:
|
||||
from legacy_utils import *
|
||||
|
||||
async def get_node_summary(node, summary_token_threshold=200, model=None):
|
||||
node_text = node.get('text')
|
||||
num_tokens = count_tokens(node_text, model=model)
|
||||
if num_tokens < summary_token_threshold:
|
||||
return node_text
|
||||
else:
|
||||
return await generate_node_summary(node, model=model)
|
||||
|
||||
|
||||
async def generate_summaries_for_structure_md(structure, summary_token_threshold, model=None):
|
||||
nodes = structure_to_list(structure)
|
||||
tasks = [get_node_summary(node, summary_token_threshold=summary_token_threshold, model=model) for node in nodes]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
|
||||
for node, summary in zip(nodes, summaries):
|
||||
if not node.get('nodes'):
|
||||
node['summary'] = summary
|
||||
else:
|
||||
node['prefix_summary'] = summary
|
||||
return structure
|
||||
|
||||
|
||||
def extract_nodes_from_markdown(markdown_content):
|
||||
header_pattern = r'^(#{1,6})\s+(.+)$'
|
||||
code_block_pattern = r'^```'
|
||||
node_list = []
|
||||
|
||||
lines = markdown_content.split('\n')
|
||||
in_code_block = False
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped_line = line.strip()
|
||||
|
||||
# Check for code block delimiters (triple backticks)
|
||||
if re.match(code_block_pattern, stripped_line):
|
||||
in_code_block = not in_code_block
|
||||
continue
|
||||
|
||||
# Skip empty lines
|
||||
if not stripped_line:
|
||||
continue
|
||||
|
||||
# Only look for headers when not inside a code block
|
||||
if not in_code_block:
|
||||
match = re.match(header_pattern, stripped_line)
|
||||
if match:
|
||||
title = match.group(2).strip()
|
||||
node_list.append({'node_title': title, 'line_num': line_num})
|
||||
|
||||
return node_list, lines
|
||||
|
||||
|
||||
def extract_node_text_content(node_list, markdown_lines):
|
||||
all_nodes = []
|
||||
for node in node_list:
|
||||
line_content = markdown_lines[node['line_num'] - 1]
|
||||
header_match = re.match(r'^(#{1,6})', line_content)
|
||||
|
||||
if header_match is None:
|
||||
print(f"Warning: Line {node['line_num']} does not contain a valid header: '{line_content}'")
|
||||
continue
|
||||
|
||||
processed_node = {
|
||||
'title': node['node_title'],
|
||||
'line_num': node['line_num'],
|
||||
'level': len(header_match.group(1))
|
||||
}
|
||||
all_nodes.append(processed_node)
|
||||
|
||||
for i, node in enumerate(all_nodes):
|
||||
start_line = node['line_num'] - 1
|
||||
if i + 1 < len(all_nodes):
|
||||
end_line = all_nodes[i + 1]['line_num'] - 1
|
||||
else:
|
||||
end_line = len(markdown_lines)
|
||||
|
||||
node['text'] = '\n'.join(markdown_lines[start_line:end_line]).strip()
|
||||
return all_nodes
|
||||
|
||||
def update_node_list_with_text_token_count(node_list, model=None):
|
||||
|
||||
def find_all_children(parent_index, parent_level, node_list):
|
||||
"""Find all direct and indirect children of a parent node"""
|
||||
children_indices = []
|
||||
|
||||
# Look for children after the parent
|
||||
for i in range(parent_index + 1, len(node_list)):
|
||||
current_level = node_list[i]['level']
|
||||
|
||||
# If we hit a node at same or higher level than parent, stop
|
||||
if current_level <= parent_level:
|
||||
break
|
||||
|
||||
# This is a descendant
|
||||
children_indices.append(i)
|
||||
|
||||
return children_indices
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
result_list = node_list.copy()
|
||||
|
||||
# Process nodes from end to beginning to ensure children are processed before parents
|
||||
for i in range(len(result_list) - 1, -1, -1):
|
||||
current_node = result_list[i]
|
||||
current_level = current_node['level']
|
||||
|
||||
# Get all children of this node
|
||||
children_indices = find_all_children(i, current_level, result_list)
|
||||
|
||||
# Start with the node's own text
|
||||
node_text = current_node.get('text', '')
|
||||
total_text = node_text
|
||||
|
||||
# Add all children's text
|
||||
for child_index in children_indices:
|
||||
child_text = result_list[child_index].get('text', '')
|
||||
if child_text:
|
||||
total_text += '\n' + child_text
|
||||
|
||||
# Calculate token count for combined text
|
||||
result_list[i]['text_token_count'] = count_tokens(total_text, model=model)
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
def tree_thinning_for_index(node_list, min_node_token=None, model=None):
|
||||
def find_all_children(parent_index, parent_level, node_list):
|
||||
children_indices = []
|
||||
|
||||
for i in range(parent_index + 1, len(node_list)):
|
||||
current_level = node_list[i]['level']
|
||||
|
||||
if current_level <= parent_level:
|
||||
break
|
||||
|
||||
children_indices.append(i)
|
||||
|
||||
return children_indices
|
||||
|
||||
result_list = node_list.copy()
|
||||
nodes_to_remove = set()
|
||||
|
||||
for i in range(len(result_list) - 1, -1, -1):
|
||||
if i in nodes_to_remove:
|
||||
continue
|
||||
|
||||
current_node = result_list[i]
|
||||
current_level = current_node['level']
|
||||
|
||||
total_tokens = current_node.get('text_token_count', 0)
|
||||
|
||||
if total_tokens < min_node_token:
|
||||
children_indices = find_all_children(i, current_level, result_list)
|
||||
|
||||
children_texts = []
|
||||
for child_index in sorted(children_indices):
|
||||
if child_index not in nodes_to_remove:
|
||||
child_text = result_list[child_index].get('text', '')
|
||||
if child_text.strip():
|
||||
children_texts.append(child_text)
|
||||
nodes_to_remove.add(child_index)
|
||||
|
||||
if children_texts:
|
||||
parent_text = current_node.get('text', '')
|
||||
merged_text = parent_text
|
||||
for child_text in children_texts:
|
||||
if merged_text and not merged_text.endswith('\n'):
|
||||
merged_text += '\n\n'
|
||||
merged_text += child_text
|
||||
|
||||
result_list[i]['text'] = merged_text
|
||||
|
||||
result_list[i]['text_token_count'] = count_tokens(merged_text, model=model)
|
||||
|
||||
for index in sorted(nodes_to_remove, reverse=True):
|
||||
result_list.pop(index)
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
def build_tree_from_nodes(node_list):
|
||||
if not node_list:
|
||||
return []
|
||||
|
||||
stack = []
|
||||
root_nodes = []
|
||||
node_counter = 1
|
||||
|
||||
for node in node_list:
|
||||
current_level = node['level']
|
||||
|
||||
tree_node = {
|
||||
'title': node['title'],
|
||||
'node_id': str(node_counter).zfill(4),
|
||||
'text': node['text'],
|
||||
'line_num': node['line_num'],
|
||||
'nodes': []
|
||||
}
|
||||
node_counter += 1
|
||||
|
||||
while stack and stack[-1][1] >= current_level:
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
root_nodes.append(tree_node)
|
||||
else:
|
||||
parent_node, parent_level = stack[-1]
|
||||
parent_node['nodes'].append(tree_node)
|
||||
|
||||
stack.append((tree_node, current_level))
|
||||
|
||||
return root_nodes
|
||||
|
||||
|
||||
def clean_tree_for_output(tree_nodes):
|
||||
cleaned_nodes = []
|
||||
|
||||
for node in tree_nodes:
|
||||
cleaned_node = {
|
||||
'title': node['title'],
|
||||
'node_id': node['node_id'],
|
||||
'text': node['text'],
|
||||
'line_num': node['line_num']
|
||||
}
|
||||
|
||||
if node['nodes']:
|
||||
cleaned_node['nodes'] = clean_tree_for_output(node['nodes'])
|
||||
|
||||
cleaned_nodes.append(cleaned_node)
|
||||
|
||||
return cleaned_nodes
|
||||
|
||||
|
||||
async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary=False, summary_token_threshold=None, model=None, if_add_doc_description=False, if_add_node_text=False, if_add_node_id=True):
|
||||
with open(md_path, 'r', encoding='utf-8') as f:
|
||||
markdown_content = f.read()
|
||||
line_count = markdown_content.count('\n') + 1
|
||||
|
||||
print(f"Extracting nodes from markdown...")
|
||||
node_list, markdown_lines = extract_nodes_from_markdown(markdown_content)
|
||||
|
||||
print(f"Extracting text content from nodes...")
|
||||
nodes_with_content = extract_node_text_content(node_list, markdown_lines)
|
||||
|
||||
if if_thinning:
|
||||
nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model)
|
||||
print(f"Thinning nodes...")
|
||||
nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model)
|
||||
|
||||
print(f"Building tree from nodes...")
|
||||
tree_structure = build_tree_from_nodes(nodes_with_content)
|
||||
|
||||
if if_add_node_id:
|
||||
write_node_id(tree_structure)
|
||||
|
||||
print(f"Formatting tree structure...")
|
||||
|
||||
if if_add_node_summary:
|
||||
# Always include text for summary generation
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
|
||||
|
||||
print(f"Generating summaries for each node...")
|
||||
tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model)
|
||||
|
||||
if not if_add_node_text:
|
||||
# Remove text after summary generation if not requested
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
|
||||
|
||||
if if_add_doc_description:
|
||||
print(f"Generating document description...")
|
||||
clean_structure = create_clean_structure_for_description(tree_structure)
|
||||
doc_description = generate_doc_description(clean_structure, model=model)
|
||||
return {
|
||||
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
|
||||
'doc_description': doc_description,
|
||||
'line_count': line_count,
|
||||
'structure': tree_structure,
|
||||
}
|
||||
else:
|
||||
# No summaries needed, format based on text preference
|
||||
if if_add_node_text:
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes'])
|
||||
else:
|
||||
tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes'])
|
||||
|
||||
return {
|
||||
'doc_name': os.path.splitext(os.path.basename(md_path))[0],
|
||||
'line_count': line_count,
|
||||
'structure': tree_structure,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import json
|
||||
|
||||
# MD_NAME = 'Detect-Order-Construct'
|
||||
MD_NAME = 'cognitive-load'
|
||||
MD_PATH = os.path.join(os.path.dirname(__file__), '..', 'examples/documents/', f'{MD_NAME}.md')
|
||||
|
||||
|
||||
MODEL="gpt-4.1"
|
||||
IF_THINNING=False
|
||||
THINNING_THRESHOLD=5000
|
||||
SUMMARY_TOKEN_THRESHOLD=200
|
||||
IF_SUMMARY=True
|
||||
|
||||
tree_structure = asyncio.run(md_to_tree(
|
||||
md_path=MD_PATH,
|
||||
if_thinning=IF_THINNING,
|
||||
min_token_threshold=THINNING_THRESHOLD,
|
||||
if_add_node_summary='yes' if IF_SUMMARY else 'no',
|
||||
summary_token_threshold=SUMMARY_TOKEN_THRESHOLD,
|
||||
model=MODEL))
|
||||
|
||||
print('\n' + '='*60)
|
||||
print('TREE STRUCTURE')
|
||||
print('='*60)
|
||||
print_json(tree_structure)
|
||||
|
||||
print('\n' + '='*60)
|
||||
print('TABLE OF CONTENTS')
|
||||
print('='*60)
|
||||
print_toc(tree_structure['structure'])
|
||||
|
||||
output_path = os.path.join(os.path.dirname(__file__), '..', 'results', f'{MD_NAME}_structure.json')
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(tree_structure, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nTree structure saved to: {output_path}")
|
||||
122
pageindex/index/pipeline.py
Normal file
122
pageindex/index/pipeline.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# pageindex/index/pipeline.py
|
||||
from __future__ import annotations
|
||||
from ..parser.protocol import ContentNode, ParsedDocument
|
||||
|
||||
|
||||
def detect_strategy(nodes: list[ContentNode]) -> str:
|
||||
"""Determine which indexing strategy to use based on node data."""
|
||||
if any(n.level is not None for n in nodes):
|
||||
return "level_based"
|
||||
return "content_based"
|
||||
|
||||
|
||||
def build_tree_from_levels(nodes: list[ContentNode]) -> list[dict]:
|
||||
"""Strategy 0: Build tree from explicit level information.
|
||||
Adapted from pageindex/page_index_md.py:build_tree_from_nodes."""
|
||||
stack = []
|
||||
root_nodes = []
|
||||
|
||||
for node in nodes:
|
||||
tree_node = {
|
||||
"title": node.title or "",
|
||||
"text": node.content,
|
||||
"line_num": node.index,
|
||||
"nodes": [],
|
||||
}
|
||||
current_level = node.level or 1
|
||||
|
||||
while stack and stack[-1][1] >= current_level:
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
root_nodes.append(tree_node)
|
||||
else:
|
||||
parent_node, _ = stack[-1]
|
||||
parent_node["nodes"].append(tree_node)
|
||||
|
||||
stack.append((tree_node, current_level))
|
||||
|
||||
return root_nodes
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine, handling the case where an event loop is already running."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
# Already inside an event loop -- run in a separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def build_index(parsed: ParsedDocument, model: str = None, opt=None) -> dict:
|
||||
"""Main entry point: ParsedDocument -> tree structure dict.
|
||||
Routes to the appropriate strategy and runs enhancement."""
|
||||
from .utils import (write_node_id, add_node_text, remove_structure_text,
|
||||
generate_summaries_for_structure, generate_doc_description,
|
||||
create_clean_structure_for_description)
|
||||
from ..config import IndexConfig
|
||||
|
||||
if opt is None:
|
||||
opt = IndexConfig(model=model) if model else IndexConfig()
|
||||
|
||||
nodes = parsed.nodes
|
||||
strategy = detect_strategy(nodes)
|
||||
|
||||
if strategy == "level_based":
|
||||
structure = build_tree_from_levels(nodes)
|
||||
# For level-based, text is already in the tree nodes
|
||||
else:
|
||||
# Strategies 1-3: convert ContentNode list to page_list format for existing pipeline
|
||||
page_list = [(n.content, n.tokens) for n in nodes]
|
||||
structure = _run_async(_content_based_pipeline(page_list, opt))
|
||||
|
||||
# Unified enhancement
|
||||
if opt.if_add_node_id:
|
||||
write_node_id(structure)
|
||||
|
||||
if strategy != "level_based":
|
||||
if opt.if_add_node_text or opt.if_add_node_summary:
|
||||
add_node_text(structure, page_list)
|
||||
|
||||
if opt.if_add_node_summary:
|
||||
_run_async(generate_summaries_for_structure(structure, model=opt.model))
|
||||
|
||||
if not opt.if_add_node_text and strategy != "level_based":
|
||||
remove_structure_text(structure)
|
||||
|
||||
result = {
|
||||
"doc_name": parsed.doc_name,
|
||||
"structure": structure,
|
||||
}
|
||||
|
||||
if opt.if_add_doc_description:
|
||||
clean_structure = create_clean_structure_for_description(structure)
|
||||
result["doc_description"] = generate_doc_description(
|
||||
clean_structure, model=opt.model
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class _NullLogger:
|
||||
"""Minimal logger that satisfies the tree_parser interface without writing files."""
|
||||
def info(self, message, **kwargs): pass
|
||||
def error(self, message, **kwargs): pass
|
||||
def debug(self, message, **kwargs): pass
|
||||
|
||||
|
||||
async def _content_based_pipeline(page_list, opt):
|
||||
"""Strategies 1-3: delegates to the existing PDF pipeline from pageindex/page_index.py.
|
||||
|
||||
The page_list is already in the format expected by tree_parser:
|
||||
[(page_text, token_count), ...]
|
||||
"""
|
||||
from .page_index import tree_parser
|
||||
|
||||
logger = _NullLogger()
|
||||
structure = await tree_parser(page_list, opt, doc=None, logger=logger)
|
||||
return structure
|
||||
431
pageindex/index/utils.py
Normal file
431
pageindex/index/utils.py
Normal file
|
|
@ -0,0 +1,431 @@
|
|||
import litellm
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import asyncio
|
||||
import PyPDF2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def count_tokens(text, model=None):
|
||||
if not text:
|
||||
return 0
|
||||
return litellm.token_counter(model=model, text=text)
|
||||
|
||||
|
||||
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
||||
if model:
|
||||
model = model.removeprefix("litellm/")
|
||||
max_retries = 10
|
||||
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
litellm.drop_params = True
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if return_finish_reason:
|
||||
finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished"
|
||||
return content, finish_reason
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning("Retrying LLM completion (%d/%d)", i + 1, max_retries)
|
||||
logger.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
time.sleep(1)
|
||||
else:
|
||||
logger.error('Max retries reached for prompt: ' + prompt)
|
||||
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
||||
|
||||
|
||||
|
||||
async def llm_acompletion(model, prompt):
|
||||
if model:
|
||||
model = model.removeprefix("litellm/")
|
||||
max_retries = 10
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
litellm.drop_params = True
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.warning("Retrying async LLM completion (%d/%d)", i + 1, max_retries)
|
||||
logger.error(f"Error: {e}")
|
||||
if i < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
logger.error('Max retries reached for prompt: ' + prompt)
|
||||
raise RuntimeError(f"LLM call failed after {max_retries} retries") from e
|
||||
|
||||
|
||||
def extract_json(content):
|
||||
try:
|
||||
# First, try to extract JSON enclosed within ```json and ```
|
||||
start_idx = content.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7 # Adjust index to start after the delimiter
|
||||
end_idx = content.rfind("```")
|
||||
json_content = content[start_idx:end_idx].strip()
|
||||
else:
|
||||
# If no delimiters, assume entire content could be JSON
|
||||
json_content = content.strip()
|
||||
|
||||
# Clean up common issues that might cause parsing errors
|
||||
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
|
||||
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
|
||||
json_content = ' '.join(json_content.split()) # Normalize whitespace
|
||||
|
||||
# Attempt to parse and return the JSON object
|
||||
return json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to extract JSON: {e}")
|
||||
# Try to clean up the content further if initial parsing fails
|
||||
try:
|
||||
# Remove any trailing commas before closing brackets/braces
|
||||
json_content = json_content.replace(',]', ']').replace(',}', '}')
|
||||
return json.loads(json_content)
|
||||
except Exception:
|
||||
logging.error("Failed to parse JSON even after cleanup")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error while extracting JSON: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_json_content(response):
|
||||
start_idx = response.find("```json")
|
||||
if start_idx != -1:
|
||||
start_idx += 7
|
||||
response = response[start_idx:]
|
||||
|
||||
end_idx = response.rfind("```")
|
||||
if end_idx != -1:
|
||||
response = response[:end_idx]
|
||||
|
||||
json_content = response.strip()
|
||||
return json_content
|
||||
|
||||
|
||||
def write_node_id(data, node_id=0):
|
||||
if isinstance(data, dict):
|
||||
data['node_id'] = str(node_id).zfill(4)
|
||||
node_id += 1
|
||||
for key in list(data.keys()):
|
||||
if 'nodes' in key:
|
||||
node_id = write_node_id(data[key], node_id)
|
||||
elif isinstance(data, list):
|
||||
for index in range(len(data)):
|
||||
node_id = write_node_id(data[index], node_id)
|
||||
return node_id
|
||||
|
||||
|
||||
def remove_fields(data, fields=None):
|
||||
fields = fields or ["text"]
|
||||
if isinstance(data, dict):
|
||||
return {k: remove_fields(v, fields)
|
||||
for k, v in data.items() if k not in fields}
|
||||
elif isinstance(data, list):
|
||||
return [remove_fields(item, fields) for item in data]
|
||||
return data
|
||||
|
||||
|
||||
def structure_to_list(structure):
|
||||
if isinstance(structure, dict):
|
||||
nodes = []
|
||||
nodes.append(structure)
|
||||
if 'nodes' in structure:
|
||||
nodes.extend(structure_to_list(structure['nodes']))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(structure_to_list(item))
|
||||
return nodes
|
||||
|
||||
|
||||
def get_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
nodes = [structure_node]
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
nodes.extend(get_nodes(structure[key]))
|
||||
return nodes
|
||||
elif isinstance(structure, list):
|
||||
nodes = []
|
||||
for item in structure:
|
||||
nodes.extend(get_nodes(item))
|
||||
return nodes
|
||||
|
||||
|
||||
def get_leaf_nodes(structure):
|
||||
if isinstance(structure, dict):
|
||||
if not structure['nodes']:
|
||||
structure_node = copy.deepcopy(structure)
|
||||
structure_node.pop('nodes', None)
|
||||
return [structure_node]
|
||||
else:
|
||||
leaf_nodes = []
|
||||
for key in list(structure.keys()):
|
||||
if 'nodes' in key:
|
||||
leaf_nodes.extend(get_leaf_nodes(structure[key]))
|
||||
return leaf_nodes
|
||||
elif isinstance(structure, list):
|
||||
leaf_nodes = []
|
||||
for item in structure:
|
||||
leaf_nodes.extend(get_leaf_nodes(item))
|
||||
return leaf_nodes
|
||||
|
||||
|
||||
async def generate_node_summary(node, model=None):
|
||||
prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
|
||||
|
||||
Partial Document Text: {node['text']}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = await llm_acompletion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
async def generate_summaries_for_structure(structure, model=None):
|
||||
nodes = structure_to_list(structure)
|
||||
tasks = [generate_node_summary(node, model=model) for node in nodes]
|
||||
summaries = await asyncio.gather(*tasks)
|
||||
|
||||
for node, summary in zip(nodes, summaries):
|
||||
node['summary'] = summary
|
||||
return structure
|
||||
|
||||
|
||||
def generate_doc_description(structure, model=None):
|
||||
prompt = f"""Your are an expert in generating descriptions for a document.
|
||||
You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents.
|
||||
|
||||
Document Structure: {structure}
|
||||
|
||||
Directly return the description, do not include any other text.
|
||||
"""
|
||||
response = llm_completion(model, prompt)
|
||||
return response
|
||||
|
||||
|
||||
def list_to_tree(data):
|
||||
def get_parent_structure(structure):
|
||||
"""Helper function to get the parent structure code"""
|
||||
if not structure:
|
||||
return None
|
||||
parts = str(structure).split('.')
|
||||
return '.'.join(parts[:-1]) if len(parts) > 1 else None
|
||||
|
||||
# First pass: Create nodes and track parent-child relationships
|
||||
nodes = {}
|
||||
root_nodes = []
|
||||
|
||||
for item in data:
|
||||
structure = item.get('structure')
|
||||
node = {
|
||||
'title': item.get('title'),
|
||||
'start_index': item.get('start_index'),
|
||||
'end_index': item.get('end_index'),
|
||||
'nodes': []
|
||||
}
|
||||
|
||||
nodes[structure] = node
|
||||
|
||||
# Find parent
|
||||
parent_structure = get_parent_structure(structure)
|
||||
|
||||
if parent_structure:
|
||||
# Add as child to parent if parent exists
|
||||
if parent_structure in nodes:
|
||||
nodes[parent_structure]['nodes'].append(node)
|
||||
else:
|
||||
root_nodes.append(node)
|
||||
else:
|
||||
# No parent, this is a root node
|
||||
root_nodes.append(node)
|
||||
|
||||
# Helper function to clean empty children arrays
|
||||
def clean_node(node):
|
||||
if not node['nodes']:
|
||||
del node['nodes']
|
||||
else:
|
||||
for child in node['nodes']:
|
||||
clean_node(child)
|
||||
return node
|
||||
|
||||
# Clean and return the tree
|
||||
return [clean_node(node) for node in root_nodes]
|
||||
|
||||
|
||||
def post_processing(structure, end_physical_index):
|
||||
# First convert page_number to start_index in flat list
|
||||
for i, item in enumerate(structure):
|
||||
item['start_index'] = item.get('physical_index')
|
||||
if i < len(structure) - 1:
|
||||
if structure[i + 1].get('appear_start') == 'yes':
|
||||
item['end_index'] = structure[i + 1]['physical_index']-1
|
||||
else:
|
||||
item['end_index'] = structure[i + 1]['physical_index']
|
||||
else:
|
||||
item['end_index'] = end_physical_index
|
||||
tree = list_to_tree(structure)
|
||||
if len(tree)!=0:
|
||||
return tree
|
||||
else:
|
||||
### remove appear_start
|
||||
for node in structure:
|
||||
node.pop('appear_start', None)
|
||||
node.pop('physical_index', None)
|
||||
return structure
|
||||
|
||||
|
||||
def reorder_dict(data, key_order):
|
||||
if not key_order:
|
||||
return data
|
||||
return {key: data[key] for key in key_order if key in data}
|
||||
|
||||
|
||||
def format_structure(structure, order=None):
|
||||
if not order:
|
||||
return structure
|
||||
if isinstance(structure, dict):
|
||||
if 'nodes' in structure:
|
||||
structure['nodes'] = format_structure(structure['nodes'], order)
|
||||
if not structure.get('nodes'):
|
||||
structure.pop('nodes', None)
|
||||
structure = reorder_dict(structure, order)
|
||||
elif isinstance(structure, list):
|
||||
structure = [format_structure(item, order) for item in structure]
|
||||
return structure
|
||||
|
||||
|
||||
def create_clean_structure_for_description(structure):
|
||||
"""
|
||||
Create a clean structure for document description generation,
|
||||
excluding unnecessary fields like 'text'.
|
||||
"""
|
||||
if isinstance(structure, dict):
|
||||
clean_node = {}
|
||||
# Only include essential fields for description
|
||||
for key in ['title', 'node_id', 'summary', 'prefix_summary']:
|
||||
if key in structure:
|
||||
clean_node[key] = structure[key]
|
||||
|
||||
# Recursively process child nodes
|
||||
if 'nodes' in structure and structure['nodes']:
|
||||
clean_node['nodes'] = create_clean_structure_for_description(structure['nodes'])
|
||||
|
||||
return clean_node
|
||||
elif isinstance(structure, list):
|
||||
return [create_clean_structure_for_description(item) for item in structure]
|
||||
else:
|
||||
return structure
|
||||
|
||||
|
||||
def _get_text_of_pages(page_list, start_page, end_page):
|
||||
"""Concatenate text from page_list for pages [start_page, end_page] (1-indexed)."""
|
||||
text = ""
|
||||
for page_num in range(start_page - 1, end_page):
|
||||
text += page_list[page_num][0]
|
||||
return text
|
||||
|
||||
|
||||
def add_node_text(node, page_list):
|
||||
"""Recursively add 'text' field to each node from page_list content.
|
||||
|
||||
Each node must have 'start_index' and 'end_index' (1-indexed page numbers).
|
||||
page_list is [(page_text, token_count), ...].
|
||||
"""
|
||||
if isinstance(node, dict):
|
||||
start_page = node.get('start_index')
|
||||
end_page = node.get('end_index')
|
||||
if start_page is not None and end_page is not None:
|
||||
node['text'] = _get_text_of_pages(page_list, start_page, end_page)
|
||||
if 'nodes' in node:
|
||||
add_node_text(node['nodes'], page_list)
|
||||
elif isinstance(node, list):
|
||||
for item in node:
|
||||
add_node_text(item, page_list)
|
||||
|
||||
|
||||
def remove_structure_text(data):
|
||||
if isinstance(data, dict):
|
||||
data.pop('text', None)
|
||||
if 'nodes' in data:
|
||||
remove_structure_text(data['nodes'])
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
remove_structure_text(item)
|
||||
return data
|
||||
|
||||
|
||||
# ── Functions migrated from retrieve.py ──────────────────────────────────────
|
||||
|
||||
def parse_pages(pages: str) -> list[int]:
|
||||
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
|
||||
result = []
|
||||
for part in pages.split(','):
|
||||
part = part.strip()
|
||||
if '-' in part:
|
||||
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
||||
if start > end:
|
||||
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
||||
result.extend(range(start, end + 1))
|
||||
else:
|
||||
result.append(int(part))
|
||||
result = [p for p in result if p >= 1]
|
||||
result = sorted(set(result))
|
||||
if len(result) > 1000:
|
||||
raise ValueError(f"Page range too large: {len(result)} pages (max 1000)")
|
||||
return result
|
||||
|
||||
|
||||
def get_pdf_page_content(file_path: str, page_nums: list[int]) -> list[dict]:
|
||||
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
|
||||
with open(file_path, 'rb') as f:
|
||||
pdf_reader = PyPDF2.PdfReader(f)
|
||||
total = len(pdf_reader.pages)
|
||||
valid_pages = [p for p in page_nums if 1 <= p <= total]
|
||||
return [
|
||||
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
|
||||
for p in valid_pages
|
||||
]
|
||||
|
||||
|
||||
def get_md_page_content(structure: list, page_nums: list[int]) -> list[dict]:
|
||||
"""
|
||||
For Markdown documents, 'pages' are line numbers.
|
||||
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
|
||||
"""
|
||||
if not page_nums:
|
||||
return []
|
||||
min_line, max_line = min(page_nums), max(page_nums)
|
||||
results = []
|
||||
seen = set()
|
||||
|
||||
def _traverse(nodes):
|
||||
for node in nodes:
|
||||
ln = node.get('line_num')
|
||||
if ln and min_line <= ln <= max_line and ln not in seen:
|
||||
seen.add(ln)
|
||||
results.append({'page': ln, 'content': node.get('text', '')})
|
||||
if node.get('nodes'):
|
||||
_traverse(node['nodes'])
|
||||
|
||||
_traverse(structure)
|
||||
results.sort(key=lambda x: x['page'])
|
||||
return results
|
||||
|
|
@ -1113,11 +1113,12 @@ def page_index_main(doc, opt=None):
|
|||
def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
|
||||
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None):
|
||||
|
||||
from .config import IndexConfig
|
||||
user_opt = {
|
||||
arg: value for arg, value in locals().items()
|
||||
if arg != "doc" and value is not None
|
||||
}
|
||||
opt = ConfigLoader().load(user_opt)
|
||||
opt = IndexConfig(**user_opt)
|
||||
return page_index_main(doc, opt)
|
||||
|
||||
|
||||
|
|
|
|||
0
pageindex/parser/__init__.py
Normal file
0
pageindex/parser/__init__.py
Normal file
59
pageindex/parser/markdown.py
Normal file
59
pageindex/parser/markdown.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import re
|
||||
from pathlib import Path
|
||||
from .protocol import ContentNode, ParsedDocument
|
||||
from ..index.utils import count_tokens
|
||||
|
||||
|
||||
class MarkdownParser:
|
||||
def supported_extensions(self) -> list[str]:
|
||||
return [".md", ".markdown"]
|
||||
|
||||
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
|
||||
path = Path(file_path)
|
||||
model = kwargs.get("model")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split("\n")
|
||||
headers = self._extract_headers(lines)
|
||||
nodes = self._build_nodes(headers, lines, model)
|
||||
|
||||
return ParsedDocument(doc_name=path.stem, nodes=nodes)
|
||||
|
||||
def _extract_headers(self, lines: list[str]) -> list[dict]:
|
||||
header_pattern = r"^(#{1,6})\s+(.+)$"
|
||||
code_block_pattern = r"^```"
|
||||
headers = []
|
||||
in_code_block = False
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if re.match(code_block_pattern, stripped):
|
||||
in_code_block = not in_code_block
|
||||
continue
|
||||
if not in_code_block and stripped:
|
||||
match = re.match(header_pattern, stripped)
|
||||
if match:
|
||||
headers.append({
|
||||
"title": match.group(2).strip(),
|
||||
"level": len(match.group(1)),
|
||||
"line_num": line_num,
|
||||
})
|
||||
return headers
|
||||
|
||||
def _build_nodes(self, headers: list[dict], lines: list[str], model: str | None) -> list[ContentNode]:
|
||||
nodes = []
|
||||
for i, header in enumerate(headers):
|
||||
start = header["line_num"] - 1
|
||||
end = headers[i + 1]["line_num"] - 1 if i + 1 < len(headers) else len(lines)
|
||||
text = "\n".join(lines[start:end]).strip()
|
||||
tokens = count_tokens(text, model=model)
|
||||
nodes.append(ContentNode(
|
||||
content=text,
|
||||
tokens=tokens,
|
||||
title=header["title"],
|
||||
index=header["line_num"],
|
||||
level=header["level"],
|
||||
))
|
||||
return nodes
|
||||
101
pageindex/parser/pdf.py
Normal file
101
pageindex/parser/pdf.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import pymupdf
|
||||
from pathlib import Path
|
||||
from .protocol import ContentNode, ParsedDocument
|
||||
from ..index.utils import count_tokens
|
||||
|
||||
# Minimum image dimension to keep (skip icons/artifacts)
|
||||
_MIN_IMAGE_SIZE = 32
|
||||
|
||||
|
||||
class PdfParser:
|
||||
def supported_extensions(self) -> list[str]:
|
||||
return [".pdf"]
|
||||
|
||||
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
|
||||
path = Path(file_path)
|
||||
model = kwargs.get("model")
|
||||
images_dir = kwargs.get("images_dir")
|
||||
nodes = []
|
||||
|
||||
with pymupdf.open(str(path)) as doc:
|
||||
for i, page in enumerate(doc):
|
||||
page_num = i + 1
|
||||
if images_dir:
|
||||
content, images = self._extract_page_with_images(
|
||||
doc, page, page_num, images_dir)
|
||||
else:
|
||||
content = page.get_text()
|
||||
images = None
|
||||
|
||||
tokens = count_tokens(content, model=model)
|
||||
nodes.append(ContentNode(
|
||||
content=content or "",
|
||||
tokens=tokens,
|
||||
index=page_num,
|
||||
images=images if images else None,
|
||||
))
|
||||
|
||||
return ParsedDocument(doc_name=path.stem, nodes=nodes)
|
||||
|
||||
@staticmethod
|
||||
def _extract_page_with_images(doc, page, page_num: int,
|
||||
images_dir: str) -> tuple[str, list[dict]]:
|
||||
"""Extract text and images from a page, preserving their relative order.
|
||||
|
||||
Uses get_text("dict") to iterate blocks in reading order.
|
||||
Text blocks become text; image blocks are saved to disk and replaced
|
||||
with an inline placeholder: 
|
||||
"""
|
||||
images_path = Path(images_dir)
|
||||
images_path.mkdir(parents=True, exist_ok=True)
|
||||
# Use path relative to cwd so downstream consumers can access directly
|
||||
try:
|
||||
rel_images_path = images_path.relative_to(Path.cwd())
|
||||
except ValueError:
|
||||
rel_images_path = images_path
|
||||
|
||||
parts: list[str] = []
|
||||
images: list[dict] = []
|
||||
img_idx = 0
|
||||
|
||||
for block in page.get_text("dict")["blocks"]:
|
||||
if block["type"] == 0: # text block
|
||||
lines = []
|
||||
for line in block["lines"]:
|
||||
spans_text = "".join(span["text"] for span in line["spans"])
|
||||
lines.append(spans_text)
|
||||
parts.append("\n".join(lines))
|
||||
|
||||
elif block["type"] == 1: # image block
|
||||
width = block.get("width", 0)
|
||||
height = block.get("height", 0)
|
||||
if width < _MIN_IMAGE_SIZE or height < _MIN_IMAGE_SIZE:
|
||||
continue
|
||||
|
||||
image_bytes = block.get("image")
|
||||
ext = block.get("ext", "png")
|
||||
if not image_bytes:
|
||||
continue
|
||||
|
||||
try:
|
||||
pix = pymupdf.Pixmap(image_bytes)
|
||||
if pix.n > 4:
|
||||
pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
||||
filename = f"p{page_num}_img{img_idx}.png"
|
||||
save_path = images_path / filename
|
||||
pix.save(str(save_path))
|
||||
pix = None
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
rel_path = str(rel_images_path / filename)
|
||||
images.append({
|
||||
"path": rel_path,
|
||||
"width": width,
|
||||
"height": height,
|
||||
})
|
||||
parts.append(f"")
|
||||
img_idx += 1
|
||||
|
||||
content = "\n".join(parts)
|
||||
return content, images
|
||||
28
pageindex/parser/protocol.py
Normal file
28
pageindex/parser/protocol.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentNode:
|
||||
"""Universal content unit produced by parsers."""
|
||||
content: str
|
||||
tokens: int
|
||||
title: str | None = None
|
||||
index: int | None = None
|
||||
level: int | None = None
|
||||
images: list[dict] | None = None # [{"path": str, "width": int, "height": int}, ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDocument:
|
||||
"""Unified parser output. Always a flat list of ContentNode."""
|
||||
doc_name: str
|
||||
nodes: list[ContentNode]
|
||||
metadata: dict | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DocumentParser(Protocol):
|
||||
def supported_extensions(self) -> list[str]: ...
|
||||
def parse(self, file_path: str, **kwargs) -> ParsedDocument: ...
|
||||
0
pageindex/storage/__init__.py
Normal file
0
pageindex/storage/__init__.py
Normal file
18
pageindex/storage/protocol.py
Normal file
18
pageindex/storage/protocol.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from __future__ import annotations
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class StorageEngine(Protocol):
|
||||
def create_collection(self, name: str) -> None: ...
|
||||
def get_or_create_collection(self, name: str) -> None: ...
|
||||
def list_collections(self) -> list[str]: ...
|
||||
def delete_collection(self, name: str) -> None: ...
|
||||
def save_document(self, collection: str, doc_id: str, doc: dict) -> None: ...
|
||||
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: ...
|
||||
def get_document(self, collection: str, doc_id: str) -> dict: ...
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> list: ...
|
||||
def get_pages(self, collection: str, doc_id: str) -> list | None: ...
|
||||
def list_documents(self, collection: str) -> list[dict]: ...
|
||||
def delete_document(self, collection: str, doc_id: str) -> None: ...
|
||||
def close(self) -> None: ...
|
||||
164
pageindex/storage/sqlite.py
Normal file
164
pageindex/storage/sqlite.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SQLiteStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self._db_path = Path(db_path).expanduser()
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._local = threading.local()
|
||||
self._connections: list[sqlite3.Connection] = []
|
||||
self._conn_lock = threading.Lock()
|
||||
self._init_schema()
|
||||
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
"""Return a thread-local SQLite connection."""
|
||||
if not hasattr(self._local, "conn"):
|
||||
conn = sqlite3.connect(str(self._db_path))
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._local.conn = conn
|
||||
with self._conn_lock:
|
||||
self._connections.append(conn)
|
||||
return self._local.conn
|
||||
|
||||
def _init_schema(self):
|
||||
conn = self._get_conn()
|
||||
conn.execute("PRAGMA user_version = 1")
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS collections (
|
||||
name TEXT PRIMARY KEY CHECK(length(name) <= 128 AND name GLOB '[a-zA-Z0-9_-]*'),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
collection_name TEXT NOT NULL REFERENCES collections(name) ON DELETE CASCADE,
|
||||
doc_name TEXT,
|
||||
doc_description TEXT,
|
||||
file_path TEXT,
|
||||
file_hash TEXT,
|
||||
doc_type TEXT NOT NULL,
|
||||
structure JSON,
|
||||
pages JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_collection ON documents(collection_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_hash ON documents(collection_name, file_hash);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def create_collection(self, name: str) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute("INSERT INTO collections (name) VALUES (?)", (name,))
|
||||
conn.commit()
|
||||
|
||||
def get_or_create_collection(self, name: str) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute("INSERT OR IGNORE INTO collections (name) VALUES (?)", (name,))
|
||||
conn.commit()
|
||||
|
||||
def list_collections(self) -> list[str]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute("SELECT name FROM collections ORDER BY name").fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute("DELETE FROM collections WHERE name = ?", (name,))
|
||||
conn.commit()
|
||||
|
||||
def save_document(self, collection: str, doc_id: str, doc: dict) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO documents
|
||||
(doc_id, collection_name, doc_name, doc_description, file_path, file_hash, doc_type, structure, pages)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(doc_id, collection, doc.get("doc_name"), doc.get("doc_description"),
|
||||
doc.get("file_path"), doc.get("file_hash"), doc["doc_type"],
|
||||
json.dumps(doc.get("structure", [])),
|
||||
json.dumps(doc.get("pages")) if doc.get("pages") else None),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT doc_id FROM documents WHERE collection_name = ? AND file_hash = ?",
|
||||
(collection, file_hash),
|
||||
).fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def get_document(self, collection: str, doc_id: str) -> dict:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT doc_id, doc_name, doc_description, file_path, doc_type FROM documents WHERE doc_id = ? AND collection_name = ?",
|
||||
(doc_id, collection),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return {}
|
||||
return {"doc_id": row[0], "doc_name": row[1], "doc_description": row[2],
|
||||
"file_path": row[3], "doc_type": row[4]}
|
||||
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> list:
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT structure FROM documents WHERE doc_id = ? AND collection_name = ?",
|
||||
(doc_id, collection),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return []
|
||||
return json.loads(row[0])
|
||||
|
||||
def get_pages(self, collection: str, doc_id: str) -> list | None:
|
||||
"""Return cached page content, or None if not cached."""
|
||||
conn = self._get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT pages FROM documents WHERE doc_id = ? AND collection_name = ?",
|
||||
(doc_id, collection),
|
||||
).fetchone()
|
||||
if not row or not row[0]:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
|
||||
def list_documents(self, collection: str) -> list[dict]:
|
||||
conn = self._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT doc_id, doc_name, doc_type FROM documents WHERE collection_name = ? ORDER BY created_at",
|
||||
(collection,),
|
||||
).fetchall()
|
||||
return [{"doc_id": r[0], "doc_name": r[1], "doc_type": r[2]} for r in rows]
|
||||
|
||||
def delete_document(self, collection: str, doc_id: str) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ? AND collection_name = ?",
|
||||
(doc_id, collection),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all tracked SQLite connections across all threads."""
|
||||
with self._conn_lock:
|
||||
for conn in self._connections:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._connections.clear()
|
||||
if hasattr(self._local, "conn"):
|
||||
del self._local.conn
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
48
pyproject.toml
Normal file
48
pyproject.toml
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "pageindex"
|
||||
version = "0.3.0"
|
||||
description = "Python SDK for PageIndex"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
requires-python = ">=3.10"
|
||||
authors = [
|
||||
{name = "Ray", email = "ray@vectify.ai"},
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
keywords = ["rag", "document", "retrieval", "llm", "pageindex"]
|
||||
dependencies = [
|
||||
"litellm>=1.82.0",
|
||||
"pymupdf>=1.26.0",
|
||||
"PyPDF2>=3.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pyyaml>=6.0",
|
||||
"openai-agents>=0.1.0",
|
||||
"requests>=2.28.0",
|
||||
"httpx[socks]>=0.28.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/VectifyAI/PageIndex"
|
||||
Documentation = "https://docs.pageindex.ai"
|
||||
Repository = "https://github.com/VectifyAI/PageIndex"
|
||||
Issues = "https://github.com/VectifyAI/PageIndex/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["pageindex*"]
|
||||
102
run_pageindex.py
102
run_pageindex.py
|
|
@ -1,9 +1,9 @@
|
|||
import argparse
|
||||
import os
|
||||
import json
|
||||
from pageindex import *
|
||||
from pageindex.page_index_md import md_to_tree
|
||||
from pageindex.utils import ConfigLoader
|
||||
from pageindex.index.page_index import *
|
||||
from pageindex.index.page_index_md import md_to_tree
|
||||
from pageindex.config import IndexConfig
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up argument parser
|
||||
|
|
@ -11,7 +11,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--pdf_path', type=str, help='Path to the PDF file')
|
||||
parser.add_argument('--md_path', type=str, help='Path to the Markdown file')
|
||||
|
||||
parser.add_argument('--model', type=str, default=None, help='Model to use (overrides config.yaml)')
|
||||
parser.add_argument('--model', type=str, default=None, help='Model to use')
|
||||
|
||||
parser.add_argument('--toc-check-pages', type=int, default=None,
|
||||
help='Number of pages to check for table of contents (PDF only)')
|
||||
|
|
@ -20,15 +20,15 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--max-tokens-per-node', type=int, default=None,
|
||||
help='Maximum number of tokens per node (PDF only)')
|
||||
|
||||
parser.add_argument('--if-add-node-id', type=str, default=None,
|
||||
help='Whether to add node id to the node')
|
||||
parser.add_argument('--if-add-node-summary', type=str, default=None,
|
||||
help='Whether to add summary to the node')
|
||||
parser.add_argument('--if-add-doc-description', type=str, default=None,
|
||||
help='Whether to add doc description to the doc')
|
||||
parser.add_argument('--if-add-node-text', type=str, default=None,
|
||||
help='Whether to add text to the node')
|
||||
|
||||
parser.add_argument('--if-add-node-id', action='store_true', default=None,
|
||||
help='Add node id to the node')
|
||||
parser.add_argument('--if-add-node-summary', action='store_true', default=None,
|
||||
help='Add summary to the node')
|
||||
parser.add_argument('--if-add-doc-description', action='store_true', default=None,
|
||||
help='Add doc description to the doc')
|
||||
parser.add_argument('--if-add-node-text', action='store_true', default=None,
|
||||
help='Add text to the node')
|
||||
|
||||
# Markdown specific arguments
|
||||
parser.add_argument('--if-thinning', type=str, default='no',
|
||||
help='Whether to apply tree thinning for markdown (markdown only)')
|
||||
|
|
@ -37,77 +37,61 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--summary-token-threshold', type=int, default=200,
|
||||
help='Token threshold for generating summaries (markdown only)')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Validate that exactly one file type is specified
|
||||
if not args.pdf_path and not args.md_path:
|
||||
raise ValueError("Either --pdf_path or --md_path must be specified")
|
||||
if args.pdf_path and args.md_path:
|
||||
raise ValueError("Only one of --pdf_path or --md_path can be specified")
|
||||
|
||||
|
||||
# Build IndexConfig from CLI args (None values use defaults)
|
||||
config_overrides = {
|
||||
k: v for k, v in {
|
||||
"model": args.model,
|
||||
"toc_check_page_num": args.toc_check_pages,
|
||||
"max_page_num_each_node": args.max_pages_per_node,
|
||||
"max_token_num_each_node": args.max_tokens_per_node,
|
||||
"if_add_node_id": args.if_add_node_id,
|
||||
"if_add_node_summary": args.if_add_node_summary,
|
||||
"if_add_doc_description": args.if_add_doc_description,
|
||||
"if_add_node_text": args.if_add_node_text,
|
||||
}.items() if v is not None
|
||||
}
|
||||
opt = IndexConfig(**config_overrides)
|
||||
|
||||
if args.pdf_path:
|
||||
# Validate PDF file
|
||||
if not args.pdf_path.lower().endswith('.pdf'):
|
||||
raise ValueError("PDF file must have .pdf extension")
|
||||
if not os.path.isfile(args.pdf_path):
|
||||
raise ValueError(f"PDF file not found: {args.pdf_path}")
|
||||
|
||||
# Process PDF file
|
||||
user_opt = {
|
||||
'model': args.model,
|
||||
'toc_check_page_num': args.toc_check_pages,
|
||||
'max_page_num_each_node': args.max_pages_per_node,
|
||||
'max_token_num_each_node': args.max_tokens_per_node,
|
||||
'if_add_node_id': args.if_add_node_id,
|
||||
'if_add_node_summary': args.if_add_node_summary,
|
||||
'if_add_doc_description': args.if_add_doc_description,
|
||||
'if_add_node_text': args.if_add_node_text,
|
||||
}
|
||||
opt = ConfigLoader().load({k: v for k, v in user_opt.items() if v is not None})
|
||||
|
||||
# Process the PDF
|
||||
toc_with_page_number = page_index_main(args.pdf_path, opt)
|
||||
print('Parsing done, saving to file...')
|
||||
|
||||
|
||||
# Save results
|
||||
pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0]
|
||||
pdf_name = os.path.splitext(os.path.basename(args.pdf_path))[0]
|
||||
output_dir = './results'
|
||||
output_file = f'{output_dir}/{pdf_name}_structure.json'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(toc_with_page_number, f, indent=2)
|
||||
|
||||
|
||||
print(f'Tree structure saved to: {output_file}')
|
||||
|
||||
|
||||
elif args.md_path:
|
||||
# Validate Markdown file
|
||||
if not args.md_path.lower().endswith(('.md', '.markdown')):
|
||||
raise ValueError("Markdown file must have .md or .markdown extension")
|
||||
if not os.path.isfile(args.md_path):
|
||||
raise ValueError(f"Markdown file not found: {args.md_path}")
|
||||
|
||||
|
||||
# Process markdown file
|
||||
print('Processing markdown file...')
|
||||
|
||||
# Process the markdown
|
||||
import asyncio
|
||||
|
||||
# Use ConfigLoader to get consistent defaults (matching PDF behavior)
|
||||
from pageindex.utils import ConfigLoader
|
||||
config_loader = ConfigLoader()
|
||||
|
||||
# Create options dict with user args
|
||||
user_opt = {
|
||||
'model': args.model,
|
||||
'if_add_node_summary': args.if_add_node_summary,
|
||||
'if_add_doc_description': args.if_add_doc_description,
|
||||
'if_add_node_text': args.if_add_node_text,
|
||||
'if_add_node_id': args.if_add_node_id
|
||||
}
|
||||
|
||||
# Load config with defaults from config.yaml
|
||||
opt = config_loader.load(user_opt)
|
||||
|
||||
|
||||
toc_with_page_number = asyncio.run(md_to_tree(
|
||||
md_path=args.md_path,
|
||||
if_thinning=args.if_thinning.lower() == 'yes',
|
||||
|
|
@ -119,16 +103,16 @@ if __name__ == "__main__":
|
|||
if_add_node_text=opt.if_add_node_text,
|
||||
if_add_node_id=opt.if_add_node_id
|
||||
))
|
||||
|
||||
|
||||
print('Parsing done, saving to file...')
|
||||
|
||||
|
||||
# Save results
|
||||
md_name = os.path.splitext(os.path.basename(args.md_path))[0]
|
||||
md_name = os.path.splitext(os.path.basename(args.md_path))[0]
|
||||
output_dir = './results'
|
||||
output_file = f'{output_dir}/{md_name}_structure.json'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f'Tree structure saved to: {output_file}')
|
||||
|
||||
print(f'Tree structure saved to: {output_file}')
|
||||
|
|
|
|||
14
tests/test_agent.py
Normal file
14
tests/test_agent.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from pageindex.agent import AgentRunner, SYSTEM_PROMPT
|
||||
from pageindex.backend.protocol import AgentTools
|
||||
|
||||
|
||||
def test_agent_runner_init():
|
||||
tools = AgentTools(function_tools=["mock_tool"])
|
||||
runner = AgentRunner(tools=tools, model="gpt-4o")
|
||||
assert runner._model == "gpt-4o"
|
||||
|
||||
|
||||
def test_system_prompt_has_tool_instructions():
|
||||
assert "list_documents" in SYSTEM_PROMPT
|
||||
assert "get_document_structure" in SYSTEM_PROMPT
|
||||
assert "get_page_content" in SYSTEM_PROMPT
|
||||
51
tests/test_client.py
Normal file
51
tests/test_client.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# tests/sdk/test_client.py
|
||||
import pytest
|
||||
from pageindex.client import PageIndexClient, LocalClient, CloudClient
|
||||
|
||||
|
||||
def test_local_client_is_pageindex_client(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
assert isinstance(client, PageIndexClient)
|
||||
|
||||
|
||||
def test_cloud_client_is_pageindex_client():
|
||||
client = CloudClient(api_key="pi-test")
|
||||
assert isinstance(client, PageIndexClient)
|
||||
|
||||
|
||||
def test_collection_default_name(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
col = client.collection()
|
||||
assert col.name == "default"
|
||||
|
||||
|
||||
def test_collection_custom_name(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
col = client.collection("papers")
|
||||
assert col.name == "papers"
|
||||
|
||||
|
||||
def test_list_collections_empty(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
assert client.list_collections() == []
|
||||
|
||||
|
||||
def test_list_collections_after_create(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
client.collection("papers")
|
||||
assert "papers" in client.list_collections()
|
||||
|
||||
|
||||
def test_delete_collection(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
client.collection("papers")
|
||||
client.delete_collection("papers")
|
||||
assert "papers" not in client.list_collections()
|
||||
|
||||
|
||||
def test_register_parser(tmp_path):
|
||||
client = LocalClient(model="gpt-4o", storage_path=str(tmp_path / "pi"))
|
||||
class FakeParser:
|
||||
def supported_extensions(self): return [".txt"]
|
||||
def parse(self, file_path, **kwargs): pass
|
||||
client.register_parser(FakeParser())
|
||||
16
tests/test_cloud_backend.py
Normal file
16
tests/test_cloud_backend.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from pageindex.backend.cloud import CloudBackend, API_BASE
|
||||
|
||||
|
||||
def test_cloud_backend_init():
|
||||
backend = CloudBackend(api_key="pi-test")
|
||||
assert backend._api_key == "pi-test"
|
||||
assert backend._headers["api_key"] == "pi-test"
|
||||
|
||||
|
||||
def test_api_base_url():
|
||||
assert "pageindex.ai" in API_BASE
|
||||
|
||||
|
||||
def test_get_retrieve_model_is_none():
|
||||
backend = CloudBackend(api_key="pi-test")
|
||||
assert backend.get_agent_tools("col").function_tools == []
|
||||
41
tests/test_collection.py
Normal file
41
tests/test_collection.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# tests/sdk/test_collection.py
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from pageindex.collection import Collection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def col():
|
||||
backend = MagicMock()
|
||||
backend.list_documents.return_value = [
|
||||
{"doc_id": "d1", "doc_name": "paper.pdf", "doc_type": "pdf"}
|
||||
]
|
||||
backend.get_document.return_value = {"doc_id": "d1", "doc_name": "paper.pdf"}
|
||||
backend.add_document.return_value = "d1"
|
||||
return Collection(name="papers", backend=backend)
|
||||
|
||||
|
||||
def test_add(col):
|
||||
doc_id = col.add("paper.pdf")
|
||||
assert doc_id == "d1"
|
||||
col._backend.add_document.assert_called_once_with("papers", "paper.pdf")
|
||||
|
||||
|
||||
def test_list_documents(col):
|
||||
docs = col.list_documents()
|
||||
assert len(docs) == 1
|
||||
assert docs[0]["doc_id"] == "d1"
|
||||
|
||||
|
||||
def test_get_document(col):
|
||||
doc = col.get_document("d1")
|
||||
assert doc["doc_name"] == "paper.pdf"
|
||||
|
||||
|
||||
def test_delete_document(col):
|
||||
col.delete_document("d1")
|
||||
col._backend.delete_document.assert_called_once_with("papers", "d1")
|
||||
|
||||
|
||||
def test_name_property(col):
|
||||
assert col.name == "papers"
|
||||
28
tests/test_config.py
Normal file
28
tests/test_config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# tests/test_config.py
|
||||
import pytest
|
||||
from pageindex.config import IndexConfig
|
||||
|
||||
|
||||
def test_defaults():
|
||||
config = IndexConfig()
|
||||
assert config.model == "gpt-4o-2024-11-20"
|
||||
assert config.retrieve_model is None
|
||||
assert config.toc_check_page_num == 20
|
||||
|
||||
|
||||
def test_overrides():
|
||||
config = IndexConfig(model="gpt-5.4", retrieve_model="claude-sonnet")
|
||||
assert config.model == "gpt-5.4"
|
||||
assert config.retrieve_model == "claude-sonnet"
|
||||
|
||||
|
||||
def test_unknown_key_raises():
|
||||
with pytest.raises(Exception):
|
||||
IndexConfig(nonexistent_key="value")
|
||||
|
||||
|
||||
def test_model_copy_with_update():
|
||||
config = IndexConfig(toc_check_page_num=30)
|
||||
updated = config.model_copy(update={"model": "gpt-5.4"})
|
||||
assert updated.model == "gpt-5.4"
|
||||
assert updated.toc_check_page_num == 30
|
||||
45
tests/test_content_node.py
Normal file
45
tests/test_content_node.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from pageindex.parser.protocol import ContentNode, ParsedDocument, DocumentParser
|
||||
|
||||
|
||||
def test_content_node_required_fields():
|
||||
node = ContentNode(content="hello", tokens=5)
|
||||
assert node.content == "hello"
|
||||
assert node.tokens == 5
|
||||
assert node.title is None
|
||||
assert node.index is None
|
||||
assert node.level is None
|
||||
|
||||
|
||||
def test_content_node_all_fields():
|
||||
node = ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1)
|
||||
assert node.title == "Intro"
|
||||
assert node.index == 1
|
||||
assert node.level == 1
|
||||
|
||||
|
||||
def test_parsed_document():
|
||||
nodes = [ContentNode(content="page1", tokens=100, index=1)]
|
||||
doc = ParsedDocument(doc_name="test.pdf", nodes=nodes)
|
||||
assert doc.doc_name == "test.pdf"
|
||||
assert len(doc.nodes) == 1
|
||||
assert doc.metadata is None
|
||||
|
||||
|
||||
def test_parsed_document_with_metadata():
|
||||
nodes = [ContentNode(content="page1", tokens=100)]
|
||||
doc = ParsedDocument(doc_name="test.pdf", nodes=nodes, metadata={"author": "John"})
|
||||
assert doc.metadata["author"] == "John"
|
||||
|
||||
|
||||
def test_document_parser_protocol():
|
||||
"""Verify a class implementing DocumentParser is structurally compatible."""
|
||||
class MyParser:
|
||||
def supported_extensions(self) -> list[str]:
|
||||
return [".txt"]
|
||||
def parse(self, file_path: str, **kwargs) -> ParsedDocument:
|
||||
return ParsedDocument(doc_name="test", nodes=[])
|
||||
|
||||
parser = MyParser()
|
||||
assert parser.supported_extensions() == [".txt"]
|
||||
result = parser.parse("test.txt")
|
||||
assert isinstance(result, ParsedDocument)
|
||||
27
tests/test_errors.py
Normal file
27
tests/test_errors.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from pageindex.errors import (
|
||||
PageIndexError,
|
||||
CollectionNotFoundError,
|
||||
DocumentNotFoundError,
|
||||
IndexingError,
|
||||
CloudAPIError,
|
||||
FileTypeError,
|
||||
)
|
||||
|
||||
|
||||
def test_all_errors_inherit_from_base():
|
||||
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
assert issubclass(cls, PageIndexError)
|
||||
assert issubclass(cls, Exception)
|
||||
|
||||
|
||||
def test_error_message():
|
||||
err = FileTypeError("Unsupported: .docx")
|
||||
assert str(err) == "Unsupported: .docx"
|
||||
|
||||
|
||||
def test_catch_base_catches_all():
|
||||
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
try:
|
||||
raise cls("test")
|
||||
except PageIndexError:
|
||||
pass # expected
|
||||
26
tests/test_events.py
Normal file
26
tests/test_events.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from pageindex.events import QueryEvent
|
||||
from pageindex.backend.protocol import AgentTools
|
||||
|
||||
|
||||
def test_query_event():
|
||||
event = QueryEvent(type="answer_delta", data="hello")
|
||||
assert event.type == "answer_delta"
|
||||
assert event.data == "hello"
|
||||
|
||||
|
||||
def test_query_event_types():
|
||||
for t in ["reasoning", "tool_call", "tool_result", "answer_delta", "answer_done"]:
|
||||
event = QueryEvent(type=t, data="test")
|
||||
assert event.type == t
|
||||
|
||||
|
||||
def test_agent_tools_default_empty():
|
||||
tools = AgentTools()
|
||||
assert tools.function_tools == []
|
||||
assert tools.mcp_servers == []
|
||||
|
||||
|
||||
def test_agent_tools_with_values():
|
||||
tools = AgentTools(function_tools=["tool1"], mcp_servers=["server1"])
|
||||
assert len(tools.function_tools) == 1
|
||||
assert len(tools.mcp_servers) == 1
|
||||
50
tests/test_local_backend.py
Normal file
50
tests/test_local_backend.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# tests/sdk/test_local_backend.py
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from pageindex.backend.local import LocalBackend
|
||||
from pageindex.storage.sqlite import SQLiteStorage
|
||||
from pageindex.errors import FileTypeError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend(tmp_path):
|
||||
storage = SQLiteStorage(str(tmp_path / "test.db"))
|
||||
files_dir = tmp_path / "files"
|
||||
return LocalBackend(storage=storage, files_dir=str(files_dir), model="gpt-4o")
|
||||
|
||||
|
||||
def test_collection_lifecycle(backend):
|
||||
backend.get_or_create_collection("papers")
|
||||
assert "papers" in backend.list_collections()
|
||||
backend.delete_collection("papers")
|
||||
assert "papers" not in backend.list_collections()
|
||||
|
||||
|
||||
def test_list_documents_empty(backend):
|
||||
backend.get_or_create_collection("papers")
|
||||
assert backend.list_documents("papers") == []
|
||||
|
||||
|
||||
def test_unsupported_file_type_raises(backend, tmp_path):
|
||||
backend.get_or_create_collection("papers")
|
||||
bad_file = tmp_path / "test.xyz"
|
||||
bad_file.write_text("hello")
|
||||
with pytest.raises(FileTypeError):
|
||||
backend.add_document("papers", str(bad_file))
|
||||
|
||||
|
||||
def test_register_custom_parser(backend):
|
||||
from pageindex.parser.protocol import ParsedDocument, ContentNode
|
||||
|
||||
class TxtParser:
|
||||
def supported_extensions(self):
|
||||
return [".txt"]
|
||||
def parse(self, file_path, **kwargs):
|
||||
text = Path(file_path).read_text()
|
||||
return ParsedDocument(doc_name="test", nodes=[
|
||||
ContentNode(content=text, tokens=len(text.split()), title="Content", index=1, level=1)
|
||||
])
|
||||
|
||||
backend.register_parser(TxtParser())
|
||||
# Now .txt should be supported (won't raise FileTypeError)
|
||||
assert backend._resolve_parser("test.txt") is not None
|
||||
55
tests/test_markdown_parser.py
Normal file
55
tests/test_markdown_parser.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
from pageindex.parser.markdown import MarkdownParser
|
||||
from pageindex.parser.protocol import ContentNode, ParsedDocument
|
||||
|
||||
@pytest.fixture
|
||||
def sample_md(tmp_path):
|
||||
md = tmp_path / "test.md"
|
||||
md.write_text("""# Chapter 1
|
||||
Some intro text.
|
||||
|
||||
## Section 1.1
|
||||
Details here.
|
||||
|
||||
## Section 1.2
|
||||
More details.
|
||||
|
||||
# Chapter 2
|
||||
Another chapter.
|
||||
""")
|
||||
return str(md)
|
||||
|
||||
def test_supported_extensions():
|
||||
parser = MarkdownParser()
|
||||
exts = parser.supported_extensions()
|
||||
assert ".md" in exts
|
||||
assert ".markdown" in exts
|
||||
|
||||
def test_parse_returns_parsed_document(sample_md):
|
||||
parser = MarkdownParser()
|
||||
result = parser.parse(sample_md)
|
||||
assert isinstance(result, ParsedDocument)
|
||||
assert result.doc_name == "test"
|
||||
|
||||
def test_parse_nodes_have_level(sample_md):
|
||||
parser = MarkdownParser()
|
||||
result = parser.parse(sample_md)
|
||||
assert len(result.nodes) == 4
|
||||
assert result.nodes[0].level == 1
|
||||
assert result.nodes[0].title == "Chapter 1"
|
||||
assert result.nodes[1].level == 2
|
||||
assert result.nodes[1].title == "Section 1.1"
|
||||
assert result.nodes[3].level == 1
|
||||
|
||||
def test_parse_nodes_have_content(sample_md):
|
||||
parser = MarkdownParser()
|
||||
result = parser.parse(sample_md)
|
||||
assert "Some intro text" in result.nodes[0].content
|
||||
assert "Details here" in result.nodes[1].content
|
||||
|
||||
def test_parse_nodes_have_index(sample_md):
|
||||
parser = MarkdownParser()
|
||||
result = parser.parse(sample_md)
|
||||
for node in result.nodes:
|
||||
assert node.index is not None
|
||||
29
tests/test_pdf_parser.py
Normal file
29
tests/test_pdf_parser.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
from pathlib import Path
|
||||
from pageindex.parser.pdf import PdfParser
|
||||
from pageindex.parser.protocol import ContentNode, ParsedDocument
|
||||
|
||||
TEST_PDF = Path("tests/pdfs/deepseek-r1.pdf")
|
||||
|
||||
def test_supported_extensions():
|
||||
parser = PdfParser()
|
||||
assert ".pdf" in parser.supported_extensions()
|
||||
|
||||
@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available")
|
||||
def test_parse_returns_parsed_document():
|
||||
parser = PdfParser()
|
||||
result = parser.parse(str(TEST_PDF))
|
||||
assert isinstance(result, ParsedDocument)
|
||||
assert len(result.nodes) > 0
|
||||
assert result.doc_name != ""
|
||||
|
||||
@pytest.mark.skipif(not TEST_PDF.exists(), reason="Test PDF not available")
|
||||
def test_parse_nodes_are_flat_without_level():
|
||||
parser = PdfParser()
|
||||
result = parser.parse(str(TEST_PDF))
|
||||
for node in result.nodes:
|
||||
assert isinstance(node, ContentNode)
|
||||
assert node.content is not None
|
||||
assert node.tokens >= 0
|
||||
assert node.index is not None
|
||||
assert node.level is None
|
||||
95
tests/test_pipeline.py
Normal file
95
tests/test_pipeline.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# tests/sdk/test_pipeline.py
|
||||
import asyncio
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
from pageindex.parser.protocol import ContentNode, ParsedDocument
|
||||
from pageindex.index.pipeline import (
|
||||
detect_strategy, build_tree_from_levels, build_index,
|
||||
_content_based_pipeline, _NullLogger,
|
||||
)
|
||||
|
||||
|
||||
def test_detect_strategy_with_level():
|
||||
nodes = [
|
||||
ContentNode(content="# Intro", tokens=10, title="Intro", index=1, level=1),
|
||||
ContentNode(content="## Details", tokens=10, title="Details", index=5, level=2),
|
||||
]
|
||||
assert detect_strategy(nodes) == "level_based"
|
||||
|
||||
|
||||
def test_detect_strategy_without_level():
|
||||
nodes = [
|
||||
ContentNode(content="Page 1 text", tokens=100, index=1),
|
||||
ContentNode(content="Page 2 text", tokens=100, index=2),
|
||||
]
|
||||
assert detect_strategy(nodes) == "content_based"
|
||||
|
||||
|
||||
def test_build_tree_from_levels():
|
||||
nodes = [
|
||||
ContentNode(content="ch1 text", tokens=10, title="Chapter 1", index=1, level=1),
|
||||
ContentNode(content="s1.1 text", tokens=10, title="Section 1.1", index=5, level=2),
|
||||
ContentNode(content="s1.2 text", tokens=10, title="Section 1.2", index=10, level=2),
|
||||
ContentNode(content="ch2 text", tokens=10, title="Chapter 2", index=20, level=1),
|
||||
]
|
||||
tree = build_tree_from_levels(nodes)
|
||||
assert len(tree) == 2 # 2 root nodes (chapters)
|
||||
assert tree[0]["title"] == "Chapter 1"
|
||||
assert len(tree[0]["nodes"]) == 2 # 2 sections under chapter 1
|
||||
assert tree[0]["nodes"][0]["title"] == "Section 1.1"
|
||||
assert tree[0]["nodes"][1]["title"] == "Section 1.2"
|
||||
assert tree[1]["title"] == "Chapter 2"
|
||||
assert len(tree[1]["nodes"]) == 0
|
||||
|
||||
|
||||
def test_build_tree_from_levels_single_level():
|
||||
nodes = [
|
||||
ContentNode(content="a", tokens=5, title="A", index=1, level=1),
|
||||
ContentNode(content="b", tokens=5, title="B", index=2, level=1),
|
||||
]
|
||||
tree = build_tree_from_levels(nodes)
|
||||
assert len(tree) == 2
|
||||
assert tree[0]["title"] == "A"
|
||||
assert tree[1]["title"] == "B"
|
||||
|
||||
|
||||
def test_build_tree_from_levels_deep_nesting():
|
||||
nodes = [
|
||||
ContentNode(content="h1", tokens=5, title="H1", index=1, level=1),
|
||||
ContentNode(content="h2", tokens=5, title="H2", index=2, level=2),
|
||||
ContentNode(content="h3", tokens=5, title="H3", index=3, level=3),
|
||||
]
|
||||
tree = build_tree_from_levels(nodes)
|
||||
assert len(tree) == 1
|
||||
assert tree[0]["title"] == "H1"
|
||||
assert len(tree[0]["nodes"]) == 1
|
||||
assert tree[0]["nodes"][0]["title"] == "H2"
|
||||
assert len(tree[0]["nodes"][0]["nodes"]) == 1
|
||||
assert tree[0]["nodes"][0]["nodes"][0]["title"] == "H3"
|
||||
|
||||
|
||||
def test_content_based_pipeline_does_not_raise():
|
||||
"""_content_based_pipeline should delegate to tree_parser, not raise NotImplementedError."""
|
||||
fake_tree = [{"title": "Intro", "start_index": 1, "end_index": 2, "nodes": []}]
|
||||
|
||||
async def fake_tree_parser(page_list, opt, doc=None, logger=None):
|
||||
return fake_tree
|
||||
|
||||
page_list = [("Page 1 text", 50), ("Page 2 text", 60)]
|
||||
|
||||
from types import SimpleNamespace
|
||||
opt = SimpleNamespace(model="test-model")
|
||||
|
||||
with patch("pageindex.index.page_index.tree_parser", new=fake_tree_parser):
|
||||
result = asyncio.run(_content_based_pipeline(page_list, opt))
|
||||
|
||||
assert result == fake_tree
|
||||
|
||||
|
||||
def test_null_logger_methods():
|
||||
"""NullLogger should have info/error/debug and not raise."""
|
||||
logger = _NullLogger()
|
||||
logger.info("test message")
|
||||
logger.error("test error")
|
||||
logger.debug("test debug")
|
||||
logger.info({"key": "value"})
|
||||
61
tests/test_sqlite_storage.py
Normal file
61
tests/test_sqlite_storage.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
from pageindex.storage.sqlite import SQLiteStorage
|
||||
|
||||
@pytest.fixture
|
||||
def storage(tmp_path):
|
||||
return SQLiteStorage(str(tmp_path / "test.db"))
|
||||
|
||||
def test_create_and_list_collections(storage):
|
||||
storage.create_collection("papers")
|
||||
assert "papers" in storage.list_collections()
|
||||
|
||||
def test_get_or_create_collection_idempotent(storage):
|
||||
storage.get_or_create_collection("papers")
|
||||
storage.get_or_create_collection("papers")
|
||||
assert storage.list_collections().count("papers") == 1
|
||||
|
||||
def test_delete_collection(storage):
|
||||
storage.create_collection("papers")
|
||||
storage.delete_collection("papers")
|
||||
assert "papers" not in storage.list_collections()
|
||||
|
||||
def test_save_and_get_document(storage):
|
||||
storage.create_collection("papers")
|
||||
doc = {
|
||||
"doc_name": "test.pdf", "doc_description": "A test",
|
||||
"file_path": "/tmp/test.pdf", "doc_type": "pdf",
|
||||
"structure": [{"title": "Intro", "node_id": "0001"}],
|
||||
}
|
||||
storage.save_document("papers", "doc-1", doc)
|
||||
result = storage.get_document("papers", "doc-1")
|
||||
assert result["doc_name"] == "test.pdf"
|
||||
assert result["doc_type"] == "pdf"
|
||||
|
||||
def test_get_document_structure(storage):
|
||||
storage.create_collection("papers")
|
||||
structure = [{"title": "Ch1", "node_id": "0001", "nodes": []}]
|
||||
storage.save_document("papers", "doc-1", {
|
||||
"doc_name": "test.pdf", "doc_type": "pdf",
|
||||
"file_path": "/tmp/test.pdf", "structure": structure,
|
||||
})
|
||||
result = storage.get_document_structure("papers", "doc-1")
|
||||
assert result[0]["title"] == "Ch1"
|
||||
|
||||
def test_list_documents(storage):
|
||||
storage.create_collection("papers")
|
||||
storage.save_document("papers", "doc-1", {"doc_name": "p1.pdf", "doc_type": "pdf", "file_path": "/tmp/p1.pdf", "structure": []})
|
||||
storage.save_document("papers", "doc-2", {"doc_name": "p2.pdf", "doc_type": "pdf", "file_path": "/tmp/p2.pdf", "structure": []})
|
||||
docs = storage.list_documents("papers")
|
||||
assert len(docs) == 2
|
||||
|
||||
def test_delete_document(storage):
|
||||
storage.create_collection("papers")
|
||||
storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []})
|
||||
storage.delete_document("papers", "doc-1")
|
||||
assert len(storage.list_documents("papers")) == 0
|
||||
|
||||
def test_delete_collection_cascades_documents(storage):
|
||||
storage.create_collection("papers")
|
||||
storage.save_document("papers", "doc-1", {"doc_name": "test.pdf", "doc_type": "pdf", "file_path": "/tmp/test.pdf", "structure": []})
|
||||
storage.delete_collection("papers")
|
||||
assert "papers" not in storage.list_collections()
|
||||
19
tests/test_storage_protocol.py
Normal file
19
tests/test_storage_protocol.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from pageindex.storage.protocol import StorageEngine
|
||||
|
||||
def test_storage_engine_is_protocol():
|
||||
class FakeStorage:
|
||||
def create_collection(self, name: str) -> None: pass
|
||||
def get_or_create_collection(self, name: str) -> None: pass
|
||||
def list_collections(self) -> list[str]: return []
|
||||
def delete_collection(self, name: str) -> None: pass
|
||||
def save_document(self, collection: str, doc_id: str, doc: dict) -> None: pass
|
||||
def find_document_by_hash(self, collection: str, file_hash: str) -> str | None: return None
|
||||
def get_document(self, collection: str, doc_id: str) -> dict: return {}
|
||||
def get_document_structure(self, collection: str, doc_id: str) -> dict: return {}
|
||||
def get_pages(self, collection: str, doc_id: str) -> list | None: return None
|
||||
def list_documents(self, collection: str) -> list[dict]: return []
|
||||
def delete_document(self, collection: str, doc_id: str) -> None: pass
|
||||
def close(self) -> None: pass
|
||||
|
||||
storage = FakeStorage()
|
||||
assert isinstance(storage, StorageEngine)
|
||||
Loading…
Add table
Add a link
Reference in a new issue