mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-04-24 23:56:21 +02:00
Add PageIndexClient with agent-based retrieval via OpenAI Agents SDK (#125)
* Add PageIndexClient with retrieve, streaming support and litellm integration * Add OpenAI agents demo example * Update README with example agent demo section * Support separate retrieve_model configuration for index and retrieve
This commit is contained in:
parent
2403be8f27
commit
5d4491f3bf
9 changed files with 501 additions and 7 deletions
24
README.md
24
README.md
|
|
@ -147,15 +147,17 @@ You can follow these steps to generate a PageIndex tree from a PDF document.
|
||||||
pip3 install --upgrade -r requirements.txt
|
pip3 install --upgrade -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Set your OpenAI API key
|
### 2. Set your LLM API key
|
||||||
|
|
||||||
Create a `.env` file in the root directory and add your API key:
|
Create a `.env` file in the root directory with your LLM API key::
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CHATGPT_API_KEY=your_openai_key_here
|
OPENAI_API_KEY=your_openai_key_here
|
||||||
|
# or
|
||||||
|
CHATGPT_API_KEY=your_openai_key_here # legacy, still supported
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Run PageIndex on your PDF
|
### 3. Generate PageIndex structure for your PDF
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 run_pageindex.py --pdf_path /path/to/your/document.pdf
|
python3 run_pageindex.py --pdf_path /path/to/your/document.pdf
|
||||||
|
|
@ -189,6 +191,20 @@ python3 run_pageindex.py --md_path /path/to/your/document.md
|
||||||
> Note: in this function, we use "#" to determine node heading and their levels. For example, "##" is level 2, "###" is level 3, etc. Make sure your markdown file is formatted correctly. If your Markdown file was converted from a PDF or HTML, we don't recommend using this function, since most existing conversion tools cannot preserve the original hierarchy. Instead, use our [PageIndex OCR](https://pageindex.ai/blog/ocr), which is designed to preserve the original hierarchy, to convert the PDF to a markdown file and then use this function.
|
> Note: in this function, we use "#" to determine node heading and their levels. For example, "##" is level 2, "###" is level 3, etc. Make sure your markdown file is formatted correctly. If your Markdown file was converted from a PDF or HTML, we don't recommend using this function, since most existing conversion tools cannot preserve the original hierarchy. Instead, use our [PageIndex OCR](https://pageindex.ai/blog/ocr), which is designed to preserve the original hierarchy, to convert the PDF to a markdown file and then use this function.
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### A Complete Agentic RAG Example
|
||||||
|
|
||||||
|
For a complete agent-based QA example using the [OpenAI Agents SDK](https://github.com/openai/openai-agents-python), see [`examples/openai_agents_demo.py`](examples/openai_agents_demo.py).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install optional dependency
|
||||||
|
pip3 install openai-agents
|
||||||
|
|
||||||
|
# Run the demo
|
||||||
|
python3 examples/openai_agents_demo.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
# ☁️ Improved Tree Generation with PageIndex OCR
|
# ☁️ Improved Tree Generation with PageIndex OCR
|
||||||
|
|
||||||
|
|
|
||||||
173
examples/openai_agents_demo.py
Normal file
173
examples/openai_agents_demo.py
Normal file
|
|
@ -0,0 +1,173 @@
|
||||||
|
"""
|
||||||
|
PageIndex x OpenAI Agents Demo
|
||||||
|
|
||||||
|
Demonstrates how to use PageIndexClient with the OpenAI Agents SDK
|
||||||
|
to build a document QA agent with 3 tools:
|
||||||
|
- get_document()
|
||||||
|
- get_document_structure()
|
||||||
|
- get_page_content()
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install openai-agents
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1 — Index PDF and inspect tree structure
|
||||||
|
2 — Inspect document metadata
|
||||||
|
3 — Ask a question (agent auto-calls tools)
|
||||||
|
4 — Reload from workspace and verify persistence
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from agents import Agent, ItemHelpers, Runner, function_tool
|
||||||
|
from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
|
||||||
|
from openai.types.responses import ResponseTextDeltaEvent, ResponseReasoningSummaryTextDeltaEvent # noqa: F401
|
||||||
|
|
||||||
|
from pageindex import PageIndexClient
|
||||||
|
import pageindex.utils as utils
|
||||||
|
|
||||||
|
PDF_URL = "https://arxiv.org/pdf/2501.12948.pdf"
|
||||||
|
PDF_PATH = "tests/pdfs/deepseek-r1.pdf"
|
||||||
|
WORKSPACE = "./pageindex_workspace"
|
||||||
|
|
||||||
|
AGENT_SYSTEM_PROMPT = """
|
||||||
|
You are PageIndex, a document QA assistant.
|
||||||
|
TOOL USE:
|
||||||
|
- Call get_document() first to confirm status and page/line count.
|
||||||
|
- Call get_document_structure() to find relevant page ranges (use node summaries and start_index/end_index).
|
||||||
|
- Call get_page_content(pages="5-7") with tight ranges. Never fetch the whole doc.
|
||||||
|
- When calling tool call, output one short sentence explaining reason.
|
||||||
|
ANSWERING: Answer based only on tool output. Be concise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def query_agent(
|
||||||
|
client: PageIndexClient,
|
||||||
|
doc_id: str,
|
||||||
|
prompt: str,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Run a document QA agent using the OpenAI Agents SDK.
|
||||||
|
|
||||||
|
Streams text output token-by-token and returns the full answer string.
|
||||||
|
Tool calls are always printed; verbose=True also prints arguments and output previews.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_document() -> str:
|
||||||
|
"""Get document metadata: status, page count, name, and description."""
|
||||||
|
return client.get_document(doc_id)
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_document_structure() -> str:
|
||||||
|
"""Get the document's full tree structure (without text) to find relevant sections."""
|
||||||
|
return client.get_document_structure(doc_id)
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_page_content(pages: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the text content of specific pages or line numbers.
|
||||||
|
Use tight ranges: e.g. '5-7' for pages 5 to 7, '3,8' for pages 3 and 8, '12' for page 12.
|
||||||
|
For Markdown documents, use line numbers from the structure's line_num field.
|
||||||
|
"""
|
||||||
|
return client.get_page_content(doc_id, pages)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
name="PageIndex",
|
||||||
|
instructions=AGENT_SYSTEM_PROMPT,
|
||||||
|
tools=[get_document, get_document_structure, get_page_content],
|
||||||
|
model=client.retrieve_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
collected = []
|
||||||
|
streamed_this_turn = False
|
||||||
|
streamed_run = Runner.run_streamed(agent, prompt)
|
||||||
|
async for event in streamed_run.stream_events():
|
||||||
|
if isinstance(event, RawResponsesStreamEvent):
|
||||||
|
if isinstance(event.data, ResponseReasoningSummaryTextDeltaEvent):
|
||||||
|
print(event.data.delta, end="", flush=True)
|
||||||
|
elif isinstance(event.data, ResponseTextDeltaEvent):
|
||||||
|
delta = event.data.delta
|
||||||
|
print(delta, end="", flush=True)
|
||||||
|
collected.append(delta)
|
||||||
|
streamed_this_turn = True
|
||||||
|
elif isinstance(event, RunItemStreamEvent):
|
||||||
|
item = event.item
|
||||||
|
if item.type == "message_output_item":
|
||||||
|
if not streamed_this_turn:
|
||||||
|
text = ItemHelpers.text_message_output(item)
|
||||||
|
if text:
|
||||||
|
print(f"{text}")
|
||||||
|
streamed_this_turn = False
|
||||||
|
collected.clear()
|
||||||
|
elif item.type == "tool_call_item":
|
||||||
|
if streamed_this_turn:
|
||||||
|
print() # end streaming line before tool call
|
||||||
|
raw = item.raw_item
|
||||||
|
args = getattr(raw, "arguments", "{}")
|
||||||
|
args_str = f"({args})" if verbose else ""
|
||||||
|
print(f"[tool call]: {raw.name}{args_str}")
|
||||||
|
elif item.type == "tool_call_output_item" and verbose:
|
||||||
|
output = str(item.output)
|
||||||
|
preview = output[:200] + "..." if len(output) > 200 else output
|
||||||
|
print(f"[tool output]: {preview}\n")
|
||||||
|
return "".join(collected)
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||||
|
return pool.submit(asyncio.run, _run()).result()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Download PDF if needed ─────────────────────────────────────────────────────
|
||||||
|
if not os.path.exists(PDF_PATH):
|
||||||
|
print(f"Downloading {PDF_URL} ...")
|
||||||
|
os.makedirs(os.path.dirname(PDF_PATH), exist_ok=True)
|
||||||
|
with requests.get(PDF_URL, stream=True, timeout=30) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
with open(PDF_PATH, "wb") as f:
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
print("Download complete.\n")
|
||||||
|
|
||||||
|
# ── Setup ──────────────────────────────────────────────────────────────────────
|
||||||
|
client = PageIndexClient(workspace=WORKSPACE)
|
||||||
|
|
||||||
|
# ── Step 1: Index + Tree ───────────────────────────────────────────────────────
|
||||||
|
print("=" * 60)
|
||||||
|
print("Step 1: Indexing PDF and inspecting tree structure")
|
||||||
|
print("=" * 60)
|
||||||
|
_id_cache = Path(WORKSPACE).expanduser() / "demo_doc_id.txt"
|
||||||
|
if _id_cache.exists() and (doc_id := _id_cache.read_text().strip()) in client.documents:
|
||||||
|
print(f"\nLoaded cached doc_id: {doc_id}")
|
||||||
|
else:
|
||||||
|
doc_id = client.index(PDF_PATH)
|
||||||
|
_id_cache.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_id_cache.write_text(doc_id)
|
||||||
|
print(f"\nIndexed. doc_id: {doc_id}")
|
||||||
|
print("\nTree Structure (top-level sections):")
|
||||||
|
utils.print_tree(client.documents[doc_id]["structure"])
|
||||||
|
|
||||||
|
# ── Step 2: Document Metadata ──────────────────────────────────────────────────
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Step 2: Document Metadata (get_document)")
|
||||||
|
print("=" * 60)
|
||||||
|
print(client.get_document(doc_id))
|
||||||
|
|
||||||
|
# ── Step 3: Agent Query ────────────────────────────────────────────────────────
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Step 3: Agent Query (auto tool-use)")
|
||||||
|
print("=" * 60)
|
||||||
|
question = "What reward design does DeepSeek-R1-Zero use, and why was it chosen over supervised fine-tuning?"
|
||||||
|
print(f"\nQuestion: '{question}'\n")
|
||||||
|
query_agent(client, doc_id, question, verbose=True)
|
||||||
|
|
@ -1,2 +1,4 @@
|
||||||
from .page_index import *
|
from .page_index import *
|
||||||
from .page_index_md import md_to_tree
|
from .page_index_md import md_to_tree
|
||||||
|
from .retrieve import get_document, get_document_structure, get_page_content
|
||||||
|
from .client import PageIndexClient
|
||||||
|
|
|
||||||
132
pageindex/client.py
Normal file
132
pageindex/client.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .page_index import page_index
|
||||||
|
from .page_index_md import md_to_tree
|
||||||
|
from .retrieve import get_document, get_document_structure, get_page_content
|
||||||
|
from .utils import ConfigLoader
|
||||||
|
|
||||||
|
class PageIndexClient:
|
||||||
|
"""
|
||||||
|
A client for indexing and retrieving document content.
|
||||||
|
Flow: index() -> get_document() / get_document_structure() / get_page_content()
|
||||||
|
|
||||||
|
For agent-based QA, see examples/openai_agents_demo.py.
|
||||||
|
"""
|
||||||
|
def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None):
|
||||||
|
if api_key:
|
||||||
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
|
||||||
|
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")
|
||||||
|
self.workspace = Path(workspace).expanduser() if workspace else None
|
||||||
|
overrides = {}
|
||||||
|
if model:
|
||||||
|
overrides["model"] = model
|
||||||
|
if retrieve_model:
|
||||||
|
overrides["retrieve_model"] = retrieve_model
|
||||||
|
opt = ConfigLoader().load(overrides or None)
|
||||||
|
self.model = opt.model
|
||||||
|
self.retrieve_model = opt.retrieve_model or self.model
|
||||||
|
if self.workspace:
|
||||||
|
self.workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.documents = {}
|
||||||
|
if self.workspace:
|
||||||
|
self._load_workspace()
|
||||||
|
|
||||||
|
def index(self, file_path: str, mode: str = "auto") -> str:
|
||||||
|
"""Index a document. Returns a document_id."""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
is_pdf = ext == '.pdf'
|
||||||
|
is_md = ext in ['.md', '.markdown']
|
||||||
|
|
||||||
|
if mode == "pdf" or (mode == "auto" and is_pdf):
|
||||||
|
print(f"Indexing PDF: {file_path}")
|
||||||
|
result = page_index(
|
||||||
|
doc=file_path,
|
||||||
|
model=self.model,
|
||||||
|
if_add_node_summary='yes',
|
||||||
|
if_add_node_text='yes',
|
||||||
|
if_add_node_id='yes',
|
||||||
|
if_add_doc_description='yes'
|
||||||
|
)
|
||||||
|
self.documents[doc_id] = {
|
||||||
|
'id': doc_id,
|
||||||
|
'path': file_path,
|
||||||
|
'type': 'pdf',
|
||||||
|
'structure': result['structure'],
|
||||||
|
'doc_name': result.get('doc_name', ''),
|
||||||
|
'doc_description': result.get('doc_description', '')
|
||||||
|
}
|
||||||
|
|
||||||
|
elif mode == "md" or (mode == "auto" and is_md):
|
||||||
|
print(f"Indexing Markdown: {file_path}")
|
||||||
|
coro = md_to_tree(
|
||||||
|
md_path=file_path,
|
||||||
|
if_thinning=False,
|
||||||
|
if_add_node_summary='yes',
|
||||||
|
summary_token_threshold=200,
|
||||||
|
model=self.model,
|
||||||
|
if_add_doc_description='yes',
|
||||||
|
if_add_node_text='yes',
|
||||||
|
if_add_node_id='yes'
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||||
|
result = pool.submit(asyncio.run, coro).result()
|
||||||
|
except RuntimeError:
|
||||||
|
result = asyncio.run(coro)
|
||||||
|
self.documents[doc_id] = {
|
||||||
|
'id': doc_id,
|
||||||
|
'path': file_path,
|
||||||
|
'type': 'md',
|
||||||
|
'structure': result['structure'],
|
||||||
|
'doc_name': result.get('doc_name', ''),
|
||||||
|
'doc_description': result.get('doc_description', '')
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file format for: {file_path}")
|
||||||
|
|
||||||
|
print(f"Indexing complete. Document ID: {doc_id}")
|
||||||
|
if self.workspace:
|
||||||
|
self._save_doc(doc_id)
|
||||||
|
return doc_id
|
||||||
|
|
||||||
|
def _save_doc(self, doc_id: str):
|
||||||
|
path = self.workspace / f"{doc_id}.json"
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.documents[doc_id], f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def _load_workspace(self):
|
||||||
|
loaded = 0
|
||||||
|
for path in self.workspace.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
doc = json.load(f)
|
||||||
|
self.documents[path.stem] = doc
|
||||||
|
loaded += 1
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
print(f"Warning: skipping corrupt workspace file {path.name}: {e}")
|
||||||
|
if loaded:
|
||||||
|
print(f"Loaded {loaded} document(s) from workspace.")
|
||||||
|
|
||||||
|
def get_document(self, doc_id: str) -> str:
|
||||||
|
"""Return document metadata JSON."""
|
||||||
|
return get_document(self.documents, doc_id)
|
||||||
|
|
||||||
|
def get_document_structure(self, doc_id: str) -> str:
|
||||||
|
"""Return document tree structure JSON (without text fields)."""
|
||||||
|
return get_document_structure(self.documents, doc_id)
|
||||||
|
|
||||||
|
def get_page_content(self, doc_id: str, pages: str) -> str:
|
||||||
|
"""Return page content for the given pages string (e.g. '5-7', '3,8', '12')."""
|
||||||
|
return get_page_content(self.documents, doc_id, pages)
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
model: "gpt-4o-2024-11-20"
|
model: "gpt-4o-2024-11-20"
|
||||||
# model: "anthropic/claude-sonnet-4-6"
|
# model: "anthropic/claude-sonnet-4-6"
|
||||||
|
retrieve_model: "gpt-5.4" # defaults to model if not set
|
||||||
toc_check_page_num: 20
|
toc_check_page_num: 20
|
||||||
max_page_num_each_node: 10
|
max_page_num_each_node: 10
|
||||||
max_token_num_each_node: 20000
|
max_token_num_each_node: 20000
|
||||||
|
|
|
||||||
|
|
@ -330,7 +330,7 @@ def toc_transformer(toc_content, model=None):
|
||||||
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
|
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
|
||||||
|
|
||||||
|
|
||||||
last_complete = json.loads(last_complete)
|
last_complete = extract_json(last_complete)
|
||||||
|
|
||||||
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
|
cleaned_response=convert_page_to_int(last_complete['table_of_contents'])
|
||||||
return cleaned_response
|
return cleaned_response
|
||||||
|
|
|
||||||
139
pageindex/retrieve.py
Normal file
139
pageindex/retrieve.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
import json
|
||||||
|
import PyPDF2
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import get_number_of_pages, remove_fields
|
||||||
|
except ImportError:
|
||||||
|
from utils import get_number_of_pages, remove_fields
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _parse_pages(pages: str) -> list[int]:
|
||||||
|
"""Parse a pages string like '5-7', '3,8', or '12' into a sorted list of ints."""
|
||||||
|
result = []
|
||||||
|
for part in pages.split(','):
|
||||||
|
part = part.strip()
|
||||||
|
if '-' in part:
|
||||||
|
start, end = int(part.split('-', 1)[0].strip()), int(part.split('-', 1)[1].strip())
|
||||||
|
if start > end:
|
||||||
|
raise ValueError(f"Invalid range '{part}': start must be <= end")
|
||||||
|
result.extend(range(start, end + 1))
|
||||||
|
else:
|
||||||
|
result.append(int(part))
|
||||||
|
return sorted(set(result))
|
||||||
|
|
||||||
|
|
||||||
|
def _count_pages(doc_info: dict) -> int:
|
||||||
|
"""Return total page count for a document."""
|
||||||
|
if doc_info.get('type') == 'pdf':
|
||||||
|
return get_number_of_pages(doc_info['path'])
|
||||||
|
# For MD, find max line_num across all nodes
|
||||||
|
max_line = 0
|
||||||
|
def _traverse(nodes):
|
||||||
|
nonlocal max_line
|
||||||
|
for node in nodes:
|
||||||
|
ln = node.get('line_num', 0)
|
||||||
|
if ln and ln > max_line:
|
||||||
|
max_line = ln
|
||||||
|
if node.get('nodes'):
|
||||||
|
_traverse(node['nodes'])
|
||||||
|
_traverse(doc_info.get('structure', []))
|
||||||
|
return max_line
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pdf_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]:
|
||||||
|
"""Extract text for specific PDF pages (1-indexed), opening the PDF once."""
|
||||||
|
path = doc_info['path']
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
pdf_reader = PyPDF2.PdfReader(f)
|
||||||
|
total = len(pdf_reader.pages)
|
||||||
|
valid_pages = [p for p in page_nums if 1 <= p <= total]
|
||||||
|
return [
|
||||||
|
{'page': p, 'content': pdf_reader.pages[p - 1].extract_text() or ''}
|
||||||
|
for p in valid_pages
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_md_page_content(doc_info: dict, page_nums: list[int]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
For Markdown documents, 'pages' are line numbers.
|
||||||
|
Find nodes whose line_num falls within [min(page_nums), max(page_nums)] and return their text.
|
||||||
|
"""
|
||||||
|
min_line, max_line = min(page_nums), max(page_nums)
|
||||||
|
results = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
def _traverse(nodes):
|
||||||
|
for node in nodes:
|
||||||
|
ln = node.get('line_num')
|
||||||
|
if ln and min_line <= ln <= max_line and ln not in seen:
|
||||||
|
seen.add(ln)
|
||||||
|
results.append({'page': ln, 'content': node.get('text', '')})
|
||||||
|
if node.get('nodes'):
|
||||||
|
_traverse(node['nodes'])
|
||||||
|
|
||||||
|
_traverse(doc_info.get('structure', []))
|
||||||
|
results.sort(key=lambda x: x['page'])
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool functions ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_document(documents: dict, doc_id: str) -> str:
|
||||||
|
"""Return JSON with document metadata: doc_id, doc_name, doc_description, type, status, page_count (PDF) or line_count (Markdown)."""
|
||||||
|
doc_info = documents.get(doc_id)
|
||||||
|
if not doc_info:
|
||||||
|
return json.dumps({'error': f'Document {doc_id} not found'})
|
||||||
|
result = {
|
||||||
|
'doc_id': doc_id,
|
||||||
|
'doc_name': doc_info.get('doc_name', ''),
|
||||||
|
'doc_description': doc_info.get('doc_description', ''),
|
||||||
|
'type': doc_info.get('type', ''),
|
||||||
|
'status': 'completed',
|
||||||
|
}
|
||||||
|
if doc_info.get('type') == 'pdf':
|
||||||
|
result['page_count'] = _count_pages(doc_info)
|
||||||
|
else:
|
||||||
|
result['line_count'] = _count_pages(doc_info)
|
||||||
|
return json.dumps(result)
|
||||||
|
|
||||||
|
|
||||||
|
def get_document_structure(documents: dict, doc_id: str) -> str:
|
||||||
|
"""Return tree structure JSON with text fields removed (saves tokens)."""
|
||||||
|
doc_info = documents.get(doc_id)
|
||||||
|
if not doc_info:
|
||||||
|
return json.dumps({'error': f'Document {doc_id} not found'})
|
||||||
|
structure = doc_info.get('structure', [])
|
||||||
|
structure_no_text = remove_fields(structure, fields=['text'])
|
||||||
|
return json.dumps(structure_no_text, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_content(documents: dict, doc_id: str, pages: str) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve page content for a document.
|
||||||
|
|
||||||
|
pages format: '5-7', '3,8', or '12'
|
||||||
|
For PDF: pages are physical page numbers (1-indexed).
|
||||||
|
For Markdown: pages are line numbers corresponding to node headers.
|
||||||
|
|
||||||
|
Returns JSON list of {'page': int, 'content': str}.
|
||||||
|
"""
|
||||||
|
doc_info = documents.get(doc_id)
|
||||||
|
if not doc_info:
|
||||||
|
return json.dumps({'error': f'Document {doc_id} not found'})
|
||||||
|
|
||||||
|
try:
|
||||||
|
page_nums = _parse_pages(pages)
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
return json.dumps({'error': f'Invalid pages format: {pages!r}. Use "5-7", "3,8", or "12". Error: {e}'})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if doc_info.get('type') == 'pdf':
|
||||||
|
content = _get_pdf_page_content(doc_info, page_nums)
|
||||||
|
else:
|
||||||
|
content = _get_md_page_content(doc_info, page_nums)
|
||||||
|
except Exception as e:
|
||||||
|
return json.dumps({'error': f'Failed to read page content: {e}'})
|
||||||
|
|
||||||
|
return json.dumps(content, ensure_ascii=False)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import litellm
|
import litellm
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
@ -29,6 +30,8 @@ def count_tokens(text, model=None):
|
||||||
|
|
||||||
|
|
||||||
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
def llm_completion(model, prompt, chat_history=None, return_finish_reason=False):
|
||||||
|
if model:
|
||||||
|
model = model.removeprefix("litellm/")
|
||||||
max_retries = 10
|
max_retries = 10
|
||||||
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}]
|
||||||
for i in range(max_retries):
|
for i in range(max_retries):
|
||||||
|
|
@ -57,6 +60,8 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False)
|
||||||
|
|
||||||
|
|
||||||
async def llm_acompletion(model, prompt):
|
async def llm_acompletion(model, prompt):
|
||||||
|
if model:
|
||||||
|
model = model.removeprefix("litellm/")
|
||||||
max_retries = 10
|
max_retries = 10
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
for i in range(max_retries):
|
for i in range(max_retries):
|
||||||
|
|
@ -678,3 +683,28 @@ class ConfigLoader:
|
||||||
self._validate_keys(user_dict)
|
self._validate_keys(user_dict)
|
||||||
merged = {**self._default_dict, **user_dict}
|
merged = {**self._default_dict, **user_dict}
|
||||||
return config(**merged)
|
return config(**merged)
|
||||||
|
|
||||||
|
def create_node_mapping(tree):
|
||||||
|
"""Create a flat dict mapping node_id to node for quick lookup."""
|
||||||
|
mapping = {}
|
||||||
|
def _traverse(nodes):
|
||||||
|
for node in nodes:
|
||||||
|
if node.get('node_id'):
|
||||||
|
mapping[node['node_id']] = node
|
||||||
|
if node.get('nodes'):
|
||||||
|
_traverse(node['nodes'])
|
||||||
|
_traverse(tree)
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def print_tree(tree, indent=0):
|
||||||
|
for node in tree:
|
||||||
|
summary = node.get('summary') or node.get('prefix_summary', '')
|
||||||
|
summary_str = f" — {summary[:60]}..." if summary else ""
|
||||||
|
print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}")
|
||||||
|
if node.get('nodes'):
|
||||||
|
print_tree(node['nodes'], indent + 1)
|
||||||
|
|
||||||
|
def print_wrapped(text, width=100):
|
||||||
|
for line in text.splitlines():
|
||||||
|
print(textwrap.fill(line, width=width))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
litellm==1.82.0
|
litellm==1.82.0
|
||||||
|
# openai-agents # optional: required for examples/openai_agents_demo.py
|
||||||
pymupdf==1.26.4
|
pymupdf==1.26.4
|
||||||
PyPDF2==3.0.1
|
PyPDF2==3.0.1
|
||||||
python-dotenv==1.1.0
|
python-dotenv==1.1.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue