Add PageIndexClient with agent-based retrieval via OpenAI Agents SDK (#125)

* Add PageIndexClient with retrieve, streaming support and litellm integration
* Add OpenAI agents demo example
* Update README with example agent demo section
* Support separate retrieve_model configuration for index and retrieve
This commit is contained in:
Kylin 2026-03-26 23:19:50 +08:00 committed by GitHub
parent 2403be8f27
commit 5d4491f3bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 501 additions and 7 deletions

View file

@ -1,2 +1,4 @@
from .page_index import *
from .page_index_md import md_to_tree
from .page_index_md import md_to_tree
from .retrieve import get_document, get_document_structure, get_page_content
from .client import PageIndexClient

132
pageindex/client.py Normal file
View file

@ -0,0 +1,132 @@
import os
import uuid
import json
import asyncio
import concurrent.futures
from pathlib import Path
from .page_index import page_index
from .page_index_md import md_to_tree
from .retrieve import get_document, get_document_structure, get_page_content
from .utils import ConfigLoader
class PageIndexClient:
"""
A client for indexing and retrieving document content.
Flow: index() -> get_document() / get_document_structure() / get_page_content()
For agent-based QA, see examples/openai_agents_demo.py.
"""
def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None):
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
self.workspace = Path(workspace).expanduser() if workspace else None
overrides = {}
if model:
overrides["model"] = model
if retrieve_model:
overrides["retrieve_model"] = retrieve_model
opt = ConfigLoader().load(overrides or None)
self.model = opt.model
self.retrieve_model = opt.retrieve_model or self.model
if self.workspace:
self.workspace.mkdir(parents=True, exist_ok=True)
self.documents = {}
if self.workspace:
self._load_workspace()
def index(self, file_path: str, mode: str = "auto") -> str:
"""Index a document. Returns a document_id."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
doc_id = str(uuid.uuid4())
ext = os.path.splitext(file_path)[1].lower()
is_pdf = ext == '.pdf'
is_md = ext in ['.md', '.markdown']
if mode == "pdf" or (mode == "auto" and is_pdf):
print(f"Indexing PDF: {file_path}")
result = page_index(
doc=file_path,
model=self.model,
if_add_node_summary='yes',
if_add_node_text='yes',
if_add_node_id='yes',
if_add_doc_description='yes'
)
self.documents[doc_id] = {
'id': doc_id,
'path': file_path,
'type': 'pdf',
'structure': result['structure'],
'doc_name': result.get('doc_name', ''),
'doc_description': result.get('doc_description', '')
}
elif mode == "md" or (mode == "auto" and is_md):
print(f"Indexing Markdown: {file_path}")
coro = md_to_tree(
md_path=file_path,
if_thinning=False,
if_add_node_summary='yes',
summary_token_threshold=200,
model=self.model,
if_add_doc_description='yes',
if_add_node_text='yes',
if_add_node_id='yes'
)
try:
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
result = pool.submit(asyncio.run, coro).result()
except RuntimeError:
result = asyncio.run(coro)
self.documents[doc_id] = {
'id': doc_id,
'path': file_path,
'type': 'md',
'structure': result['structure'],
'doc_name': result.get('doc_name', ''),
'doc_description': result.get('doc_description', '')
}
else:
raise ValueError(f"Unsupported file format for: {file_path}")
print(f"Indexing complete. Document ID: {doc_id}")
if self.workspace:
self._save_doc(doc_id)
return doc_id
def _save_doc(self, doc_id: str):
path = self.workspace / f"{doc_id}.json"
with open(path, "w", encoding="utf-8") as f:
json.dump(self.documents[doc_id], f, ensure_ascii=False, indent=2)
def _load_workspace(self):
loaded = 0
for path in self.workspace.glob("*.json"):
try:
with open(path, "r", encoding="utf-8") as f:
doc = json.load(f)
self.documents[path.stem] = doc
loaded += 1
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: skipping corrupt workspace file {path.name}: {e}")
if loaded:
print(f"Loaded {loaded} document(s) from workspace.")
def get_document(self, doc_id: str) -> str:
"""Return document metadata JSON."""
return get_document(self.documents, doc_id)
def get_document_structure(self, doc_id: str) -> str:
"""Return document tree structure JSON (without text fields)."""
return get_document_structure(self.documents, doc_id)
def get_page_content(self, doc_id: str, pages: str) -> str:
"""Return page content for the given pages string (e.g. '5-7', '3,8', '12')."""
return get_page_content(self.documents, doc_id, pages)

View file

@ -1,5 +1,6 @@
model: "gpt-4o-2024-11-20"
# model: "anthropic/claude-sonnet-4-6"
retrieve_model: "gpt-5.4" # defaults to model if not set
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000

View file

@ -330,7 +330,7 @@ def toc_transformer(toc_content, model=None):
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
last_complete = json.loads(last_complete)
last_complete = extract_json(last_complete)
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
return cleaned_response

139
pageindex/retrieve.py Normal file
View file

@ -0,0 +1,139 @@
import json
import PyPDF2
try:
from .utils import get_number_of_pages, remove_fields
except ImportError:
from utils import get_number_of_pages, remove_fields
# ── Helpers ──────────────────────────────────────────────────────────────────
def _parse_pages(pages: str) -> list[int]:
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
result = []
for part in pages.split(','):
part = part.strip()
if '-' in part:
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
if start > end:
raise ValueError(f"Invalid range '{part}': start must be <= end")
result.extend(range(start, end + 1))
else:
result.append(int(part))
return sorted(set(result))
def _count_pages(doc_info: dict) -> int:
"""Return total page count for a document."""
if doc_info.get('type') == 'pdf':
return get_number_of_pages(doc_info['path'])
# For MD, find max line_num across all nodes
max_line = 0
def _traverse(nodes):
nonlocal max_line
for node in nodes:
ln = node.get('line_num', 0)
if ln and ln > max_line:
max_line = ln
if node.get('nodes'):
_traverse(node['nodes'])
_traverse(doc_info.get('structure', []))
return max_line
def _get_pdf_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]:
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
path = doc_info['path']
with open(path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
total = len(pdf_reader.pages)
valid_pages = [p for p in page_nums if 1 <= p <= total]
return [
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
for p in valid_pages
]
def _get_md_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]:
"""
For Markdown documents, 'pages' are line numbers.
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
"""
min_line, max_line = min(page_nums), max(page_nums)
results = []
seen = set()
def _traverse(nodes):
for node in nodes:
ln = node.get('line_num')
if ln and min_line <= ln <= max_line and ln not in seen:
seen.add(ln)
results.append({'page': ln, 'content': node.get('text', '')})
if node.get('nodes'):
_traverse(node['nodes'])
_traverse(doc_info.get('structure', []))
results.sort(key=lambda x: x['page'])
return results
# ── Tool functions ────────────────────────────────────────────────────────────
def get_document(documents: dict, doc_id: str) -> str:
"""Return JSON with document metadata: doc_id, doc_name, doc_description, type, status, page_count (PDF) or line_count (Markdown)."""
doc_info = documents.get(doc_id)
if not doc_info:
return json.dumps({'error': f'Document {doc_id} not found'})
result = {
'doc_id': doc_id,
'doc_name': doc_info.get('doc_name', ''),
'doc_description': doc_info.get('doc_description', ''),
'type': doc_info.get('type', ''),
'status': 'completed',
}
if doc_info.get('type') == 'pdf':
result['page_count'] = _count_pages(doc_info)
else:
result['line_count'] = _count_pages(doc_info)
return json.dumps(result)
def get_document_structure(documents: dict, doc_id: str) -> str:
"""Return tree structure JSON with text fields removed (saves tokens)."""
doc_info = documents.get(doc_id)
if not doc_info:
return json.dumps({'error': f'Document {doc_id} not found'})
structure = doc_info.get('structure', [])
structure_no_text = remove_fields(structure, fields=['text'])
return json.dumps(structure_no_text, ensure_ascii=False)
def get_page_content(documents: dict, doc_id: str, pages: str) -> str:
"""
Retrieve page content for a document.
pages format: '5-7', '3,8', or '12'
For PDF: pages are physical page numbers (1-indexed).
For Markdown: pages are line numbers corresponding to node headers.
Returns JSON list of {'page': int, 'content': str}.
"""
doc_info = documents.get(doc_id)
if not doc_info:
return json.dumps({'error': f'Document {doc_id} not found'})
try:
page_nums = _parse_pages(pages)
except (ValueError, AttributeError) as e:
return json.dumps({'error': f'Invalid pages format: {pages!r}. Use "5-7", "3,8", or "12". Error: {e}'})
try:
if doc_info.get('type') == 'pdf':
content = _get_pdf_page_content(doc_info, page_nums)
else:
content = _get_md_page_content(doc_info, page_nums)
except Exception as e:
return json.dumps({'error': f'Failed to read page content: {e}'})
return json.dumps(content, ensure_ascii=False)

View file

@ -1,6 +1,7 @@
import litellm
import logging
import os
import textwrap
from datetime import datetime
import time
import json
@ -29,6 +30,8 @@ def count_tokens(text, model=None):
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
if model:
model = model.removeprefix("litellm/")
max_retries = 10
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
for i in range(max_retries):
@ -57,6 +60,8 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False)
async def llm_acompletion(model, prompt):
if model:
model = model.removeprefix("litellm/")
max_retries = 10
messages = [{"role": "user", "content": prompt}]
for i in range(max_retries):
@ -678,3 +683,28 @@ class ConfigLoader:
self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)
def create_node_mapping(tree):
"""Create a flat dict mapping node_id to node for quick lookup."""
mapping = {}
def _traverse(nodes):
for node in nodes:
if node.get('node_id'):
mapping[node['node_id']] = node
if node.get('nodes'):
_traverse(node['nodes'])
_traverse(tree)
return mapping
def print_tree(tree, indent=0):
for node in tree:
summary = node.get('summary') or node.get('prefix_summary', '')
summary_str = f"{summary[:60]}..." if summary else ""
print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}")
if node.get('nodes'):
print_tree(node['nodes'], indent + 1)
def print_wrapped(text, width=100):
for line in text.splitlines():
print(textwrap.fill(line, width=width))