From de1992def10b0cc70fb76a519a5eb1f76cf930bf Mon Sep 17 00:00:00 2001 From: BukeLy Date: Tue, 26 May 2026 17:21:44 +0800 Subject: [PATCH] refactor(filesystem): make pifs providers configurable --- examples/pifs_demo.py | 66 ++++++++++++++++----- pageindex/filesystem/__init__.py | 4 +- pageindex/filesystem/core.py | 19 +++++- pageindex/filesystem/hybrid_projection.py | 55 +++++++---------- pageindex/filesystem/metadata_generation.py | 32 ++++++++-- tests/test_metadata_generation.py | 30 ++++++++++ tests/test_semantic_index.py | 9 +++ 7 files changed, 154 insertions(+), 61 deletions(-) create mode 100644 tests/test_metadata_generation.py diff --git a/examples/pifs_demo.py b/examples/pifs_demo.py index f6f9b51..cf5bf82 100644 --- a/examples/pifs_demo.py +++ b/examples/pifs_demo.py @@ -39,7 +39,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "true") from pageindex import PageIndexClient -from pageindex.filesystem import OpenAIMetadataGenerator, PageIndexFileSystem, PIFSCommandExecutor +from pageindex.filesystem import MetadataGenerator, PageIndexFileSystem, PIFSCommandExecutor from pageindex.filesystem.agent import run_pifs_agent @@ -47,6 +47,12 @@ EXAMPLES_DIR = Path(__file__).parent DOCUMENTS_DIR = EXAMPLES_DIR / "documents" WORKSPACE = EXAMPLES_DIR / "pifs_workspace" DEFAULT_MODEL = os.environ.get("PIFS_DEMO_MODEL", "gpt-5.4-mini") +DEFAULT_METADATA_PROVIDER = os.environ.get("PIFS_DEMO_METADATA_PROVIDER") or os.environ.get( + "PIFS_METADATA_PROVIDER", "openai" +) +DEFAULT_EMBEDDING_PROVIDER = os.environ.get("PIFS_DEMO_EMBEDDING_PROVIDER") or os.environ.get( + "PIFS_EMBEDDING_PROVIDER", "openai" +) DEFAULT_QUESTION = ( "Use the PIFS workspace to find the Federal Reserve annual report. " "Which section covers supervision and regulation, and what page range " @@ -110,10 +116,15 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--question", default=DEFAULT_QUESTION) parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument( + "--metadata-provider", + default=DEFAULT_METADATA_PROVIDER, + help="Provider used for register-time metadata generation.", + ) parser.add_argument( "--metadata-model", default=os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano"), - help="OpenAI or OpenAI-compatible model used for register-time metadata.", + help="Model used for register-time metadata generation.", ) parser.add_argument("--stream-mode", default="all", choices=["off", "tools", "model", "all"]) parser.add_argument("--verbose", action="store_true") @@ -121,23 +132,40 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--max-seconds", type=float, default=90) parser.add_argument("--reasoning-effort", default=None) parser.add_argument("--reasoning-summary", default="auto") + parser.add_argument( + "--embedding-provider", + default=DEFAULT_EMBEDDING_PROVIDER, + help="Provider used for register-time summary projection embeddings.", + ) parser.add_argument( "--embedding-model", default=os.environ.get("PIFS_DEMO_EMBEDDING_MODEL", "text-embedding-3-small"), - help="OpenAI embedding model used for register-time summary projection.", + help="Embedding model used for register-time summary projection.", ) parser.add_argument("--embedding-dimensions", type=int, default=256) return parser.parse_args() -def require_openai_environment() -> None: - if os.environ.get("OPENAI_API_KEY"): - return - raise RuntimeError( - "OPENAI_API_KEY is required for this demo: register() generates real " - "PIFS metadata and the agent uses the OpenAI Agents SDK. Source your " - ".env or export OPENAI_API_KEY before running." - ) +def require_runtime_environment(*, metadata_provider: str, embedding_provider: str) -> None: + metadata_provider = metadata_provider.lower() + embedding_provider = embedding_provider.lower() + missing: list[str] = [] + if not os.environ.get("OPENAI_API_KEY"): + missing.append("OPENAI_API_KEY for the OpenAI Agents SDK demo agent") + if metadata_provider == "openai" and not ( + os.environ.get("PIFS_METADATA_API_KEY") or os.environ.get("OPENAI_API_KEY") + ): + missing.append("PIFS_METADATA_API_KEY or OPENAI_API_KEY for metadata generation") + if embedding_provider == "openai" and not ( + os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY") + ): + missing.append("PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY for summary embeddings") + if missing: + raise RuntimeError( + "Missing required environment variable(s): " + + "; ".join(missing) + + ". Source your .env or export the required key before running." + ) def discover_cached_documents(documents_dir: Path) -> list[Path]: @@ -294,6 +322,7 @@ def backfill_registered_metadata_values(filesystem: PageIndexFileSystem, file_re def configure_summary_projection_backend( filesystem: PageIndexFileSystem, *, + embedding_provider: str, embedding_model: str, embedding_dimensions: int, ) -> None: @@ -301,7 +330,7 @@ def configure_summary_projection_backend( return filesystem.configure_hybrid_projection_retrieval( filesystem.summary_projection_index_dir, - embedding_provider="openai", + embedding_provider=embedding_provider, embedding_model=embedding_model, embedding_dimensions=embedding_dimensions, ) @@ -690,7 +719,10 @@ def run_smoke_commands( def main() -> None: args = parse_args() - require_openai_environment() + require_runtime_environment( + metadata_provider=args.metadata_provider, + embedding_provider=args.embedding_provider, + ) workspace = args.workspace.expanduser() documents_dir = args.documents_dir.expanduser() if args.reset and workspace.exists(): @@ -705,8 +737,11 @@ def main() -> None: filesystem = PageIndexFileSystem( workspace, - metadata_generator=OpenAIMetadataGenerator(model=args.metadata_model), - summary_projection_embedding_provider="openai", + metadata_generator=MetadataGenerator( + provider=args.metadata_provider, + model=args.metadata_model, + ), + summary_projection_embedding_provider=args.embedding_provider, summary_projection_embedding_model=args.embedding_model, summary_projection_embedding_dimensions=args.embedding_dimensions, ) @@ -718,6 +753,7 @@ def main() -> None: registered = register_documents(filesystem, documents, documents_dir=documents_dir) configure_summary_projection_backend( filesystem, + embedding_provider=args.embedding_provider, embedding_model=args.embedding_model, embedding_dimensions=args.embedding_dimensions, ) diff --git a/pageindex/filesystem/__init__.py b/pageindex/filesystem/__init__.py index 2ad1c84..a6cde16 100644 --- a/pageindex/filesystem/__init__.py +++ b/pageindex/filesystem/__init__.py @@ -2,11 +2,11 @@ from .commands import PIFSCommandExecutor from .core import PageIndexFileSystem from .hybrid_projection import HybridProjectionSearchBackend from .metadata_generation import ( + MetadataGenerationBackend, MetadataGenerationError, MetadataGenerationInput, MetadataGenerationResult, MetadataGenerator, - OpenAIMetadataGenerator, ) from .projection_indexing import SummaryProjectionIndexer from .semantic_index import ( @@ -20,11 +20,11 @@ from .types import OpenResult, SearchResult __all__ = [ "OpenResult", "HybridProjectionSearchBackend", + "MetadataGenerationBackend", "MetadataGenerationError", "MetadataGenerationInput", "MetadataGenerationResult", "MetadataGenerator", - "OpenAIMetadataGenerator", "PIFSCommandExecutor", "PageIndexFileSystem", "RebuildableSemanticIndex", diff --git a/pageindex/filesystem/core.py b/pageindex/filesystem/core.py index bbf81f1..fc096e3 100644 --- a/pageindex/filesystem/core.py +++ b/pageindex/filesystem/core.py @@ -9,11 +9,11 @@ from urllib.parse import unquote, urlparse from ..client import PageIndexClient from .metadata import MetadataQueryEngine from .metadata_generation import ( + MetadataGenerationBackend, MetadataGenerationError, MetadataGenerationInput, MetadataGenerationResult, MetadataGenerator, - OpenAIMetadataGenerator, ) from .projection_indexing import SummaryProjectionIndexer from .semantic_folder_policy import ( @@ -91,7 +91,11 @@ class PageIndexFileSystem: workspace: Union[str, Path], *, semantic_retrieval_backend: Any | None = None, - metadata_generator: MetadataGenerator | None = None, + metadata_generator: MetadataGenerationBackend | None = None, + metadata_provider: str = "openai", + metadata_model: str | None = None, + metadata_base_url: str | None = None, + metadata_max_text_chars: int = 24000, summary_projection_indexer: SummaryProjectionIndexer | None = None, summary_projection_index: bool = True, summary_projection_index_dir: Union[str, Path, None] = None, @@ -105,6 +109,10 @@ class PageIndexFileSystem: self.metadata = MetadataQueryEngine(self.store) self.semantic_retrieval_backend = semantic_retrieval_backend self.metadata_generator = metadata_generator + self.metadata_provider = metadata_provider + self.metadata_model = metadata_model + self.metadata_base_url = metadata_base_url + self.metadata_max_text_chars = metadata_max_text_chars self.summary_projection_indexer = summary_projection_indexer self.summary_projection_index = summary_projection_index self.summary_projection_index_dir = ( @@ -199,7 +207,12 @@ class PageIndexFileSystem: def _ensure_register_completion_defaults(self) -> None: if self.metadata_generator is None: - self.metadata_generator = OpenAIMetadataGenerator() + self.metadata_generator = MetadataGenerator( + provider=self.metadata_provider, + model=self.metadata_model, + base_url=self.metadata_base_url, + max_text_chars=self.metadata_max_text_chars, + ) if self.summary_projection_index and self.summary_projection_indexer is None: self.summary_projection_indexer = SummaryProjectionIndexer.from_provider( self.summary_projection_index_dir, diff --git a/pageindex/filesystem/hybrid_projection.py b/pageindex/filesystem/hybrid_projection.py index 30df591..e802ab3 100644 --- a/pageindex/filesystem/hybrid_projection.py +++ b/pageindex/filesystem/hybrid_projection.py @@ -1,6 +1,5 @@ from __future__ import annotations -import hashlib import json import os import re @@ -331,17 +330,22 @@ class EmbeddingCache: return [cached[text_hash] for text_hash in hashes] -class OpenAIEmbeddingClient: - def __init__(self, model: str, *, dimensions: int, timeout: float): - from openai import OpenAI - +class EmbeddingClient: + def __init__(self, *, provider: str, model: str, dimensions: int, timeout: float): + self.provider = provider.lower() self.model = model self.dimensions = dimensions - self.client = OpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), - base_url=os.environ.get("OPENAI_BASE_URL") or None, - timeout=timeout, - ) + if self.provider != "openai": + raise ValueError(f"unknown embedding provider: {provider}") + from openai import OpenAI + + api_key = os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY") + base_url = os.environ.get("PIFS_EMBEDDING_BASE_URL") or os.environ.get("OPENAI_BASE_URL") + if not api_key: + raise ValueError( + "PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY is required for PIFS embeddings" + ) + self.client = OpenAI(api_key=api_key, base_url=base_url or None, timeout=timeout) def embed(self, texts: list[str]) -> list[list[float]]: kwargs: dict[str, Any] = {"model": self.model, "input": texts} @@ -351,32 +355,13 @@ class OpenAIEmbeddingClient: return [list(item.embedding) for item in sorted(response.data, key=lambda item: item.index)] -class HashEmbeddingClient: - def __init__(self, dimensions: int = 256): - self.dimensions = dimensions - - def embed(self, texts: list[str]) -> list[list[float]]: - return [self._embed_one(text) for text in texts] - - def _embed_one(self, text: str) -> list[float]: - vector = [0.0] * self.dimensions - for term in keyword_terms(text)[:256]: - digest = hashlib.blake2b(term.encode("utf-8"), digest_size=8).digest() - bucket = int.from_bytes(digest[:4], "little") % self.dimensions - sign = 1.0 if digest[4] % 2 == 0 else -1.0 - vector[bucket] += sign - norm = sum(value * value for value in vector) ** 0.5 - if norm: - vector = [value / norm for value in vector] - return vector - - def make_embedder(provider: str, model: str, *, dimensions: int, timeout: float) -> Any: - if provider == "openai": - return OpenAIEmbeddingClient(model, dimensions=dimensions, timeout=timeout) - if provider == "hash": - return HashEmbeddingClient(dimensions=dimensions if dimensions > 0 else 256) - raise ValueError(f"unknown embedding provider: {provider}") + return EmbeddingClient( + provider=provider, + model=model, + dimensions=dimensions, + timeout=timeout, + ) def query_text_for_channel(channel: str, query: str, projection: QueryProjection) -> str: diff --git a/pageindex/filesystem/metadata_generation.py b/pageindex/filesystem/metadata_generation.py index 1935455..86b2ac6 100644 --- a/pageindex/filesystem/metadata_generation.py +++ b/pageindex/filesystem/metadata_generation.py @@ -32,7 +32,7 @@ class MetadataGenerationResult: failures: dict[str, str] = field(default_factory=dict) -class MetadataGenerator(Protocol): +class MetadataGenerationBackend(Protocol): def generate( self, request: MetadataGenerationInput, @@ -42,23 +42,31 @@ class MetadataGenerator(Protocol): ... -class OpenAIMetadataGenerator: +class MetadataGenerator: """Default product generator for retrieval metadata. This intentionally lives under pageindex.filesystem instead of benchmark paths. It uses registered text today; callers can pass PageIndex-extracted text through the same MetadataGenerationInput without changing the API. + Provider selection is an instance parameter rather than a provider-specific + public class name. """ def __init__( self, *, + provider: str | None = None, model: str | None = None, base_url: str | None = None, max_text_chars: int = 24000, ): + self.provider = (provider or os.environ.get("PIFS_METADATA_PROVIDER", "openai")).lower() self.model = model or os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano") - self.base_url = base_url if base_url is not None else os.environ.get("OPENAI_BASE_URL") + self.base_url = ( + base_url + if base_url is not None + else os.environ.get("PIFS_METADATA_BASE_URL") or os.environ.get("OPENAI_BASE_URL") + ) self.max_text_chars = max_text_chars def generate( @@ -67,9 +75,21 @@ class OpenAIMetadataGenerator: *, fields: list[str], ) -> MetadataGenerationResult: - api_key = os.environ.get("OPENAI_API_KEY") + if self.provider != "openai": + raise MetadataGenerationError(f"unsupported metadata provider: {self.provider}") + return self._generate_openai(request, fields=fields) + + def _generate_openai( + self, + request: MetadataGenerationInput, + *, + fields: list[str], + ) -> MetadataGenerationResult: + api_key = os.environ.get("PIFS_METADATA_API_KEY") or os.environ.get("OPENAI_API_KEY") if not api_key: - raise MetadataGenerationError("OPENAI_API_KEY is required for PIFS metadata generation") + raise MetadataGenerationError( + "PIFS_METADATA_API_KEY or OPENAI_API_KEY is required for PIFS metadata generation" + ) from openai import OpenAI @@ -122,7 +142,7 @@ class OpenAIMetadataGenerator: properties[field] = {"type": "string"} else: raise MetadataGenerationError( - f"OpenAIMetadataGenerator does not support generated metadata field: {field}" + f"MetadataGenerator does not support generated metadata field: {field}" ) return { "type": "json_schema", diff --git a/tests/test_metadata_generation.py b/tests/test_metadata_generation.py new file mode 100644 index 0000000..3e64a4b --- /dev/null +++ b/tests/test_metadata_generation.py @@ -0,0 +1,30 @@ +import sys +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def test_metadata_generator_uses_provider_parameter(): + from pageindex.filesystem.metadata_generation import ( + MetadataGenerationError, + MetadataGenerationInput, + MetadataGenerator, + ) + + generator = MetadataGenerator(provider="unsupported", model="unused") + request = MetadataGenerationInput( + file_ref="file_a", + external_id="doc_a", + title="A", + source_path="docs/a.txt", + content_type="text/plain", + source_type=None, + text="hello", + ) + + with pytest.raises(MetadataGenerationError, match="unsupported metadata provider: unsupported"): + generator.generate(request, fields=["summary"]) diff --git a/tests/test_semantic_index.py b/tests/test_semantic_index.py index a06641e..4bd9085 100644 --- a/tests/test_semantic_index.py +++ b/tests/test_semantic_index.py @@ -1,6 +1,8 @@ import sys from pathlib import Path +import pytest + REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) @@ -87,3 +89,10 @@ def test_summary_projection_indexes_unified_metadata_summary(tmp_path): assert hits[0].external_id == "doc_a" assert hits[0].metadata["summary"] == "Unified metadata summary." assert hits[0].metadata["department"] == "ops" + + +def test_hash_embedding_provider_is_not_available(): + from pageindex.filesystem.hybrid_projection import make_embedder + + with pytest.raises(ValueError, match="unknown embedding provider: hash"): + make_embedder("hash", "unused", dimensions=256, timeout=1)