mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-05-19 18:35:16 +02:00
feat(collection): scoped query mode and experimental multi-doc warning
- get_agent_tools branches on doc_ids:
- scoped (doc_ids=[...]): drops list_documents and hard-enforces a
whitelist on the remaining tools; system prompt switches to
SCOPED_SYSTEM_PROMPT (no list_documents instruction); doc list +
summaries are prepended to the user message via wrap_with_doc_context.
- open (doc_ids=None): unchanged 4-tool agent loop.
- list_documents now exposes doc_description (sqlite + cloud).
- Collection.query emits UserWarning when doc_ids is None and the
collection holds >1 documents; PAGEINDEX_EXPERIMENTAL_MULTIDOC=1
silences it. Single-doc collections skip the warning; empty
collections raise ValueError.
- Agents SDK tracing upload disabled by default (avoids SSL timeouts);
PAGEINDEX_AGENTS_TRACING=1 re-enables it.
- README: new SDK Usage section covering local/cloud quick start,
streaming, multi-doc as experimental, and runnable examples.
This commit is contained in:
parent
cbea31d1a2
commit
d7b36aaf3f
8 changed files with 348 additions and 25 deletions
72
README.md
72
README.md
|
|
@ -139,6 +139,78 @@ You can generate the PageIndex tree structure with this open-source repo, or use
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
# 🚀 SDK Usage
|
||||||
|
|
||||||
|
A unified `PageIndexClient` powers both local self-hosted and cloud-managed modes. Mode is auto-detected by whether you pass an `api_key`.
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pageindex
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quick start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pageindex import PageIndexClient
|
||||||
|
|
||||||
|
# Local mode — uses your LLM key (e.g. OPENAI_API_KEY in env).
|
||||||
|
client = PageIndexClient(model="gpt-4o-2024-11-20")
|
||||||
|
|
||||||
|
col = client.collection()
|
||||||
|
doc_id = col.add("path/to/your.pdf")
|
||||||
|
|
||||||
|
print(col.query("What is the main contribution?", doc_ids=[doc_id]))
|
||||||
|
|
||||||
|
# Cloud mode — fully managed, no LLM key needed:
|
||||||
|
# client = PageIndexClient(api_key="your-pageindex-api-key")
|
||||||
|
```
|
||||||
|
|
||||||
|
`col.query(...)` returns the answer string by default. Always pass `doc_ids` for reliable single-document QA — omitting it queries the entire collection, which is experimental (see below).
|
||||||
|
|
||||||
|
### Streaming queries
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async for ev in col.query("Explain multi-head attention", stream=True):
|
||||||
|
if ev.type == "answer_delta":
|
||||||
|
print(ev.data, end="", flush=True)
|
||||||
|
elif ev.type == "tool_call":
|
||||||
|
print(f"\n[tool] {ev.data['name']}")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
`ev.type` is one of: `tool_call`, `tool_result`, `answer_delta`, `answer_done`.
|
||||||
|
|
||||||
|
### Multi-document collections (experimental)
|
||||||
|
|
||||||
|
Passing `doc_ids` scopes the query to a specific subset of documents — this is the recommended path:
|
||||||
|
|
||||||
|
```python
|
||||||
|
col.query("Compare these two papers", doc_ids=[doc1, doc2])
|
||||||
|
```
|
||||||
|
|
||||||
|
Omitting `doc_ids` queries the **entire collection** and lets the agent pick which docs to read. This is an **experimental** feature with a naive first implementation — we're actively working on better cross-document retrieval. A `UserWarning` is emitted; set `PAGEINDEX_EXPERIMENTAL_MULTIDOC=1` to silence it.
|
||||||
|
|
||||||
|
### Environment variables
|
||||||
|
|
||||||
|
| Variable | Effect |
|
||||||
|
|---|---|
|
||||||
|
| `OPENAI_API_KEY` (or any LiteLLM `<PROVIDER>_API_KEY`) | LLM provider key — local mode |
|
||||||
|
| `PAGEINDEX_API_KEY` | PageIndex cloud key — cloud mode |
|
||||||
|
| `PAGEINDEX_EXPERIMENTAL_MULTIDOC` | Set to `1` to silence the warning when calling `col.query(...)` without `doc_ids` |
|
||||||
|
|
||||||
|
### Runnable examples
|
||||||
|
|
||||||
|
- [`examples/local_demo.py`](examples/local_demo.py) — local mode end-to-end (index a PDF + streaming QA)
|
||||||
|
- [`examples/cloud_demo.py`](examples/cloud_demo.py) — cloud mode end-to-end
|
||||||
|
- [`examples/agentic_vectorless_rag_demo.py`](examples/agentic_vectorless_rag_demo.py) — lower-level integration with the OpenAI Agents SDK
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# ⚙️ Package Usage
|
# ⚙️ Package Usage
|
||||||
|
|
||||||
You can follow these steps to generate a PageIndex tree from a PDF document.
|
You can follow these steps to generate a PageIndex tree from a PDF document.
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,25 @@
|
||||||
# pageindex/agent.py
|
# pageindex/agent.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
from .events import QueryEvent
|
from .events import QueryEvent
|
||||||
from .backend.protocol import AgentTools
|
from .backend.protocol import AgentTools
|
||||||
|
|
||||||
|
# Disable Agents SDK tracing upload by default — it posts to OpenAI's tracing
|
||||||
|
# endpoint and can fail with SSL timeouts in restricted networks. Opt back in
|
||||||
|
# with PAGEINDEX_AGENTS_TRACING=1.
|
||||||
|
if os.getenv("PAGEINDEX_AGENTS_TRACING", "").lower() not in ("1", "true", "yes"):
|
||||||
|
try:
|
||||||
|
from agents import set_tracing_disabled
|
||||||
|
set_tracing_disabled(True)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
SYSTEM_PROMPT = """
|
|
||||||
|
OPEN_SYSTEM_PROMPT = """
|
||||||
You are PageIndex, a document QA assistant.
|
You are PageIndex, a document QA assistant.
|
||||||
TOOL USE:
|
TOOL USE:
|
||||||
- Call list_documents() to see available documents.
|
- Call list_documents() to see available documents; use doc_name and doc_description to pick which doc(s) are relevant.
|
||||||
- Call get_document(doc_id) to confirm status and page/line count.
|
- Call get_document(doc_id) to confirm status and page/line count.
|
||||||
- Call get_document_structure(doc_id) to identify relevant page ranges.
|
- Call get_document_structure(doc_id) to identify relevant page ranges.
|
||||||
- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document.
|
- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document.
|
||||||
|
|
@ -19,6 +30,42 @@ IMAGES:
|
||||||
Answer based only on tool output. Be concise.
|
Answer based only on tool output. Be concise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SCOPED_SYSTEM_PROMPT = """
|
||||||
|
You are PageIndex, a document QA assistant.
|
||||||
|
TOOL USE:
|
||||||
|
- Call get_document(doc_id) to confirm status and page/line count.
|
||||||
|
- Call get_document_structure(doc_id) to identify relevant page ranges.
|
||||||
|
- Call get_page_content(doc_id, pages="5-7") with tight ranges; never fetch the whole document.
|
||||||
|
- Before each tool call, output one short sentence explaining the reason.
|
||||||
|
IMAGES:
|
||||||
|
- Page content may contain image references like . Always preserve these in your answer so the downstream UI can render them.
|
||||||
|
- Place images near the relevant context in your answer.
|
||||||
|
Answer based only on tool output. Be concise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_with_doc_context(docs: list[dict], question: str) -> str:
|
||||||
|
"""Prepend a doc-context block to the user question for scoped queries."""
|
||||||
|
lines = []
|
||||||
|
for d in docs:
|
||||||
|
line = f"- {d['doc_id']}: {d.get('doc_name', '')}"
|
||||||
|
desc = d.get("doc_description") or ""
|
||||||
|
if desc:
|
||||||
|
line += f" — {desc}"
|
||||||
|
lines.append(line)
|
||||||
|
label = "document" if len(docs) == 1 else "documents"
|
||||||
|
return (
|
||||||
|
f"The user has specified the following {label}:\n"
|
||||||
|
+ "\n".join(lines)
|
||||||
|
+ f"\n\nUse the doc_id(s) above directly with get_document_structure() "
|
||||||
|
f"and get_page_content() — do not look for other documents.\n\n"
|
||||||
|
f"User question: {question}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards-compatible alias (open mode is the historical default).
|
||||||
|
SYSTEM_PROMPT = OPEN_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class QueryStream:
|
class QueryStream:
|
||||||
"""Streaming query result, similar to OpenAI's RunResultStreaming.
|
"""Streaming query result, similar to OpenAI's RunResultStreaming.
|
||||||
|
|
@ -30,12 +77,13 @@ class QueryStream:
|
||||||
print(event.data, end="", flush=True)
|
print(event.data, end="", flush=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tools: AgentTools, question: str, model: str = None):
|
def __init__(self, tools: AgentTools, question: str, model: str = None,
|
||||||
|
instructions: str | None = None):
|
||||||
from agents import Agent
|
from agents import Agent
|
||||||
from agents.model_settings import ModelSettings
|
from agents.model_settings import ModelSettings
|
||||||
self._agent = Agent(
|
self._agent = Agent(
|
||||||
name="PageIndex",
|
name="PageIndex",
|
||||||
instructions=SYSTEM_PROMPT,
|
instructions=instructions or OPEN_SYSTEM_PROMPT,
|
||||||
tools=tools.function_tools,
|
tools=tools.function_tools,
|
||||||
mcp_servers=tools.mcp_servers,
|
mcp_servers=tools.mcp_servers,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -73,9 +121,11 @@ class QueryStream:
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
def __init__(self, tools: AgentTools, model: str = None):
|
def __init__(self, tools: AgentTools, model: str = None,
|
||||||
|
instructions: str | None = None):
|
||||||
self._tools = tools
|
self._tools = tools
|
||||||
self._model = model
|
self._model = model
|
||||||
|
self._instructions = instructions or OPEN_SYSTEM_PROMPT
|
||||||
|
|
||||||
def run(self, question: str) -> str:
|
def run(self, question: str) -> str:
|
||||||
"""Sync non-streaming query. Returns answer string."""
|
"""Sync non-streaming query. Returns answer string."""
|
||||||
|
|
@ -83,7 +133,7 @@ class AgentRunner:
|
||||||
from agents.model_settings import ModelSettings
|
from agents.model_settings import ModelSettings
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name="PageIndex",
|
name="PageIndex",
|
||||||
instructions=SYSTEM_PROMPT,
|
instructions=self._instructions,
|
||||||
tools=self._tools.function_tools,
|
tools=self._tools.function_tools,
|
||||||
mcp_servers=self._tools.mcp_servers,
|
mcp_servers=self._tools.mcp_servers,
|
||||||
model=self._model,
|
model=self._model,
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,12 @@ class CloudBackend:
|
||||||
params["folder_id"] = folder_id
|
params["folder_id"] = folder_id
|
||||||
data = self._request("GET", "/docs/", params=params)
|
data = self._request("GET", "/docs/", params=params)
|
||||||
return [
|
return [
|
||||||
{"doc_id": d.get("id", ""), "doc_name": d.get("name", ""), "doc_type": "pdf"}
|
{
|
||||||
|
"doc_id": d.get("id", ""),
|
||||||
|
"doc_name": d.get("name", ""),
|
||||||
|
"doc_description": d.get("description", ""),
|
||||||
|
"doc_type": "pdf",
|
||||||
|
}
|
||||||
for d in data.get("documents", [])
|
for d in data.get("documents", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -197,49 +197,95 @@ class LocalBackend:
|
||||||
self._storage.delete_document(collection, doc_id)
|
self._storage.delete_document(collection, doc_id)
|
||||||
|
|
||||||
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
|
def get_agent_tools(self, collection: str, doc_ids: list[str] | None = None) -> AgentTools:
|
||||||
|
"""Build agent tools.
|
||||||
|
|
||||||
|
- doc_ids=None (open mode): includes ``list_documents``; agent picks docs itself.
|
||||||
|
- doc_ids=[...] (scoped mode): no ``list_documents``; the other tools
|
||||||
|
hard-enforce the whitelist and reject out-of-scope doc_ids.
|
||||||
|
"""
|
||||||
from agents import function_tool
|
from agents import function_tool
|
||||||
import json
|
import json
|
||||||
storage = self._storage
|
storage = self._storage
|
||||||
col_name = collection
|
col_name = collection
|
||||||
backend = self
|
backend = self
|
||||||
filter_ids = doc_ids
|
scope = set(doc_ids) if doc_ids else None
|
||||||
|
|
||||||
@function_tool
|
def _reject(doc_id: str) -> str | None:
|
||||||
def list_documents() -> str:
|
if scope is not None and doc_id not in scope:
|
||||||
"""List all documents in the collection."""
|
return json.dumps({
|
||||||
docs = storage.list_documents(col_name)
|
"error": f"doc_id '{doc_id}' is not in scope.",
|
||||||
if filter_ids:
|
"allowed_doc_ids": sorted(scope),
|
||||||
docs = [d for d in docs if d["doc_id"] in filter_ids]
|
})
|
||||||
return json.dumps(docs)
|
return None
|
||||||
|
|
||||||
@function_tool
|
@function_tool
|
||||||
def get_document(doc_id: str) -> str:
|
def get_document(doc_id: str) -> str:
|
||||||
"""Get document metadata."""
|
"""Get document metadata."""
|
||||||
|
rejection = _reject(doc_id)
|
||||||
|
if rejection:
|
||||||
|
return rejection
|
||||||
return json.dumps(storage.get_document(col_name, doc_id))
|
return json.dumps(storage.get_document(col_name, doc_id))
|
||||||
|
|
||||||
@function_tool
|
@function_tool
|
||||||
def get_document_structure(doc_id: str) -> str:
|
def get_document_structure(doc_id: str) -> str:
|
||||||
"""Get document tree structure (without text)."""
|
"""Get document tree structure (without text)."""
|
||||||
|
rejection = _reject(doc_id)
|
||||||
|
if rejection:
|
||||||
|
return rejection
|
||||||
structure = storage.get_document_structure(col_name, doc_id)
|
structure = storage.get_document_structure(col_name, doc_id)
|
||||||
return json.dumps(remove_fields(structure, fields=["text"]), ensure_ascii=False)
|
return json.dumps(remove_fields(structure, fields=["text"]), ensure_ascii=False)
|
||||||
|
|
||||||
@function_tool
|
@function_tool
|
||||||
def get_page_content(doc_id: str, pages: str) -> str:
|
def get_page_content(doc_id: str, pages: str) -> str:
|
||||||
"""Get page content. Use tight ranges: '5-7', '3,8', '12'."""
|
"""Get page content. Use tight ranges: '5-7', '3,8', '12'."""
|
||||||
|
rejection = _reject(doc_id)
|
||||||
|
if rejection:
|
||||||
|
return rejection
|
||||||
result = backend.get_page_content(col_name, doc_id, pages)
|
result = backend.get_page_content(col_name, doc_id, pages)
|
||||||
return json.dumps(result, ensure_ascii=False)
|
return json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
return AgentTools(function_tools=[list_documents, get_document, get_document_structure, get_page_content])
|
tools = [get_document, get_document_structure, get_page_content]
|
||||||
|
|
||||||
|
if scope is None:
|
||||||
|
@function_tool
|
||||||
|
def list_documents() -> str:
|
||||||
|
"""List all documents in the collection."""
|
||||||
|
return json.dumps(storage.list_documents(col_name))
|
||||||
|
tools.insert(0, list_documents)
|
||||||
|
|
||||||
|
return AgentTools(function_tools=tools)
|
||||||
|
|
||||||
|
def _scoped_docs(self, collection: str, doc_ids: list[str]) -> list[dict]:
|
||||||
|
"""Fetch metadata for the docs in scope; raise if any are missing."""
|
||||||
|
by_id = {d["doc_id"]: d for d in self._storage.list_documents(collection)}
|
||||||
|
missing = [did for did in doc_ids if did not in by_id]
|
||||||
|
if missing:
|
||||||
|
raise DocumentNotFoundError(
|
||||||
|
f"doc_ids not found in collection '{collection}': {missing}"
|
||||||
|
)
|
||||||
|
return [by_id[did] for did in doc_ids]
|
||||||
|
|
||||||
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
|
def query(self, collection: str, question: str, doc_ids: list[str] | None = None) -> str:
|
||||||
from ..agent import AgentRunner
|
from ..agent import AgentRunner, SCOPED_SYSTEM_PROMPT, wrap_with_doc_context
|
||||||
tools = self.get_agent_tools(collection, doc_ids)
|
tools = self.get_agent_tools(collection, doc_ids)
|
||||||
return AgentRunner(tools=tools, model=self._retrieve_model).run(question)
|
instructions = None
|
||||||
|
if doc_ids:
|
||||||
|
docs = self._scoped_docs(collection, doc_ids)
|
||||||
|
question = wrap_with_doc_context(docs, question)
|
||||||
|
instructions = SCOPED_SYSTEM_PROMPT
|
||||||
|
return AgentRunner(tools=tools, model=self._retrieve_model,
|
||||||
|
instructions=instructions).run(question)
|
||||||
|
|
||||||
async def query_stream(self, collection: str, question: str,
|
async def query_stream(self, collection: str, question: str,
|
||||||
doc_ids: list[str] | None = None):
|
doc_ids: list[str] | None = None):
|
||||||
from ..agent import QueryStream
|
from ..agent import QueryStream, SCOPED_SYSTEM_PROMPT, wrap_with_doc_context
|
||||||
tools = self.get_agent_tools(collection, doc_ids)
|
tools = self.get_agent_tools(collection, doc_ids)
|
||||||
stream = QueryStream(tools=tools, question=question, model=self._retrieve_model)
|
instructions = None
|
||||||
|
if doc_ids:
|
||||||
|
docs = self._scoped_docs(collection, doc_ids)
|
||||||
|
question = wrap_with_doc_context(docs, question)
|
||||||
|
instructions = SCOPED_SYSTEM_PROMPT
|
||||||
|
stream = QueryStream(tools=tools, question=question,
|
||||||
|
model=self._retrieve_model, instructions=instructions)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
yield event
|
yield event
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,24 @@
|
||||||
# pageindex/collection.py
|
# pageindex/collection.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
from .events import QueryEvent
|
from .events import QueryEvent
|
||||||
from .backend.protocol import Backend
|
from .backend.protocol import Backend
|
||||||
|
|
||||||
|
|
||||||
|
def _multidoc_acked() -> bool:
|
||||||
|
return os.getenv("PAGEINDEX_EXPERIMENTAL_MULTIDOC", "").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
|
|
||||||
|
_MULTIDOC_WARNING = (
|
||||||
|
"Querying the entire collection (no doc_ids) is experimental — selection "
|
||||||
|
"accuracy depends on auto-generated doc descriptions. Pass doc_ids=[...] "
|
||||||
|
"for reliable results, or set PAGEINDEX_EXPERIMENTAL_MULTIDOC=1 to silence "
|
||||||
|
"this warning."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueryStream:
|
class QueryStream:
|
||||||
"""Wraps backend.query_stream() as an async iterable object."""
|
"""Wraps backend.query_stream() as an async iterable object."""
|
||||||
|
|
||||||
|
|
@ -60,10 +74,23 @@ class Collection:
|
||||||
- stream=True: returns async iterable of QueryEvent
|
- stream=True: returns async iterable of QueryEvent
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
answer = col.query("question")
|
answer = col.query("question", doc_ids=[doc_id])
|
||||||
async for event in col.query("question", stream=True):
|
async for event in col.query("question", doc_ids=[doc_id], stream=True):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
Passing doc_ids=None queries the entire collection — this is
|
||||||
|
experimental; emits a UserWarning unless PAGEINDEX_EXPERIMENTAL_MULTIDOC
|
||||||
|
is set.
|
||||||
"""
|
"""
|
||||||
|
if doc_ids is None and not _multidoc_acked():
|
||||||
|
docs = self._backend.list_documents(self._name)
|
||||||
|
if not docs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot query collection '{self._name}': it is empty. "
|
||||||
|
"Add documents with col.add(...) first."
|
||||||
|
)
|
||||||
|
if len(docs) > 1:
|
||||||
|
warnings.warn(_MULTIDOC_WARNING, UserWarning, stacklevel=2)
|
||||||
if stream:
|
if stream:
|
||||||
return QueryStream(self._backend, self._name, question, doc_ids)
|
return QueryStream(self._backend, self._name, question, doc_ids)
|
||||||
return self._backend.query(self._name, question, doc_ids)
|
return self._backend.query(self._name, question, doc_ids)
|
||||||
|
|
|
||||||
|
|
@ -125,10 +125,10 @@ class SQLiteStorage:
|
||||||
def list_documents(self, collection: str) -> list[dict]:
|
def list_documents(self, collection: str) -> list[dict]:
|
||||||
conn = self._get_conn()
|
conn = self._get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT doc_id, doc_name, doc_type FROM documents WHERE collection_name = ? ORDER BY created_at",
|
"SELECT doc_id, doc_name, doc_description, doc_type FROM documents WHERE collection_name = ? ORDER BY created_at",
|
||||||
(collection,),
|
(collection,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return [{"doc_id": r[0], "doc_name": r[1], "doc_type": r[2]} for r in rows]
|
return [{"doc_id": r[0], "doc_name": r[1], "doc_description": r[2] or "", "doc_type": r[3]} for r in rows]
|
||||||
|
|
||||||
def delete_document(self, collection: str, doc_id: str) -> None:
|
def delete_document(self, collection: str, doc_id: str) -> None:
|
||||||
conn = self._get_conn()
|
conn = self._get_conn()
|
||||||
|
|
|
||||||
|
|
@ -39,3 +39,46 @@ def test_delete_document(col):
|
||||||
|
|
||||||
def test_name_property(col):
|
def test_name_property(col):
|
||||||
assert col.name == "papers"
|
assert col.name == "papers"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_without_doc_ids_warns_when_multidoc(col, monkeypatch):
|
||||||
|
monkeypatch.delenv("PAGEINDEX_EXPERIMENTAL_MULTIDOC", raising=False)
|
||||||
|
col._backend.list_documents.return_value = [
|
||||||
|
{"doc_id": "d1", "doc_name": "a.pdf", "doc_type": "pdf"},
|
||||||
|
{"doc_id": "d2", "doc_name": "b.pdf", "doc_type": "pdf"},
|
||||||
|
]
|
||||||
|
col._backend.query.return_value = "answer"
|
||||||
|
with pytest.warns(UserWarning, match="experimental"):
|
||||||
|
result = col.query("what?")
|
||||||
|
assert result == "answer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_without_doc_ids_no_warning_when_single_doc(col, monkeypatch, recwarn):
|
||||||
|
monkeypatch.delenv("PAGEINDEX_EXPERIMENTAL_MULTIDOC", raising=False)
|
||||||
|
col._backend.query.return_value = "answer"
|
||||||
|
col.query("what?")
|
||||||
|
assert not any(issubclass(w.category, UserWarning) for w in recwarn)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_empty_collection_raises(col, monkeypatch):
|
||||||
|
monkeypatch.delenv("PAGEINDEX_EXPERIMENTAL_MULTIDOC", raising=False)
|
||||||
|
col._backend.list_documents.return_value = []
|
||||||
|
with pytest.raises(ValueError, match="empty"):
|
||||||
|
col.query("what?")
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_with_doc_ids_no_warning(col, recwarn):
|
||||||
|
col._backend.query.return_value = "answer"
|
||||||
|
col.query("what?", doc_ids=["d1"])
|
||||||
|
assert not any(issubclass(w.category, UserWarning) for w in recwarn)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_env_var_silences_warning(col, monkeypatch, recwarn):
|
||||||
|
monkeypatch.setenv("PAGEINDEX_EXPERIMENTAL_MULTIDOC", "1")
|
||||||
|
col._backend.list_documents.return_value = [
|
||||||
|
{"doc_id": "d1", "doc_name": "a.pdf", "doc_type": "pdf"},
|
||||||
|
{"doc_id": "d2", "doc_name": "b.pdf", "doc_type": "pdf"},
|
||||||
|
]
|
||||||
|
col._backend.query.return_value = "answer"
|
||||||
|
col.query("what?")
|
||||||
|
assert not any(issubclass(w.category, UserWarning) for w in recwarn)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
# tests/sdk/test_local_backend.py
|
# tests/sdk/test_local_backend.py
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pageindex.backend.local import LocalBackend
|
from pageindex.backend.local import LocalBackend
|
||||||
from pageindex.storage.sqlite import SQLiteStorage
|
from pageindex.storage.sqlite import SQLiteStorage
|
||||||
from pageindex.errors import FileTypeError
|
from pageindex.errors import FileTypeError, DocumentNotFoundError
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -48,3 +50,81 @@ def test_register_custom_parser(backend):
|
||||||
backend.register_parser(TxtParser())
|
backend.register_parser(TxtParser())
|
||||||
# Now .txt should be supported (won't raise FileTypeError)
|
# Now .txt should be supported (won't raise FileTypeError)
|
||||||
assert backend._resolve_parser("test.txt") is not None
|
assert backend._resolve_parser("test.txt") is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoped-mode agent tools ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def populated_backend(backend):
|
||||||
|
"""Backend with a 'papers' collection containing two stub docs."""
|
||||||
|
backend.get_or_create_collection("papers")
|
||||||
|
for did, name, desc in [
|
||||||
|
("d1", "alpha.pdf", "About alpha."),
|
||||||
|
("d2", "beta.pdf", "About beta."),
|
||||||
|
]:
|
||||||
|
backend._storage.save_document("papers", did, {
|
||||||
|
"doc_name": name, "doc_description": desc,
|
||||||
|
"doc_type": "pdf", "file_path": f"/tmp/{name}", "structure": [],
|
||||||
|
})
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_tool(tool, args: dict) -> str:
|
||||||
|
"""Run a FunctionTool synchronously with a minimal ToolContext."""
|
||||||
|
from agents.tool_context import ToolContext
|
||||||
|
ctx = ToolContext(context=None, tool_name=tool.name,
|
||||||
|
tool_call_id="test", tool_arguments=json.dumps(args))
|
||||||
|
return asyncio.run(tool.on_invoke_tool(ctx, json.dumps(args)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_mode_includes_list_documents(populated_backend):
|
||||||
|
tools = populated_backend.get_agent_tools("papers", doc_ids=None)
|
||||||
|
names = {t.name for t in tools.function_tools}
|
||||||
|
assert names == {"list_documents", "get_document", "get_document_structure", "get_page_content"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_mode_excludes_list_documents(populated_backend):
|
||||||
|
tools = populated_backend.get_agent_tools("papers", doc_ids=["d1"])
|
||||||
|
names = {t.name for t in tools.function_tools}
|
||||||
|
assert "list_documents" not in names
|
||||||
|
assert names == {"get_document", "get_document_structure", "get_page_content"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_mode_rejects_out_of_scope_doc_id(populated_backend):
|
||||||
|
tools = populated_backend.get_agent_tools("papers", doc_ids=["d1"])
|
||||||
|
by_name = {t.name: t for t in tools.function_tools}
|
||||||
|
out = json.loads(_invoke_tool(by_name["get_document"], {"doc_id": "d2"}))
|
||||||
|
assert "error" in out
|
||||||
|
assert "not in scope" in out["error"]
|
||||||
|
assert out["allowed_doc_ids"] == ["d1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_mode_allows_in_scope_doc_id(populated_backend):
|
||||||
|
tools = populated_backend.get_agent_tools("papers", doc_ids=["d1"])
|
||||||
|
by_name = {t.name: t for t in tools.function_tools}
|
||||||
|
out = json.loads(_invoke_tool(by_name["get_document"], {"doc_id": "d1"}))
|
||||||
|
assert out.get("doc_name") == "alpha.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrap_with_doc_context_single(populated_backend):
|
||||||
|
from pageindex.agent import wrap_with_doc_context
|
||||||
|
docs = populated_backend._scoped_docs("papers", ["d1"])
|
||||||
|
wrapped = wrap_with_doc_context(docs, "what is this?")
|
||||||
|
assert "d1: alpha.pdf — About alpha." in wrapped
|
||||||
|
assert "specified the following document:" in wrapped
|
||||||
|
assert "User question: what is this?" in wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrap_with_doc_context_multi(populated_backend):
|
||||||
|
from pageindex.agent import wrap_with_doc_context
|
||||||
|
docs = populated_backend._scoped_docs("papers", ["d1", "d2"])
|
||||||
|
wrapped = wrap_with_doc_context(docs, "compare them")
|
||||||
|
assert "d1: alpha.pdf — About alpha." in wrapped
|
||||||
|
assert "d2: beta.pdf — About beta." in wrapped
|
||||||
|
assert "specified the following documents:" in wrapped
|
||||||
|
assert "User question: compare them" in wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_docs_raises_on_missing(populated_backend):
|
||||||
|
with pytest.raises(DocumentNotFoundError, match="nonexistent"):
|
||||||
|
populated_backend._scoped_docs("papers", ["d1", "nonexistent"])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue