mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-07-03 20:41:02 +02:00
refactor(filesystem): make pifs providers configurable
This commit is contained in:
parent
7c021a7dd0
commit
de1992def1
7 changed files with 154 additions and 61 deletions
|
|
@ -39,7 +39,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "true")
|
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "true")
|
||||||
|
|
||||||
from pageindex import PageIndexClient
|
from pageindex import PageIndexClient
|
||||||
from pageindex.filesystem import OpenAIMetadataGenerator, PageIndexFileSystem, PIFSCommandExecutor
|
from pageindex.filesystem import MetadataGenerator, PageIndexFileSystem, PIFSCommandExecutor
|
||||||
from pageindex.filesystem.agent import run_pifs_agent
|
from pageindex.filesystem.agent import run_pifs_agent
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -47,6 +47,12 @@ EXAMPLES_DIR = Path(__file__).parent
|
||||||
DOCUMENTS_DIR = EXAMPLES_DIR / "documents"
|
DOCUMENTS_DIR = EXAMPLES_DIR / "documents"
|
||||||
WORKSPACE = EXAMPLES_DIR / "pifs_workspace"
|
WORKSPACE = EXAMPLES_DIR / "pifs_workspace"
|
||||||
DEFAULT_MODEL = os.environ.get("PIFS_DEMO_MODEL", "gpt-5.4-mini")
|
DEFAULT_MODEL = os.environ.get("PIFS_DEMO_MODEL", "gpt-5.4-mini")
|
||||||
|
DEFAULT_METADATA_PROVIDER = os.environ.get("PIFS_DEMO_METADATA_PROVIDER") or os.environ.get(
|
||||||
|
"PIFS_METADATA_PROVIDER", "openai"
|
||||||
|
)
|
||||||
|
DEFAULT_EMBEDDING_PROVIDER = os.environ.get("PIFS_DEMO_EMBEDDING_PROVIDER") or os.environ.get(
|
||||||
|
"PIFS_EMBEDDING_PROVIDER", "openai"
|
||||||
|
)
|
||||||
DEFAULT_QUESTION = (
|
DEFAULT_QUESTION = (
|
||||||
"Use the PIFS workspace to find the Federal Reserve annual report. "
|
"Use the PIFS workspace to find the Federal Reserve annual report. "
|
||||||
"Which section covers supervision and regulation, and what page range "
|
"Which section covers supervision and regulation, and what page range "
|
||||||
|
|
@ -110,10 +116,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
parser.add_argument("--question", default=DEFAULT_QUESTION)
|
parser.add_argument("--question", default=DEFAULT_QUESTION)
|
||||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata-provider",
|
||||||
|
default=DEFAULT_METADATA_PROVIDER,
|
||||||
|
help="Provider used for register-time metadata generation.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metadata-model",
|
"--metadata-model",
|
||||||
default=os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano"),
|
default=os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano"),
|
||||||
help="OpenAI or OpenAI-compatible model used for register-time metadata.",
|
help="Model used for register-time metadata generation.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--stream-mode", default="all", choices=["off", "tools", "model", "all"])
|
parser.add_argument("--stream-mode", default="all", choices=["off", "tools", "model", "all"])
|
||||||
parser.add_argument("--verbose", action="store_true")
|
parser.add_argument("--verbose", action="store_true")
|
||||||
|
|
@ -121,23 +132,40 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument("--max-seconds", type=float, default=90)
|
parser.add_argument("--max-seconds", type=float, default=90)
|
||||||
parser.add_argument("--reasoning-effort", default=None)
|
parser.add_argument("--reasoning-effort", default=None)
|
||||||
parser.add_argument("--reasoning-summary", default="auto")
|
parser.add_argument("--reasoning-summary", default="auto")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-provider",
|
||||||
|
default=DEFAULT_EMBEDDING_PROVIDER,
|
||||||
|
help="Provider used for register-time summary projection embeddings.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
default=os.environ.get("PIFS_DEMO_EMBEDDING_MODEL", "text-embedding-3-small"),
|
default=os.environ.get("PIFS_DEMO_EMBEDDING_MODEL", "text-embedding-3-small"),
|
||||||
help="OpenAI embedding model used for register-time summary projection.",
|
help="Embedding model used for register-time summary projection.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--embedding-dimensions", type=int, default=256)
|
parser.add_argument("--embedding-dimensions", type=int, default=256)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def require_openai_environment() -> None:
|
def require_runtime_environment(*, metadata_provider: str, embedding_provider: str) -> None:
|
||||||
if os.environ.get("OPENAI_API_KEY"):
|
metadata_provider = metadata_provider.lower()
|
||||||
return
|
embedding_provider = embedding_provider.lower()
|
||||||
raise RuntimeError(
|
missing: list[str] = []
|
||||||
"OPENAI_API_KEY is required for this demo: register() generates real "
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
"PIFS metadata and the agent uses the OpenAI Agents SDK. Source your "
|
missing.append("OPENAI_API_KEY for the OpenAI Agents SDK demo agent")
|
||||||
".env or export OPENAI_API_KEY before running."
|
if metadata_provider == "openai" and not (
|
||||||
)
|
os.environ.get("PIFS_METADATA_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
|
):
|
||||||
|
missing.append("PIFS_METADATA_API_KEY or OPENAI_API_KEY for metadata generation")
|
||||||
|
if embedding_provider == "openai" and not (
|
||||||
|
os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
|
):
|
||||||
|
missing.append("PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY for summary embeddings")
|
||||||
|
if missing:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Missing required environment variable(s): "
|
||||||
|
+ "; ".join(missing)
|
||||||
|
+ ". Source your .env or export the required key before running."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def discover_cached_documents(documents_dir: Path) -> list[Path]:
|
def discover_cached_documents(documents_dir: Path) -> list[Path]:
|
||||||
|
|
@ -294,6 +322,7 @@ def backfill_registered_metadata_values(filesystem: PageIndexFileSystem, file_re
|
||||||
def configure_summary_projection_backend(
|
def configure_summary_projection_backend(
|
||||||
filesystem: PageIndexFileSystem,
|
filesystem: PageIndexFileSystem,
|
||||||
*,
|
*,
|
||||||
|
embedding_provider: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimensions: int,
|
embedding_dimensions: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -301,7 +330,7 @@ def configure_summary_projection_backend(
|
||||||
return
|
return
|
||||||
filesystem.configure_hybrid_projection_retrieval(
|
filesystem.configure_hybrid_projection_retrieval(
|
||||||
filesystem.summary_projection_index_dir,
|
filesystem.summary_projection_index_dir,
|
||||||
embedding_provider="openai",
|
embedding_provider=embedding_provider,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
embedding_dimensions=embedding_dimensions,
|
embedding_dimensions=embedding_dimensions,
|
||||||
)
|
)
|
||||||
|
|
@ -690,7 +719,10 @@ def run_smoke_commands(
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
require_openai_environment()
|
require_runtime_environment(
|
||||||
|
metadata_provider=args.metadata_provider,
|
||||||
|
embedding_provider=args.embedding_provider,
|
||||||
|
)
|
||||||
workspace = args.workspace.expanduser()
|
workspace = args.workspace.expanduser()
|
||||||
documents_dir = args.documents_dir.expanduser()
|
documents_dir = args.documents_dir.expanduser()
|
||||||
if args.reset and workspace.exists():
|
if args.reset and workspace.exists():
|
||||||
|
|
@ -705,8 +737,11 @@ def main() -> None:
|
||||||
|
|
||||||
filesystem = PageIndexFileSystem(
|
filesystem = PageIndexFileSystem(
|
||||||
workspace,
|
workspace,
|
||||||
metadata_generator=OpenAIMetadataGenerator(model=args.metadata_model),
|
metadata_generator=MetadataGenerator(
|
||||||
summary_projection_embedding_provider="openai",
|
provider=args.metadata_provider,
|
||||||
|
model=args.metadata_model,
|
||||||
|
),
|
||||||
|
summary_projection_embedding_provider=args.embedding_provider,
|
||||||
summary_projection_embedding_model=args.embedding_model,
|
summary_projection_embedding_model=args.embedding_model,
|
||||||
summary_projection_embedding_dimensions=args.embedding_dimensions,
|
summary_projection_embedding_dimensions=args.embedding_dimensions,
|
||||||
)
|
)
|
||||||
|
|
@ -718,6 +753,7 @@ def main() -> None:
|
||||||
registered = register_documents(filesystem, documents, documents_dir=documents_dir)
|
registered = register_documents(filesystem, documents, documents_dir=documents_dir)
|
||||||
configure_summary_projection_backend(
|
configure_summary_projection_backend(
|
||||||
filesystem,
|
filesystem,
|
||||||
|
embedding_provider=args.embedding_provider,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_dimensions=args.embedding_dimensions,
|
embedding_dimensions=args.embedding_dimensions,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ from .commands import PIFSCommandExecutor
|
||||||
from .core import PageIndexFileSystem
|
from .core import PageIndexFileSystem
|
||||||
from .hybrid_projection import HybridProjectionSearchBackend
|
from .hybrid_projection import HybridProjectionSearchBackend
|
||||||
from .metadata_generation import (
|
from .metadata_generation import (
|
||||||
|
MetadataGenerationBackend,
|
||||||
MetadataGenerationError,
|
MetadataGenerationError,
|
||||||
MetadataGenerationInput,
|
MetadataGenerationInput,
|
||||||
MetadataGenerationResult,
|
MetadataGenerationResult,
|
||||||
MetadataGenerator,
|
MetadataGenerator,
|
||||||
OpenAIMetadataGenerator,
|
|
||||||
)
|
)
|
||||||
from .projection_indexing import SummaryProjectionIndexer
|
from .projection_indexing import SummaryProjectionIndexer
|
||||||
from .semantic_index import (
|
from .semantic_index import (
|
||||||
|
|
@ -20,11 +20,11 @@ from .types import OpenResult, SearchResult
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OpenResult",
|
"OpenResult",
|
||||||
"HybridProjectionSearchBackend",
|
"HybridProjectionSearchBackend",
|
||||||
|
"MetadataGenerationBackend",
|
||||||
"MetadataGenerationError",
|
"MetadataGenerationError",
|
||||||
"MetadataGenerationInput",
|
"MetadataGenerationInput",
|
||||||
"MetadataGenerationResult",
|
"MetadataGenerationResult",
|
||||||
"MetadataGenerator",
|
"MetadataGenerator",
|
||||||
"OpenAIMetadataGenerator",
|
|
||||||
"PIFSCommandExecutor",
|
"PIFSCommandExecutor",
|
||||||
"PageIndexFileSystem",
|
"PageIndexFileSystem",
|
||||||
"RebuildableSemanticIndex",
|
"RebuildableSemanticIndex",
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,11 @@ from urllib.parse import unquote, urlparse
|
||||||
from ..client import PageIndexClient
|
from ..client import PageIndexClient
|
||||||
from .metadata import MetadataQueryEngine
|
from .metadata import MetadataQueryEngine
|
||||||
from .metadata_generation import (
|
from .metadata_generation import (
|
||||||
|
MetadataGenerationBackend,
|
||||||
MetadataGenerationError,
|
MetadataGenerationError,
|
||||||
MetadataGenerationInput,
|
MetadataGenerationInput,
|
||||||
MetadataGenerationResult,
|
MetadataGenerationResult,
|
||||||
MetadataGenerator,
|
MetadataGenerator,
|
||||||
OpenAIMetadataGenerator,
|
|
||||||
)
|
)
|
||||||
from .projection_indexing import SummaryProjectionIndexer
|
from .projection_indexing import SummaryProjectionIndexer
|
||||||
from .semantic_folder_policy import (
|
from .semantic_folder_policy import (
|
||||||
|
|
@ -91,7 +91,11 @@ class PageIndexFileSystem:
|
||||||
workspace: Union[str, Path],
|
workspace: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
semantic_retrieval_backend: Any | None = None,
|
semantic_retrieval_backend: Any | None = None,
|
||||||
metadata_generator: MetadataGenerator | None = None,
|
metadata_generator: MetadataGenerationBackend | None = None,
|
||||||
|
metadata_provider: str = "openai",
|
||||||
|
metadata_model: str | None = None,
|
||||||
|
metadata_base_url: str | None = None,
|
||||||
|
metadata_max_text_chars: int = 24000,
|
||||||
summary_projection_indexer: SummaryProjectionIndexer | None = None,
|
summary_projection_indexer: SummaryProjectionIndexer | None = None,
|
||||||
summary_projection_index: bool = True,
|
summary_projection_index: bool = True,
|
||||||
summary_projection_index_dir: Union[str, Path, None] = None,
|
summary_projection_index_dir: Union[str, Path, None] = None,
|
||||||
|
|
@ -105,6 +109,10 @@ class PageIndexFileSystem:
|
||||||
self.metadata = MetadataQueryEngine(self.store)
|
self.metadata = MetadataQueryEngine(self.store)
|
||||||
self.semantic_retrieval_backend = semantic_retrieval_backend
|
self.semantic_retrieval_backend = semantic_retrieval_backend
|
||||||
self.metadata_generator = metadata_generator
|
self.metadata_generator = metadata_generator
|
||||||
|
self.metadata_provider = metadata_provider
|
||||||
|
self.metadata_model = metadata_model
|
||||||
|
self.metadata_base_url = metadata_base_url
|
||||||
|
self.metadata_max_text_chars = metadata_max_text_chars
|
||||||
self.summary_projection_indexer = summary_projection_indexer
|
self.summary_projection_indexer = summary_projection_indexer
|
||||||
self.summary_projection_index = summary_projection_index
|
self.summary_projection_index = summary_projection_index
|
||||||
self.summary_projection_index_dir = (
|
self.summary_projection_index_dir = (
|
||||||
|
|
@ -199,7 +207,12 @@ class PageIndexFileSystem:
|
||||||
|
|
||||||
def _ensure_register_completion_defaults(self) -> None:
|
def _ensure_register_completion_defaults(self) -> None:
|
||||||
if self.metadata_generator is None:
|
if self.metadata_generator is None:
|
||||||
self.metadata_generator = OpenAIMetadataGenerator()
|
self.metadata_generator = MetadataGenerator(
|
||||||
|
provider=self.metadata_provider,
|
||||||
|
model=self.metadata_model,
|
||||||
|
base_url=self.metadata_base_url,
|
||||||
|
max_text_chars=self.metadata_max_text_chars,
|
||||||
|
)
|
||||||
if self.summary_projection_index and self.summary_projection_indexer is None:
|
if self.summary_projection_index and self.summary_projection_indexer is None:
|
||||||
self.summary_projection_indexer = SummaryProjectionIndexer.from_provider(
|
self.summary_projection_indexer = SummaryProjectionIndexer.from_provider(
|
||||||
self.summary_projection_index_dir,
|
self.summary_projection_index_dir,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
@ -331,17 +330,22 @@ class EmbeddingCache:
|
||||||
return [cached[text_hash] for text_hash in hashes]
|
return [cached[text_hash] for text_hash in hashes]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddingClient:
|
class EmbeddingClient:
|
||||||
def __init__(self, model: str, *, dimensions: int, timeout: float):
|
def __init__(self, *, provider: str, model: str, dimensions: int, timeout: float):
|
||||||
from openai import OpenAI
|
self.provider = provider.lower()
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.client = OpenAI(
|
if self.provider != "openai":
|
||||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
raise ValueError(f"unknown embedding provider: {provider}")
|
||||||
base_url=os.environ.get("OPENAI_BASE_URL") or None,
|
from openai import OpenAI
|
||||||
timeout=timeout,
|
|
||||||
)
|
api_key = os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
|
base_url = os.environ.get("PIFS_EMBEDDING_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY is required for PIFS embeddings"
|
||||||
|
)
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=base_url or None, timeout=timeout)
|
||||||
|
|
||||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
kwargs: dict[str, Any] = {"model": self.model, "input": texts}
|
kwargs: dict[str, Any] = {"model": self.model, "input": texts}
|
||||||
|
|
@ -351,32 +355,13 @@ class OpenAIEmbeddingClient:
|
||||||
return [list(item.embedding) for item in sorted(response.data, key=lambda item: item.index)]
|
return [list(item.embedding) for item in sorted(response.data, key=lambda item: item.index)]
|
||||||
|
|
||||||
|
|
||||||
class HashEmbeddingClient:
|
|
||||||
def __init__(self, dimensions: int = 256):
|
|
||||||
self.dimensions = dimensions
|
|
||||||
|
|
||||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
||||||
return [self._embed_one(text) for text in texts]
|
|
||||||
|
|
||||||
def _embed_one(self, text: str) -> list[float]:
|
|
||||||
vector = [0.0] * self.dimensions
|
|
||||||
for term in keyword_terms(text)[:256]:
|
|
||||||
digest = hashlib.blake2b(term.encode("utf-8"), digest_size=8).digest()
|
|
||||||
bucket = int.from_bytes(digest[:4], "little") % self.dimensions
|
|
||||||
sign = 1.0 if digest[4] % 2 == 0 else -1.0
|
|
||||||
vector[bucket] += sign
|
|
||||||
norm = sum(value * value for value in vector) ** 0.5
|
|
||||||
if norm:
|
|
||||||
vector = [value / norm for value in vector]
|
|
||||||
return vector
|
|
||||||
|
|
||||||
|
|
||||||
def make_embedder(provider: str, model: str, *, dimensions: int, timeout: float) -> Any:
|
def make_embedder(provider: str, model: str, *, dimensions: int, timeout: float) -> Any:
|
||||||
if provider == "openai":
|
return EmbeddingClient(
|
||||||
return OpenAIEmbeddingClient(model, dimensions=dimensions, timeout=timeout)
|
provider=provider,
|
||||||
if provider == "hash":
|
model=model,
|
||||||
return HashEmbeddingClient(dimensions=dimensions if dimensions > 0 else 256)
|
dimensions=dimensions,
|
||||||
raise ValueError(f"unknown embedding provider: {provider}")
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def query_text_for_channel(channel: str, query: str, projection: QueryProjection) -> str:
|
def query_text_for_channel(channel: str, query: str, projection: QueryProjection) -> str:
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class MetadataGenerationResult:
|
||||||
failures: dict[str, str] = field(default_factory=dict)
|
failures: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class MetadataGenerator(Protocol):
|
class MetadataGenerationBackend(Protocol):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
request: MetadataGenerationInput,
|
request: MetadataGenerationInput,
|
||||||
|
|
@ -42,23 +42,31 @@ class MetadataGenerator(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMetadataGenerator:
|
class MetadataGenerator:
|
||||||
"""Default product generator for retrieval metadata.
|
"""Default product generator for retrieval metadata.
|
||||||
|
|
||||||
This intentionally lives under pageindex.filesystem instead of benchmark
|
This intentionally lives under pageindex.filesystem instead of benchmark
|
||||||
paths. It uses registered text today; callers can pass PageIndex-extracted
|
paths. It uses registered text today; callers can pass PageIndex-extracted
|
||||||
text through the same MetadataGenerationInput without changing the API.
|
text through the same MetadataGenerationInput without changing the API.
|
||||||
|
Provider selection is an instance parameter rather than a provider-specific
|
||||||
|
public class name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
provider: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
max_text_chars: int = 24000,
|
max_text_chars: int = 24000,
|
||||||
):
|
):
|
||||||
|
self.provider = (provider or os.environ.get("PIFS_METADATA_PROVIDER", "openai")).lower()
|
||||||
self.model = model or os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano")
|
self.model = model or os.environ.get("PIFS_METADATA_MODEL", "gpt-5-nano")
|
||||||
self.base_url = base_url if base_url is not None else os.environ.get("OPENAI_BASE_URL")
|
self.base_url = (
|
||||||
|
base_url
|
||||||
|
if base_url is not None
|
||||||
|
else os.environ.get("PIFS_METADATA_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
|
||||||
|
)
|
||||||
self.max_text_chars = max_text_chars
|
self.max_text_chars = max_text_chars
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
|
@ -67,9 +75,21 @@ class OpenAIMetadataGenerator:
|
||||||
*,
|
*,
|
||||||
fields: list[str],
|
fields: list[str],
|
||||||
) -> MetadataGenerationResult:
|
) -> MetadataGenerationResult:
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
if self.provider != "openai":
|
||||||
|
raise MetadataGenerationError(f"unsupported metadata provider: {self.provider}")
|
||||||
|
return self._generate_openai(request, fields=fields)
|
||||||
|
|
||||||
|
def _generate_openai(
|
||||||
|
self,
|
||||||
|
request: MetadataGenerationInput,
|
||||||
|
*,
|
||||||
|
fields: list[str],
|
||||||
|
) -> MetadataGenerationResult:
|
||||||
|
api_key = os.environ.get("PIFS_METADATA_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise MetadataGenerationError("OPENAI_API_KEY is required for PIFS metadata generation")
|
raise MetadataGenerationError(
|
||||||
|
"PIFS_METADATA_API_KEY or OPENAI_API_KEY is required for PIFS metadata generation"
|
||||||
|
)
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
@ -122,7 +142,7 @@ class OpenAIMetadataGenerator:
|
||||||
properties[field] = {"type": "string"}
|
properties[field] = {"type": "string"}
|
||||||
else:
|
else:
|
||||||
raise MetadataGenerationError(
|
raise MetadataGenerationError(
|
||||||
f"OpenAIMetadataGenerator does not support generated metadata field: {field}"
|
f"MetadataGenerator does not support generated metadata field: {field}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
|
|
|
||||||
30
tests/test_metadata_generation.py
Normal file
30
tests/test_metadata_generation.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_generator_uses_provider_parameter():
|
||||||
|
from pageindex.filesystem.metadata_generation import (
|
||||||
|
MetadataGenerationError,
|
||||||
|
MetadataGenerationInput,
|
||||||
|
MetadataGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = MetadataGenerator(provider="unsupported", model="unused")
|
||||||
|
request = MetadataGenerationInput(
|
||||||
|
file_ref="file_a",
|
||||||
|
external_id="doc_a",
|
||||||
|
title="A",
|
||||||
|
source_path="docs/a.txt",
|
||||||
|
content_type="text/plain",
|
||||||
|
source_type=None,
|
||||||
|
text="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(MetadataGenerationError, match="unsupported metadata provider: unsupported"):
|
||||||
|
generator.generate(request, fields=["summary"])
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
if str(REPO_ROOT) not in sys.path:
|
if str(REPO_ROOT) not in sys.path:
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
@ -87,3 +89,10 @@ def test_summary_projection_indexes_unified_metadata_summary(tmp_path):
|
||||||
assert hits[0].external_id == "doc_a"
|
assert hits[0].external_id == "doc_a"
|
||||||
assert hits[0].metadata["summary"] == "Unified metadata summary."
|
assert hits[0].metadata["summary"] == "Unified metadata summary."
|
||||||
assert hits[0].metadata["department"] == "ops"
|
assert hits[0].metadata["department"] == "ops"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_embedding_provider_is_not_available():
|
||||||
|
from pageindex.filesystem.hybrid_projection import make_embedder
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="unknown embedding provider: hash"):
|
||||||
|
make_embedder("hash", "unused", dimensions=256, timeout=1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue