mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-05-13 00:32:36 +02:00
feat:compatible with Pageindex SDK (#238)
* feat:compatible with Pageindex SDK
* corner cases fixed
* fix: mock behavior of old SDK
* fix: close streaming response and warn on empty api_key
- LegacyCloudAPI: close response in `finally` for both _stream_chat_response
variants so abandoned iterators no longer leak the TCP connection.
- PageIndexClient: emit a warning instead of silently falling back to local
when api_key is the empty string, surfacing typical env-var-unset misconfig.
- FakeResponse: add close()/closed to match the real requests.Response API.
- Add unit coverage for stream close (both paths) and the empty-api_key warning.
- Add scripts/e2e_legacy_sdk.py to smoke-test the legacy SDK contract end-to-end
against api.pageindex.ai.
* chore: mark legacy SDK methods with @deprecated and docstring pointers
- Decorate the 12 PageIndexClient cloud-SDK compat methods with
@typing_extensions.deprecated(..., category=PendingDeprecationWarning):
- IDE/type-checkers render them with a strikethrough hint
- runtime warnings stay silent by default (no spam for existing callers),
surfaceable via `python -W default::PendingDeprecationWarning`
- Add a one-line docstring on each pointing to the Collection-based equivalent.
- Promote typing-extensions to a direct dependency (was transitive via litellm).
---------
Co-authored-by: XinyanZhou <xinyanzhou@XinyanZhoudeMacBook-Pro.local>
Co-authored-by: saccharin98 <xinyanzhou938@gmail.com>
Co-authored-by: mountain <kose2livs@gmail.com>
This commit is contained in:
parent
6d29886892
commit
595895cf28
10 changed files with 1030 additions and 20 deletions
|
|
@ -13,6 +13,7 @@ from .storage.protocol import StorageEngine
|
|||
from .events import QueryEvent
|
||||
from .errors import (
|
||||
PageIndexError,
|
||||
PageIndexAPIError,
|
||||
CollectionNotFoundError,
|
||||
DocumentNotFoundError,
|
||||
IndexingError,
|
||||
|
|
@ -32,6 +33,7 @@ __all__ = [
|
|||
"StorageEngine",
|
||||
"QueryEvent",
|
||||
"PageIndexError",
|
||||
"PageIndexAPIError",
|
||||
"CollectionNotFoundError",
|
||||
"DocumentNotFoundError",
|
||||
"IndexingError",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,21 @@
|
|||
# pageindex/client.py
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .collection import Collection
|
||||
from .config import IndexConfig
|
||||
from .errors import PageIndexAPIError
|
||||
from .parser.protocol import DocumentParser
|
||||
|
||||
_LEGACY_SDK_MSG = (
|
||||
"Legacy compatibility — new code should prefer the Collection-based API "
|
||||
"(PageIndexClient.collection(...))."
|
||||
)
|
||||
_legacy_sdk = deprecated(_LEGACY_SDK_MSG, category=PendingDeprecationWarning)
|
||||
|
||||
|
||||
def _normalize_retrieve_model(model: str) -> str:
|
||||
"""Preserve supported Agents SDK prefixes and route other provider paths via LiteLLM."""
|
||||
|
|
@ -39,21 +50,34 @@ class PageIndexClient:
|
|||
# Or use LocalClient / CloudClient for explicit mode selection
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str = None, model: str = None,
|
||||
BASE_URL = "https://api.pageindex.ai"
|
||||
|
||||
def __init__(self, api_key: str | None = None, model: str = None,
|
||||
retrieve_model: str = None, storage_path: str = None,
|
||||
storage=None, index_config: IndexConfig | dict = None):
|
||||
if api_key:
|
||||
if api_key == "":
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"PageIndexClient received an empty api_key; falling back to local mode. "
|
||||
"Pass api_key=None to silence this warning, or provide a real key for cloud mode."
|
||||
)
|
||||
api_key = None
|
||||
if api_key is not None:
|
||||
self._init_cloud(api_key)
|
||||
else:
|
||||
self._init_local(model, retrieve_model, storage_path, storage, index_config)
|
||||
|
||||
def _init_cloud(self, api_key: str):
|
||||
from .backend.cloud import CloudBackend
|
||||
from .cloud_api import LegacyCloudAPI
|
||||
self._backend = CloudBackend(api_key=api_key)
|
||||
self._legacy_cloud_api = LegacyCloudAPI(api_key=api_key, base_url=self.BASE_URL)
|
||||
|
||||
def _init_local(self, model: str = None, retrieve_model: str = None,
|
||||
storage_path: str = None, storage=None,
|
||||
index_config: IndexConfig | dict = None):
|
||||
self._legacy_cloud_api = None
|
||||
|
||||
# Build IndexConfig: merge model/retrieve_model with index_config
|
||||
overrides = {}
|
||||
if model:
|
||||
|
|
@ -123,6 +147,124 @@ class PageIndexClient:
|
|||
raise PageIndexError("Custom parsers are not supported in cloud mode")
|
||||
self._backend.register_parser(parser)
|
||||
|
||||
def _require_cloud_api(self):
|
||||
if self._legacy_cloud_api is None:
|
||||
from .errors import PageIndexAPIError
|
||||
raise PageIndexAPIError(
|
||||
"This method is part of the pageindex 0.2.x cloud SDK API. "
|
||||
"Initialize with api_key to use it."
|
||||
)
|
||||
return self._legacy_cloud_api
|
||||
|
||||
# ── pageindex 0.2.x cloud SDK compatibility (prefer Collection API for new code) ──
|
||||
@_legacy_sdk
|
||||
def submit_document(
|
||||
self,
|
||||
file_path: str,
|
||||
mode: str | None = None,
|
||||
beta_headers: list[str] | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``client.collection(...).add(path)``."""
|
||||
return self._require_cloud_api().submit_document(
|
||||
file_path=file_path,
|
||||
mode=mode,
|
||||
beta_headers=beta_headers,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
@_legacy_sdk
|
||||
def get_ocr(self, doc_id: str, format: str = "page") -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.get_page_content(doc_id, pages)``."""
|
||||
return self._require_cloud_api().get_ocr(doc_id=doc_id, format=format)
|
||||
|
||||
@_legacy_sdk
|
||||
def get_tree(self, doc_id: str, node_summary: bool = False) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.get_document_structure(doc_id)``."""
|
||||
return self._require_cloud_api().get_tree(doc_id=doc_id, node_summary=node_summary)
|
||||
|
||||
@_legacy_sdk
|
||||
def is_retrieval_ready(self, doc_id: str) -> bool:
|
||||
"""Legacy SDK compatibility — Collection API handles readiness internally."""
|
||||
return self._require_cloud_api().is_retrieval_ready(doc_id=doc_id)
|
||||
|
||||
@_legacy_sdk
|
||||
def submit_query(self, doc_id: str, query: str, thinking: bool = False) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.query(question, doc_ids=[doc_id])``."""
|
||||
return self._require_cloud_api().submit_query(
|
||||
doc_id=doc_id,
|
||||
query=query,
|
||||
thinking=thinking,
|
||||
)
|
||||
|
||||
@_legacy_sdk
|
||||
def get_retrieval(self, retrieval_id: str) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — Collection API returns answers synchronously."""
|
||||
return self._require_cloud_api().get_retrieval(retrieval_id=retrieval_id)
|
||||
|
||||
@_legacy_sdk
|
||||
def chat_completions(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
stream: bool = False,
|
||||
doc_id: str | list[str] | None = None,
|
||||
temperature: float | None = None,
|
||||
stream_metadata: bool = False,
|
||||
enable_citations: bool = False,
|
||||
) -> dict[str, Any] | Iterator[str] | Iterator[dict[str, Any]]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.query(...)``."""
|
||||
return self._require_cloud_api().chat_completions(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
doc_id=doc_id,
|
||||
temperature=temperature,
|
||||
stream_metadata=stream_metadata,
|
||||
enable_citations=enable_citations,
|
||||
)
|
||||
|
||||
@_legacy_sdk
|
||||
def get_document(self, doc_id: str) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.get_document(doc_id)``."""
|
||||
return self._require_cloud_api().get_document(doc_id=doc_id)
|
||||
|
||||
@_legacy_sdk
|
||||
def delete_document(self, doc_id: str) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.delete_document(doc_id)``."""
|
||||
return self._require_cloud_api().delete_document(doc_id=doc_id)
|
||||
|
||||
@_legacy_sdk
|
||||
def list_documents(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``collection.list_documents()``."""
|
||||
return self._require_cloud_api().list_documents(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
@_legacy_sdk
|
||||
def create_folder(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
parent_folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``client.collection(name)`` (auto-creates)."""
|
||||
return self._require_cloud_api().create_folder(
|
||||
name=name,
|
||||
description=description,
|
||||
parent_folder_id=parent_folder_id,
|
||||
)
|
||||
|
||||
@_legacy_sdk
|
||||
def list_folders(self, parent_folder_id: str | None = None) -> dict[str, Any]:
|
||||
"""Legacy SDK compatibility — prefer ``client.list_collections()``."""
|
||||
return self._require_cloud_api().list_folders(parent_folder_id=parent_folder_id)
|
||||
|
||||
|
||||
class LocalClient(PageIndexClient):
|
||||
"""Local mode — indexes and queries documents on your machine.
|
||||
|
|
|
|||
265
pageindex/cloud_api.py
Normal file
265
pageindex/cloud_api.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Iterator
|
||||
|
||||
import requests
|
||||
|
||||
from .errors import PageIndexAPIError
|
||||
|
||||
|
||||
class LegacyCloudAPI:
|
||||
"""Compatibility layer for the pageindex 0.2.x cloud SDK API."""
|
||||
|
||||
BASE_URL = "https://api.pageindex.ai"
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or self.BASE_URL
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {"api_key": self.api_key}
|
||||
|
||||
def _request(self, method: str, path: str, error_prefix: str, **kwargs) -> requests.Response:
|
||||
try:
|
||||
response = requests.request(
|
||||
method,
|
||||
f"{self.base_url}{path}",
|
||||
headers=self._headers(),
|
||||
**kwargs,
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
raise PageIndexAPIError(f"{error_prefix}: {e}") from e
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PageIndexAPIError(f"{error_prefix}: {response.text}")
|
||||
return response
|
||||
|
||||
def submit_document(
|
||||
self,
|
||||
file_path: str,
|
||||
mode: str | None = None,
|
||||
beta_headers: list[str] | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {"if_retrieval": True}
|
||||
if mode is not None:
|
||||
data["mode"] = mode
|
||||
if beta_headers is not None:
|
||||
data["beta_headers"] = json.dumps(beta_headers)
|
||||
if folder_id is not None:
|
||||
data["folder_id"] = folder_id
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/doc/",
|
||||
"Failed to submit document",
|
||||
files={"file": f},
|
||||
data=data,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def get_ocr(self, doc_id: str, format: str = "page") -> dict[str, Any]:
|
||||
if format not in ["page", "node", "raw"]:
|
||||
raise ValueError("Format parameter must be 'page', 'node', or 'raw'")
|
||||
|
||||
response = self._request(
|
||||
"GET",
|
||||
f"/doc/{doc_id}/?type=ocr&format={format}",
|
||||
"Failed to get OCR result",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def get_tree(self, doc_id: str, node_summary: bool = False) -> dict[str, Any]:
|
||||
response = self._request(
|
||||
"GET",
|
||||
f"/doc/{doc_id}/?type=tree&summary={node_summary}",
|
||||
"Failed to get tree result",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def is_retrieval_ready(self, doc_id: str) -> bool:
|
||||
try:
|
||||
result = self.get_tree(doc_id)
|
||||
return result.get("retrieval_ready", False)
|
||||
except PageIndexAPIError:
|
||||
return False
|
||||
|
||||
def submit_query(self, doc_id: str, query: str, thinking: bool = False) -> dict[str, Any]:
|
||||
payload = {
|
||||
"doc_id": doc_id,
|
||||
"query": query,
|
||||
"thinking": thinking,
|
||||
}
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/retrieval/",
|
||||
"Failed to submit retrieval",
|
||||
json=payload,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def get_retrieval(self, retrieval_id: str) -> dict[str, Any]:
|
||||
response = self._request(
|
||||
"GET",
|
||||
f"/retrieval/{retrieval_id}/",
|
||||
"Failed to get retrieval result",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def chat_completions(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
stream: bool = False,
|
||||
doc_id: str | list[str] | None = None,
|
||||
temperature: float | None = None,
|
||||
stream_metadata: bool = False,
|
||||
enable_citations: bool = False,
|
||||
) -> dict[str, Any] | Iterator[str] | Iterator[dict[str, Any]]:
|
||||
payload: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if doc_id is not None:
|
||||
payload["doc_id"] = doc_id
|
||||
if temperature is not None:
|
||||
payload["temperature"] = temperature
|
||||
if enable_citations:
|
||||
payload["enable_citations"] = enable_citations
|
||||
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/chat/completions/",
|
||||
"Failed to get chat completion",
|
||||
json=payload,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if stream:
|
||||
if stream_metadata:
|
||||
return self._stream_chat_response_raw(response)
|
||||
return self._stream_chat_response(response)
|
||||
return response.json()
|
||||
|
||||
def _stream_chat_response(self, response: requests.Response) -> Iterator[str]:
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except requests.RequestException as e:
|
||||
raise PageIndexAPIError(f"Failed to stream chat completion: {e}") from e
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def _stream_chat_response_raw(self, response: requests.Response) -> Iterator[dict[str, Any]]:
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except requests.RequestException as e:
|
||||
raise PageIndexAPIError(f"Failed to stream chat completion: {e}") from e
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def get_document(self, doc_id: str) -> dict[str, Any]:
|
||||
response = self._request(
|
||||
"GET",
|
||||
f"/doc/{doc_id}/metadata/",
|
||||
"Failed to get document metadata",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def delete_document(self, doc_id: str) -> dict[str, Any]:
|
||||
response = self._request(
|
||||
"DELETE",
|
||||
f"/doc/{doc_id}/",
|
||||
"Failed to delete document",
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def list_documents(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if limit < 1 or limit > 100:
|
||||
raise ValueError("limit must be between 1 and 100")
|
||||
if offset < 0:
|
||||
raise ValueError("offset must be non-negative")
|
||||
|
||||
params: dict[str, Any] = {"limit": limit, "offset": offset}
|
||||
if folder_id is not None:
|
||||
params["folder_id"] = folder_id
|
||||
|
||||
response = self._request(
|
||||
"GET",
|
||||
"/docs/",
|
||||
"Failed to list documents",
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def create_folder(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
parent_folder_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"name": name}
|
||||
if description is not None:
|
||||
payload["description"] = description
|
||||
if parent_folder_id is not None:
|
||||
payload["parent_folder_id"] = parent_folder_id
|
||||
|
||||
response = self._request(
|
||||
"POST",
|
||||
"/folder/",
|
||||
"Failed to create folder",
|
||||
json=payload,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def list_folders(self, parent_folder_id: str | None = None) -> dict[str, Any]:
|
||||
params = {}
|
||||
if parent_folder_id is not None:
|
||||
params["parent_folder_id"] = parent_folder_id
|
||||
|
||||
response = self._request(
|
||||
"GET",
|
||||
"/folders/",
|
||||
"Failed to list folders",
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
|
@ -18,7 +18,15 @@ class IndexingError(PageIndexError):
|
|||
pass
|
||||
|
||||
|
||||
class CloudAPIError(PageIndexError):
|
||||
class PageIndexAPIError(PageIndexError):
|
||||
"""PageIndex cloud API returned an error.
|
||||
|
||||
Kept for compatibility with the pageindex 0.2.x cloud SDK.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CloudAPIError(PageIndexAPIError):
|
||||
"""Cloud API returned error."""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ load_dotenv()
|
|||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from types import SimpleNamespace as config
|
||||
|
||||
# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY
|
||||
|
|
@ -23,6 +24,22 @@ if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
|
|||
|
||||
litellm.drop_params = True
|
||||
|
||||
async def call_llm(prompt, api_key, model="gpt-4.1", temperature=0):
|
||||
"""Call an LLM to generate a response to a prompt.
|
||||
|
||||
Kept for compatibility with the pageindex 0.2.x SDK utility API.
|
||||
"""
|
||||
import openai
|
||||
|
||||
client = openai.AsyncOpenAI(api_key=api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=temperature,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
def count_tokens(text, model=None):
|
||||
if not text:
|
||||
return 0
|
||||
|
|
@ -463,12 +480,14 @@ def clean_structure_post(data):
|
|||
clean_structure_post(section)
|
||||
return data
|
||||
|
||||
def remove_fields(data, fields=['text']):
|
||||
def remove_fields(data, fields=['text'], max_len=None):
|
||||
if isinstance(data, dict):
|
||||
return {k: remove_fields(v, fields)
|
||||
return {k: remove_fields(v, fields, max_len)
|
||||
for k, v in data.items() if k not in fields}
|
||||
elif isinstance(data, list):
|
||||
return [remove_fields(item, fields) for item in data]
|
||||
return [remove_fields(item, fields, max_len) for item in data]
|
||||
elif isinstance(data, str):
|
||||
return data[:max_len] + '...' if max_len is not None and len(data) > max_len else data
|
||||
return data
|
||||
|
||||
def print_toc(tree, indent=0):
|
||||
|
|
@ -684,27 +703,72 @@ class ConfigLoader:
|
|||
merged = {**self._default_dict, **user_dict}
|
||||
return config(**merged)
|
||||
|
||||
def create_node_mapping(tree):
|
||||
"""Create a flat dict mapping node_id to node for quick lookup."""
|
||||
def create_node_mapping(tree, include_page_ranges=False, max_page=None):
|
||||
"""Create a mapping of node_id to node for quick lookup.
|
||||
|
||||
The optional page-range arguments are kept for compatibility with the
|
||||
pageindex 0.2.x SDK utility API.
|
||||
"""
|
||||
def get_all_nodes(nodes):
|
||||
if isinstance(nodes, dict):
|
||||
return [nodes] + [
|
||||
child_node
|
||||
for child in nodes.get('nodes', [])
|
||||
for child_node in get_all_nodes(child)
|
||||
]
|
||||
elif isinstance(nodes, list):
|
||||
return [
|
||||
child_node
|
||||
for item in nodes
|
||||
for child_node in get_all_nodes(item)
|
||||
]
|
||||
return []
|
||||
|
||||
all_nodes = get_all_nodes(tree)
|
||||
|
||||
if not include_page_ranges:
|
||||
return {node["node_id"]: node for node in all_nodes if node.get("node_id")}
|
||||
|
||||
mapping = {}
|
||||
def _traverse(nodes):
|
||||
for node in nodes:
|
||||
if node.get('node_id'):
|
||||
mapping[node['node_id']] = node
|
||||
if node.get('nodes'):
|
||||
_traverse(node['nodes'])
|
||||
_traverse(tree)
|
||||
for i, node in enumerate(all_nodes):
|
||||
if not node.get("node_id"):
|
||||
continue
|
||||
start_page = node.get("page_index", node.get("start_index"))
|
||||
if node.get("end_index") is not None:
|
||||
end_page = node.get("end_index")
|
||||
elif i + 1 < len(all_nodes):
|
||||
next_node = all_nodes[i + 1]
|
||||
end_page = next_node.get("page_index", next_node.get("start_index"))
|
||||
else:
|
||||
end_page = max_page
|
||||
|
||||
mapping[node["node_id"]] = {
|
||||
"node": node,
|
||||
"start_index": start_page,
|
||||
"end_index": end_page,
|
||||
}
|
||||
|
||||
return mapping
|
||||
|
||||
def print_tree(tree, indent=0):
|
||||
def print_tree(tree, exclude_fields=None, indent=None):
|
||||
if exclude_fields is None:
|
||||
exclude_fields = ['text', 'page_index']
|
||||
if isinstance(exclude_fields, int):
|
||||
indent = exclude_fields
|
||||
exclude_fields = None
|
||||
if indent is None and exclude_fields is not None:
|
||||
cleaned_tree = remove_fields(copy.deepcopy(tree), exclude_fields, max_len=40)
|
||||
pprint(cleaned_tree, sort_dicts=False, width=100)
|
||||
return
|
||||
|
||||
indent = indent or 0
|
||||
for node in tree:
|
||||
summary = node.get('summary') or node.get('prefix_summary', '')
|
||||
summary_str = f" — {summary[:60]}..." if summary else ""
|
||||
print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}")
|
||||
if node.get('nodes'):
|
||||
print_tree(node['nodes'], indent + 1)
|
||||
print_tree(node['nodes'], exclude_fields=exclude_fields, indent=indent + 1)
|
||||
|
||||
def print_wrapped(text, width=100):
|
||||
for line in text.splitlines():
|
||||
print(textwrap.fill(line, width=width))
|
||||
|
||||
|
|
|
|||
|
|
@ -26,9 +26,11 @@ pymupdf = ">=1.26.0"
|
|||
PyPDF2 = ">=3.0.0"
|
||||
python-dotenv = ">=1.0.0"
|
||||
pyyaml = ">=6.0"
|
||||
openai = ">=1.70.0"
|
||||
openai-agents = ">=0.1.0"
|
||||
requests = ">=2.28.0"
|
||||
httpx = {extras = ["socks"], version = ">=0.28.1"}
|
||||
typing-extensions = ">=4.9.0"
|
||||
|
||||
[tool.poetry.urls]
|
||||
Repository = "https://github.com/VectifyAI/PageIndex"
|
||||
|
|
|
|||
94
scripts/e2e_legacy_sdk.py
Normal file
94
scripts/e2e_legacy_sdk.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""End-to-end smoke test of the legacy SDK compatibility layer against the real cloud API.
|
||||
|
||||
Run: PAGEINDEX_API_KEY=... uv run python scripts/e2e_legacy_sdk.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from pageindex import PageIndexClient
|
||||
|
||||
|
||||
def log(step: str, detail: str = "") -> None:
|
||||
print(f"[e2e] {step}" + (f" — {detail}" if detail else ""), flush=True)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
api_key = os.environ.get("PAGEINDEX_API_KEY")
|
||||
if not api_key:
|
||||
print("PAGEINDEX_API_KEY not set", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
pdf = Path("examples/documents/attention-residuals.pdf")
|
||||
if not pdf.exists():
|
||||
print(f"Test PDF missing: {pdf}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
client = PageIndexClient(api_key=api_key)
|
||||
log("init", f"cloud mode (key={api_key[:6]}…)")
|
||||
|
||||
# 1) submit_document (legacy SDK signature — fire-and-forget)
|
||||
submit_resp = client.submit_document(file_path=str(pdf))
|
||||
doc_id = submit_resp["doc_id"]
|
||||
log("submit_document", f"doc_id={doc_id}")
|
||||
|
||||
try:
|
||||
# 2) poll is_retrieval_ready (with hard timeout)
|
||||
deadline = time.time() + 600 # 10 min
|
||||
while time.time() < deadline:
|
||||
if client.is_retrieval_ready(doc_id):
|
||||
log("is_retrieval_ready", "True")
|
||||
break
|
||||
time.sleep(8)
|
||||
else:
|
||||
log("is_retrieval_ready", "TIMEOUT")
|
||||
return 2
|
||||
|
||||
# 3) get_tree
|
||||
tree = client.get_tree(doc_id)
|
||||
node_count = len(tree.get("result") or tree.get("tree") or [])
|
||||
log("get_tree", f"top-level nodes={node_count}, status={tree.get('status')}")
|
||||
|
||||
# 4) get_document (metadata)
|
||||
meta = client.get_document(doc_id)
|
||||
log("get_document", f"name={meta.get('name')!r} pages={meta.get('pageNum')} status={meta.get('status')}")
|
||||
|
||||
# 5) chat_completions (non-stream)
|
||||
chat = client.chat_completions(
|
||||
messages=[{"role": "user", "content": "What is this paper about? Answer in one sentence."}],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
answer = (chat.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
||||
log("chat_completions", f"answer={answer[:120]!r}")
|
||||
|
||||
# 6) chat_completions (stream) — full consumption
|
||||
log("chat_completions stream", "starting…")
|
||||
print("[stream] ", end="", flush=True)
|
||||
chunk_count = 0
|
||||
for chunk in client.chat_completions(
|
||||
messages=[{"role": "user", "content": "List 3 keywords from this paper."}],
|
||||
doc_id=doc_id,
|
||||
stream=True,
|
||||
):
|
||||
print(chunk, end="", flush=True)
|
||||
chunk_count += 1
|
||||
print() # newline after streaming
|
||||
log("chat_completions stream", f"chunks received={chunk_count}")
|
||||
|
||||
finally:
|
||||
# 7) delete_document
|
||||
del_resp = client.delete_document(doc_id)
|
||||
log("delete_document", f"resp={del_resp}")
|
||||
|
||||
log("done", "all steps OK")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from pageindex.errors import (
|
||||
PageIndexError,
|
||||
PageIndexAPIError,
|
||||
CollectionNotFoundError,
|
||||
DocumentNotFoundError,
|
||||
IndexingError,
|
||||
|
|
@ -9,9 +10,10 @@ from pageindex.errors import (
|
|||
|
||||
|
||||
def test_all_errors_inherit_from_base():
|
||||
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
for cls in [PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
assert issubclass(cls, PageIndexError)
|
||||
assert issubclass(cls, Exception)
|
||||
assert issubclass(CloudAPIError, PageIndexAPIError)
|
||||
|
||||
|
||||
def test_error_message():
|
||||
|
|
@ -20,7 +22,7 @@ def test_error_message():
|
|||
|
||||
|
||||
def test_catch_base_catches_all():
|
||||
for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
for cls in [PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]:
|
||||
try:
|
||||
raise cls("test")
|
||||
except PageIndexError:
|
||||
|
|
|
|||
325
tests/test_legacy_sdk_contract.py
Normal file
325
tests/test_legacy_sdk_contract.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
import pytest
|
||||
import requests
|
||||
|
||||
from pageindex.client import PageIndexAPIError as ClientPageIndexAPIError
|
||||
from pageindex import PageIndexAPIError, PageIndexClient
|
||||
from pageindex.client import CloudClient
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code=200, payload=None, text="ok", lines=None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload or {}
|
||||
self.text = text
|
||||
self._lines = lines or []
|
||||
self.closed = False
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
def iter_lines(self):
|
||||
return iter(self._lines)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class StreamingErrorResponse(FakeResponse):
|
||||
def iter_lines(self):
|
||||
raise requests.ReadTimeout("stream stalled")
|
||||
|
||||
|
||||
def test_legacy_imports_and_initializers():
|
||||
positional = PageIndexClient("pi-test")
|
||||
keyword = PageIndexClient(api_key="pi-test")
|
||||
cloud = CloudClient(api_key="pi-test")
|
||||
|
||||
assert positional._legacy_cloud_api.api_key == "pi-test"
|
||||
assert keyword._legacy_cloud_api.api_key == "pi-test"
|
||||
assert cloud._legacy_cloud_api.api_key == "pi-test"
|
||||
assert issubclass(PageIndexAPIError, Exception)
|
||||
assert ClientPageIndexAPIError is PageIndexAPIError
|
||||
|
||||
|
||||
def test_legacy_methods_exist():
|
||||
client = PageIndexClient("pi-test")
|
||||
for method_name in [
|
||||
"submit_document",
|
||||
"get_ocr",
|
||||
"get_tree",
|
||||
"is_retrieval_ready",
|
||||
"submit_query",
|
||||
"get_retrieval",
|
||||
"chat_completions",
|
||||
"get_document",
|
||||
"delete_document",
|
||||
"list_documents",
|
||||
"create_folder",
|
||||
"list_folders",
|
||||
]:
|
||||
assert callable(getattr(client, method_name))
|
||||
|
||||
|
||||
def test_legacy_base_url_can_be_overridden_from_client(monkeypatch):
|
||||
calls = []
|
||||
|
||||
def fake_request(method, url, headers=None, **kwargs):
|
||||
calls.append({"method": method, "url": url, "headers": headers})
|
||||
return FakeResponse(payload={"id": "doc-1"})
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
monkeypatch.setattr(PageIndexClient, "BASE_URL", "https://staging.pageindex.test")
|
||||
|
||||
result = PageIndexClient("pi-test").get_document("doc-1")
|
||||
|
||||
assert result == {"id": "doc-1"}
|
||||
assert calls[0]["method"] == "GET"
|
||||
assert calls[0]["url"] == "https://staging.pageindex.test/doc/doc-1/metadata/"
|
||||
assert calls[0]["headers"] == {"api_key": "pi-test"}
|
||||
|
||||
|
||||
def test_submit_document_uses_legacy_endpoint(monkeypatch, tmp_path):
|
||||
calls = []
|
||||
|
||||
def fake_request(method, url, headers=None, files=None, data=None, **kwargs):
|
||||
calls.append({
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"data": data,
|
||||
"files": files,
|
||||
"kwargs": kwargs,
|
||||
})
|
||||
return FakeResponse(payload={"doc_id": "doc-1"})
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
pdf = tmp_path / "doc.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4")
|
||||
result = PageIndexClient("pi-test").submit_document(
|
||||
str(pdf),
|
||||
mode="mcp",
|
||||
beta_headers=["block_reference"],
|
||||
folder_id="folder-1",
|
||||
)
|
||||
|
||||
assert result == {"doc_id": "doc-1"}
|
||||
assert calls[0]["method"] == "POST"
|
||||
assert calls[0]["url"] == "https://api.pageindex.ai/doc/"
|
||||
assert calls[0]["headers"] == {"api_key": "pi-test"}
|
||||
assert "timeout" not in calls[0]["kwargs"]
|
||||
assert calls[0]["data"]["if_retrieval"] is True
|
||||
assert calls[0]["data"]["mode"] == "mcp"
|
||||
assert calls[0]["data"]["beta_headers"] == '["block_reference"]'
|
||||
assert calls[0]["data"]["folder_id"] == "folder-1"
|
||||
|
||||
|
||||
def test_get_ocr_and_tree_use_legacy_urls(monkeypatch):
|
||||
get_calls = []
|
||||
|
||||
def fake_request(method, url, headers=None, **kwargs):
|
||||
get_calls.append({"method": method, "url": url, "headers": headers})
|
||||
return FakeResponse(payload={"status": "completed", "retrieval_ready": True})
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
client = PageIndexClient("pi-test")
|
||||
|
||||
assert client.get_ocr("doc-1", format="page")["status"] == "completed"
|
||||
assert client.get_tree("doc-1", node_summary=True)["retrieval_ready"] is True
|
||||
|
||||
assert get_calls[0]["method"] == "GET"
|
||||
assert get_calls[0]["url"] == "https://api.pageindex.ai/doc/doc-1/?type=ocr&format=page"
|
||||
assert get_calls[1]["url"] == "https://api.pageindex.ai/doc/doc-1/?type=tree&summary=True"
|
||||
|
||||
|
||||
def test_get_ocr_rejects_invalid_format():
|
||||
with pytest.raises(ValueError, match="Format parameter must be"):
|
||||
PageIndexClient("pi-test").get_ocr("doc-1", format="bad")
|
||||
|
||||
|
||||
def test_submit_query_uses_legacy_payload(monkeypatch):
|
||||
calls = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, **kwargs):
|
||||
calls.append({"method": method, "url": url, "headers": headers, "json": json})
|
||||
return FakeResponse(payload={"retrieval_id": "ret-1"})
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
result = PageIndexClient("pi-test").submit_query("doc-1", "What changed?", thinking=True)
|
||||
|
||||
assert result == {"retrieval_id": "ret-1"}
|
||||
assert calls[0]["method"] == "POST"
|
||||
assert calls[0]["url"] == "https://api.pageindex.ai/retrieval/"
|
||||
assert calls[0]["json"] == {
|
||||
"doc_id": "doc-1",
|
||||
"query": "What changed?",
|
||||
"thinking": True,
|
||||
}
|
||||
|
||||
|
||||
def test_chat_completions_non_stream_returns_json(monkeypatch):
|
||||
calls = []
|
||||
payload = {"choices": [{"message": {"content": "answer"}}]}
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, stream=False, **kwargs):
|
||||
calls.append({
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"json": json,
|
||||
"stream": stream,
|
||||
})
|
||||
return FakeResponse(payload=payload)
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
result = PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
doc_id=["doc-1"],
|
||||
temperature=0.1,
|
||||
enable_citations=True,
|
||||
)
|
||||
|
||||
assert result == payload
|
||||
assert calls[0]["method"] == "POST"
|
||||
assert calls[0]["url"] == "https://api.pageindex.ai/chat/completions/"
|
||||
assert calls[0]["stream"] is False
|
||||
assert calls[0]["json"] == {
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
"doc_id": ["doc-1"],
|
||||
"temperature": 0.1,
|
||||
"enable_citations": True,
|
||||
}
|
||||
|
||||
|
||||
def test_chat_completions_stream_parses_text_chunks(monkeypatch):
|
||||
calls = []
|
||||
lines = [
|
||||
b'data: {"choices":[{"delta":{"content":"hel"}}]}',
|
||||
b'data: {"choices":[{"delta":{"content":"lo"}}]}',
|
||||
b"data: [DONE]",
|
||||
]
|
||||
|
||||
def fake_request(method, url, **kwargs):
|
||||
calls.append({"method": method, "url": url, "kwargs": kwargs})
|
||||
return FakeResponse(lines=lines)
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
chunks = list(PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
))
|
||||
|
||||
assert chunks == ["hel", "lo"]
|
||||
assert "timeout" not in calls[0]["kwargs"]
|
||||
|
||||
|
||||
def test_chat_completions_stream_metadata_returns_raw_chunks(monkeypatch):
|
||||
calls = []
|
||||
lines = [
|
||||
b'data: {"object":"chat.completion.chunk"}',
|
||||
b"data: [DONE]",
|
||||
]
|
||||
|
||||
def fake_request(method, url, **kwargs):
|
||||
calls.append({"method": method, "url": url, "json": kwargs.get("json")})
|
||||
return FakeResponse(lines=lines)
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
chunks = list(PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
stream_metadata=True,
|
||||
))
|
||||
|
||||
assert chunks == [{"object": "chat.completion.chunk"}]
|
||||
assert "stream_metadata" not in calls[0]["json"]
|
||||
|
||||
|
||||
def test_chat_completions_stream_errors_are_pageindex_api_error(monkeypatch):
|
||||
def fake_request(*args, **kwargs):
|
||||
return StreamingErrorResponse()
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
stream = PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with pytest.raises(PageIndexAPIError, match="Failed to stream chat completion: stream stalled"):
|
||||
list(stream)
|
||||
|
||||
|
||||
def test_api_errors_are_pageindex_api_error(monkeypatch):
|
||||
def fake_request(*args, **kwargs):
|
||||
return FakeResponse(status_code=500, text="server error")
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
with pytest.raises(PageIndexAPIError, match="Failed to get document metadata"):
|
||||
PageIndexClient("pi-test").get_document("doc-1")
|
||||
|
||||
|
||||
def test_network_errors_are_wrapped_as_pageindex_api_error(monkeypatch):
|
||||
def fake_request(*args, **kwargs):
|
||||
raise requests.Timeout("slow network")
|
||||
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request)
|
||||
|
||||
with pytest.raises(PageIndexAPIError, match="Failed to get document metadata: slow network"):
|
||||
PageIndexClient("pi-test").get_document("doc-1")
|
||||
|
||||
|
||||
def test_list_documents_validates_legacy_pagination():
|
||||
client = PageIndexClient("pi-test")
|
||||
|
||||
with pytest.raises(ValueError, match="limit must be between 1 and 100"):
|
||||
client.list_documents(limit=0)
|
||||
with pytest.raises(ValueError, match="offset must be non-negative"):
|
||||
client.list_documents(offset=-1)
|
||||
|
||||
|
||||
def test_chat_completions_stream_closes_response_after_done(monkeypatch):
|
||||
fake = FakeResponse(lines=[
|
||||
b'data: {"choices":[{"delta":{"content":"hi"}}]}',
|
||||
b"data: [DONE]",
|
||||
])
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request",
|
||||
lambda *a, **kw: fake)
|
||||
|
||||
list(PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "x"}], stream=True,
|
||||
))
|
||||
assert fake.closed is True
|
||||
|
||||
|
||||
def test_chat_completions_stream_closes_response_on_early_abandon(monkeypatch):
|
||||
fake = FakeResponse(lines=[
|
||||
b'data: {"choices":[{"delta":{"content":"a"}}]}',
|
||||
b'data: {"choices":[{"delta":{"content":"b"}}]}',
|
||||
b"data: [DONE]",
|
||||
])
|
||||
monkeypatch.setattr("pageindex.cloud_api.requests.request",
|
||||
lambda *a, **kw: fake)
|
||||
|
||||
gen = PageIndexClient("pi-test").chat_completions(
|
||||
[{"role": "user", "content": "x"}], stream=True,
|
||||
)
|
||||
next(gen)
|
||||
gen.close()
|
||||
assert fake.closed is True
|
||||
|
||||
|
||||
def test_empty_api_key_warns_and_falls_back_to_local(caplog, tmp_path, monkeypatch):
|
||||
import logging
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
with caplog.at_level(logging.WARNING, logger="pageindex.client"):
|
||||
client = PageIndexClient(api_key="", storage_path=str(tmp_path))
|
||||
|
||||
assert any("empty api_key" in r.message for r in caplog.records)
|
||||
assert client._legacy_cloud_api is None
|
||||
106
tests/test_legacy_utils_contract.py
Normal file
106
tests/test_legacy_utils_contract.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import sys
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from pageindex import utils
|
||||
|
||||
|
||||
def test_remove_fields_keeps_legacy_max_len():
|
||||
data = {
|
||||
"title": "A long title",
|
||||
"text": "hidden",
|
||||
"nodes": [{"summary": "abcdefghijklmnopqrstuvwxyz"}],
|
||||
}
|
||||
|
||||
result = utils.remove_fields(data, fields=["text"], max_len=5)
|
||||
|
||||
assert "text" not in result
|
||||
assert result["title"] == "A lon..."
|
||||
assert result["nodes"][0]["summary"] == "abcde..."
|
||||
|
||||
|
||||
def test_create_node_mapping_keeps_legacy_page_ranges():
|
||||
tree = [
|
||||
{
|
||||
"node_id": "0001",
|
||||
"title": "Root",
|
||||
"page_index": 1,
|
||||
"nodes": [
|
||||
{"node_id": "0002", "title": "Child", "page_index": 3, "nodes": []},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
plain = utils.create_node_mapping(tree)
|
||||
ranged = utils.create_node_mapping(tree, include_page_ranges=True, max_page=8)
|
||||
|
||||
assert plain["0001"]["title"] == "Root"
|
||||
assert ranged["0001"]["start_index"] == 1
|
||||
assert ranged["0001"]["end_index"] == 3
|
||||
assert ranged["0002"]["start_index"] == 3
|
||||
assert ranged["0002"]["end_index"] == 8
|
||||
|
||||
|
||||
def test_create_node_mapping_prefers_existing_start_end_ranges():
|
||||
tree = [
|
||||
{
|
||||
"node_id": "0001",
|
||||
"title": "Root",
|
||||
"start_index": 1,
|
||||
"end_index": 10,
|
||||
"nodes": [
|
||||
{"node_id": "0002", "title": "Child", "start_index": 3, "end_index": 5},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
ranged = utils.create_node_mapping(tree, include_page_ranges=True, max_page=12)
|
||||
|
||||
assert ranged["0001"]["start_index"] == 1
|
||||
assert ranged["0001"]["end_index"] == 10
|
||||
assert ranged["0002"]["start_index"] == 3
|
||||
assert ranged["0002"]["end_index"] == 5
|
||||
|
||||
|
||||
def test_print_tree_keeps_legacy_exclude_fields(capsys):
|
||||
tree = [{"node_id": "0001", "title": "Root", "text": "hidden", "page_index": 1}]
|
||||
|
||||
utils.print_tree(tree)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Root" in out
|
||||
assert "hidden" not in out
|
||||
assert "page_index" not in out
|
||||
|
||||
|
||||
def test_call_llm_keeps_legacy_async_openai_contract(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeCompletions:
|
||||
async def create(self, **kwargs):
|
||||
calls.append(kwargs)
|
||||
message = SimpleNamespace(content=" answer ")
|
||||
choice = SimpleNamespace(message=message)
|
||||
return SimpleNamespace(choices=[choice])
|
||||
|
||||
class FakeAsyncOpenAI:
|
||||
def __init__(self, api_key):
|
||||
self.api_key = api_key
|
||||
self.chat = SimpleNamespace(completions=FakeCompletions())
|
||||
|
||||
fake_openai = SimpleNamespace(AsyncOpenAI=FakeAsyncOpenAI)
|
||||
monkeypatch.setitem(sys.modules, "openai", fake_openai)
|
||||
|
||||
result = asyncio.run(utils.call_llm(
|
||||
"hello",
|
||||
api_key="sk-test",
|
||||
model="gpt-test",
|
||||
temperature=0.2,
|
||||
))
|
||||
|
||||
assert result == "answer"
|
||||
assert calls == [{
|
||||
"model": "gpt-test",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"temperature": 0.2,
|
||||
}]
|
||||
Loading…
Add table
Add a link
Reference in a new issue