PageIndex/pageindex/filesystem/hybrid_projection.py
Bukely_ ad45f96dfa fix(filesystem): use summary projection for default semantic search
Route default semantic search to the summary projection when summary is the only populated semantic channel.
2026-05-27 02:12:34 +08:00

649 lines
21 KiB
Python

from __future__ import annotations
import json
import os
import re
import sqlite3
import struct
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .semantic_index import SQLiteVecSemanticIndex, SemanticIndexError, SemanticSearchResult
INDEX_BY_CHANNEL = {
"metadata": "metadata_composite_vector",
"summary": "summary_only_vector",
"entity": "entity_vectors",
"constraint": "constraint_vectors",
"relation": "relation_vectors",
}
HYBRID_ENTITY_RELATION_CHANNELS = ("metadata", "entity", "constraint", "relation")
SEMANTIC_TOOL_CHANNELS = ("summary", "entity", "relation")
HYBRID_ENTITY_RELATION_WEIGHTS = {
"metadata": 0.25,
"entity": 0.25,
"relation": 0.30,
"constraint": 0.20,
}
@dataclass(frozen=True)
class QueryProjection:
entities: list[str]
relations: list[str]
constraints: list[str]
expected_answer_type: str = ""
@dataclass(frozen=True)
class HybridProjectionCandidate:
document_id: str
score: float
sources: list[dict[str, Any]]
source_type: str
source_path: str
title: str
metadata: dict[str, Any]
snippet: str
class HybridProjectionSearchBackend:
"""Hybrid entity/relation/vector retrieval over rebuildable projection indexes.
The SQLite catalog remains the source of truth. This backend only reads
external sqlite-vec projection indexes and returns candidate document ids
for the catalog to resolve and filter.
"""
def __init__(
self,
index_dir: str | Path,
*,
embedder: Any,
embedding_provider: str,
embedding_model: str,
embedding_dimensions: int = 256,
embedding_cache_path: str | Path | None = None,
per_channel_limit: int = 100,
fetch_multiplier: int = 100,
) -> None:
self.index_dir = Path(index_dir).expanduser()
self.embedder = embedder
self.embedding_provider = embedding_provider
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.cache_model = embedding_cache_model_key(embedding_model, embedding_dimensions)
self.embedding_cache = EmbeddingCache(
Path(embedding_cache_path).expanduser()
if embedding_cache_path is not None
else self.index_dir / "embedding_cache.sqlite"
)
self.per_channel_limit = per_channel_limit
self.fetch_multiplier = fetch_multiplier
self.indexes = {
channel: SQLiteVecSemanticIndex(self.index_dir / f"{index_name}.sqlite")
for channel, index_name in INDEX_BY_CHANNEL.items()
}
@classmethod
def from_provider(
cls,
index_dir: str | Path,
*,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-3-small",
embedding_dimensions: int = 256,
embedding_timeout: float = 60,
**kwargs: Any,
) -> "HybridProjectionSearchBackend":
return cls(
index_dir,
embedder=make_embedder(
embedding_provider,
embedding_model,
dimensions=embedding_dimensions,
timeout=embedding_timeout,
),
embedding_provider=embedding_provider,
embedding_model=embedding_model,
embedding_dimensions=embedding_dimensions,
**kwargs,
)
def search(
self,
query: str,
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> list[HybridProjectionCandidate]:
query = normalize_text(query)
if not query:
return []
projection = heuristic_query_projection(query)
channels = tuple(
channel
for channel in HYBRID_ENTITY_RELATION_CHANNELS
if self._channel_document_count(channel) > 0
)
if not channels:
if self._channel_document_count("summary") > 0:
return self.search_channel("summary", query, limit=limit, filters=filters)
return []
channel_hits = self._search_channels(
query=query,
projection=projection,
limit=max(limit, self.per_channel_limit),
filters=filters,
channels=channels,
)
return aggregate_hybrid_entity_relation(channel_hits, projection)[:limit]
def search_channel(
self,
channel: str,
query: str,
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
) -> list[HybridProjectionCandidate]:
if channel not in SEMANTIC_TOOL_CHANNELS:
raise ValueError(f"unsupported semantic channel: {channel}")
if channel not in self.available_channels():
return []
query = normalize_text(query)
if not query:
return []
projection = heuristic_query_projection(query)
vector = self.embedding_cache.embed_texts(
[query_text_for_channel(channel, query, projection)],
provider=self.embedding_provider,
model=self.cache_model,
embedder=self.embedder,
batch_size=1,
)[0]
results = self.indexes[channel].search(
vector,
limit=limit,
filters=filters,
fetch_multiplier=self.fetch_multiplier,
)
return rank_single_semantic_channel(channel, results)
def available_channels(self) -> tuple[str, ...]:
return tuple(
channel
for channel in SEMANTIC_TOOL_CHANNELS
if self._channel_document_count(channel) > 0
)
def info(self) -> dict[str, Any]:
return {
"index_dir": str(self.index_dir),
"embedding_provider": self.embedding_provider,
"embedding_model": self.embedding_model,
"embedding_dimensions": self.embedding_dimensions,
"strategy": "hybrid_entity_relation_vector",
"available_channels": list(self.available_channels()),
"channels": {
channel: self._safe_channel_info(channel)
for channel in self.indexes
},
}
def _channel_document_count(self, channel: str) -> int:
info = self._safe_channel_info(channel)
if not info.get("available"):
return 0
return int(info.get("document_count") or 0)
def _safe_channel_info(self, channel: str) -> dict[str, Any]:
index = self.indexes[channel]
if not index.db_path.exists():
return {
"db_path": str(index.db_path),
"available": False,
"document_count": 0,
"error": "index file is missing",
}
try:
info = index.info()
except (OSError, sqlite3.Error, SemanticIndexError) as exc:
return {
"db_path": str(index.db_path),
"available": False,
"document_count": 0,
"error": str(exc),
}
return {**info, "available": int(info.get("document_count") or 0) > 0}
def _search_channels(
self,
*,
query: str,
projection: QueryProjection,
limit: int,
filters: dict[str, Any] | None,
channels: tuple[str, ...],
) -> dict[str, list[SemanticSearchResult]]:
query_texts = {
channel: query_text_for_channel(channel, query, projection)
for channel in channels
}
vectors = self.embedding_cache.embed_texts(
[query_texts[channel] for channel in channels],
provider=self.embedding_provider,
model=self.cache_model,
embedder=self.embedder,
batch_size=1,
)
return {
channel: self.indexes[channel].search(
vector,
limit=limit,
filters=filters,
fetch_multiplier=self.fetch_multiplier,
)
for channel, vector in zip(channels, vectors)
}
class EmbeddingCache:
def __init__(self, db_path: Path):
self.db_path = db_path
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with self.connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS embedding_cache (
provider TEXT NOT NULL,
model TEXT NOT NULL,
text_hash TEXT NOT NULL,
dimension INTEGER NOT NULL,
vector_blob BLOB,
vector_json TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(provider, model, text_hash)
)
"""
)
conn.commit()
def connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def embed_texts(
self,
texts: list[str],
*,
provider: str,
model: str,
embedder: Any,
batch_size: int,
) -> list[list[float]]:
hashes = [SQLiteVecSemanticIndex.text_hash(text) for text in texts]
cached: dict[str, list[float]] = {}
with self.connect() as conn:
for text_hash in sorted(set(hashes)):
row = conn.execute(
"""
SELECT vector_blob, vector_json
FROM embedding_cache
WHERE provider = ? AND model = ? AND text_hash = ?
""",
(provider, model, text_hash),
).fetchone()
if row is not None:
cached[text_hash] = decode_vector(row["vector_blob"], row["vector_json"])
missing_positions = [
index for index, text_hash in enumerate(hashes) if text_hash not in cached
]
for start in range(0, len(missing_positions), max(1, batch_size)):
positions = missing_positions[start : start + max(1, batch_size)]
batch_texts = [texts[index] for index in positions]
vectors = embed_with_retry(embedder, batch_texts)
with self.connect() as conn:
conn.executemany(
"""
INSERT OR REPLACE INTO embedding_cache(
provider, model, text_hash, dimension, vector_blob, vector_json
)
VALUES (?, ?, ?, ?, ?, '')
""",
[
(
provider,
model,
hashes[index],
len(vector),
encode_vector(vector),
)
for index, vector in zip(positions, vectors)
],
)
conn.commit()
for index, vector in zip(positions, vectors):
cached[hashes[index]] = vector
return [cached[text_hash] for text_hash in hashes]
class EmbeddingClient:
def __init__(self, *, provider: str, model: str, dimensions: int, timeout: float):
self.provider = provider.lower()
self.model = model
self.dimensions = dimensions
if self.provider != "openai":
raise ValueError(f"unknown embedding provider: {provider}")
from openai import OpenAI
api_key = os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY")
base_url = os.environ.get("PIFS_EMBEDDING_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
if not api_key:
raise ValueError(
"PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY is required for PIFS embeddings"
)
self.client = OpenAI(api_key=api_key, base_url=base_url or None, timeout=timeout)
def embed(self, texts: list[str]) -> list[list[float]]:
kwargs: dict[str, Any] = {"model": self.model, "input": texts}
if self.dimensions > 0:
kwargs["dimensions"] = self.dimensions
response = self.client.embeddings.create(**kwargs)
return [list(item.embedding) for item in sorted(response.data, key=lambda item: item.index)]
def make_embedder(provider: str, model: str, *, dimensions: int, timeout: float) -> Any:
return EmbeddingClient(
provider=provider,
model=model,
dimensions=dimensions,
timeout=timeout,
)
def query_text_for_channel(channel: str, query: str, projection: QueryProjection) -> str:
if channel in {"metadata", "summary"}:
return query
if channel == "entity":
return compact_join(projection.entities, limit=24) or query
if channel == "constraint":
return compact_join(projection.constraints, limit=24) or query
if channel == "relation":
return "\n".join(projection.relations) or query
raise ValueError(f"unknown semantic channel: {channel}")
def rank_single_semantic_channel(
channel: str,
results: list[SemanticSearchResult],
) -> list[HybridProjectionCandidate]:
rows: list[HybridProjectionCandidate] = []
seen: set[str] = set()
for rank, result in enumerate(results, 1):
doc_id = str(result.external_id or result.file_ref)
if doc_id in seen:
continue
seen.add(doc_id)
rows.append(
HybridProjectionCandidate(
document_id=doc_id,
score=1 / (60 + rank),
sources=[{"channel": channel, "rank": rank, "distance": result.distance}],
source_type=result.source_type,
source_path=result.source_path,
title=result.title,
metadata=result.metadata,
snippet=f"{channel}_vector rank={rank}",
)
)
return rows
def aggregate_hybrid_entity_relation(
channel_hits: dict[str, list[SemanticSearchResult]],
projection: QueryProjection,
) -> list[HybridProjectionCandidate]:
by_doc: dict[str, dict[str, Any]] = {}
for channel, results in channel_hits.items():
weight = HYBRID_ENTITY_RELATION_WEIGHTS[channel]
seen_in_channel = set()
for rank, result in enumerate(results, 1):
doc_id = str(result.external_id or result.file_ref)
if doc_id in seen_in_channel:
continue
seen_in_channel.add(doc_id)
item = by_doc.setdefault(
doc_id,
{
"document_id": doc_id,
"score": 0.0,
"sources": [],
"source_type": result.source_type,
"source_path": result.source_path,
"title": result.title,
"metadata": result.metadata,
},
)
item["score"] += weight * (1 / (60 + rank))
item["sources"].append({"channel": channel, "rank": rank, "distance": result.distance})
candidates = []
for item in by_doc.values():
item["score"] += exact_match_bonus(item, projection)
candidates.append(
HybridProjectionCandidate(
document_id=item["document_id"],
score=float(item["score"]),
sources=item["sources"],
source_type=item["source_type"],
source_path=item["source_path"],
title=item["title"],
metadata=item["metadata"],
snippet=hybrid_snippet(item),
)
)
return sorted(
candidates,
key=lambda item: (
-item.score,
min(source["rank"] for source in item.sources),
item.document_id,
),
)
def exact_match_bonus(item: dict[str, Any], projection: QueryProjection) -> float:
haystack = json.dumps(
{
"title": item.get("title", ""),
"source_path": item.get("source_path", ""),
"metadata": item.get("metadata", {}),
},
ensure_ascii=False,
).lower()
terms = [*projection.entities[:8], *projection.constraints[:6]]
matched = 0
for term in terms:
normalized = str(term).lower().strip()
if len(normalized) >= 3 and normalized in haystack:
matched += 1
return min(0.02, matched * 0.004)
def hybrid_snippet(item: dict[str, Any]) -> str:
channels = ", ".join(
f"{source['channel']}@{source['rank']}" for source in item.get("sources", [])[:4]
)
topic = str((item.get("metadata") or {}).get("topic") or "").strip()
parts = [f"hybrid_entity_relation_vector {channels}"]
if topic:
parts.append(f"topic: {topic}")
return "; ".join(parts)
def heuristic_query_projection(question: str) -> QueryProjection:
entities = dedupe(
[
*identifier_terms(question),
*keyword_terms(question)[:16],
]
)[:16]
constraints = dedupe(
[
*extract_constraint_terms(question),
*numeric_terms(question),
]
)[:12]
predicate = infer_query_predicate(question)
subject = entities[0] if entities else "question"
return QueryProjection(
entities=entities,
relations=[f"{subject} | {predicate} | {question}"],
constraints=constraints,
expected_answer_type=infer_answer_type(question),
)
def compact_join(values: list[str], *, limit: int) -> str:
return " | ".join(values[:limit])
def identifier_terms(text: str) -> list[str]:
patterns = [
r"\b[A-Z]{2,12}-\d{2,}\b",
r"\b[A-Za-z_][A-Za-z0-9_]{2,}\b\s*(?:=|:)\s*[A-Za-z0-9_.:/-]+",
r"\b[A-Za-z][A-Za-z0-9_+-]+(?:[-_+][A-Za-z0-9]+)+\b",
r"\b[A-Z]{2,}[A-Za-z0-9_-]*\b",
]
found: list[str] = []
for pattern in patterns:
found.extend(match.strip() for match in re.findall(pattern, text))
return found
def keyword_terms(text: str) -> list[str]:
stopwords = {
"about",
"after",
"also",
"and",
"are",
"for",
"from",
"how",
"into",
"the",
"this",
"that",
"what",
"when",
"where",
"which",
"with",
}
terms = [
term.lower()
for term in re.findall(r"[A-Za-z][A-Za-z0-9_+-]{2,}", text)
if term.lower() not in stopwords
]
return dedupe(terms)
def extract_constraint_terms(text: str) -> list[str]:
constraints = []
for pattern in [
r"\b(?:must|should|required|requires?|default(?:s)?|limit(?:s)?|maximum|minimum)\b[^.!?\n]{0,120}",
r"\b[A-Za-z_][A-Za-z0-9_]{2,}\s*(?:=|:)\s*[A-Za-z0-9_.:/-]+",
]:
constraints.extend(match.strip() for match in re.findall(pattern, text, flags=re.IGNORECASE))
return dedupe(constraints)
def numeric_terms(text: str) -> list[str]:
return re.findall(
r"\b\d+(?:\.\d+)?\s*(?:MiB|GiB|MB|GB|ms|sec|seconds|minutes|hours|days|%|tokens?|req/s|rps)\b",
text,
flags=re.IGNORECASE,
)
def infer_query_predicate(question: str) -> str:
lowered = question.lower()
rules = [
("asks_default", ["default", "defaults"]),
("asks_limit", ["limit", "maximum", "minimum", "size"]),
("asks_cause", ["caused", "cause", "why"]),
("asks_owner", ["who", "owner", "assigned"]),
("asks_deadline", ["when", "deadline", "date"]),
("asks_status", ["status", "state"]),
("asks_requirement", ["required", "requirement", "must"]),
]
for predicate, needles in rules:
if any(needle in lowered for needle in needles):
return predicate
return "asks_about"
def infer_answer_type(question: str) -> str:
lowered = question.lower()
if "how many" in lowered or "limit" in lowered or "size" in lowered:
return "number_or_limit"
if lowered.startswith("who"):
return "person_or_team"
if lowered.startswith("when"):
return "date_or_time"
if "why" in lowered or "caused" in lowered:
return "cause"
return "fact"
def dedupe(values: Any) -> list[str]:
seen = set()
result = []
for value in values:
normalized = re.sub(r"\s+", " ", str(value)).strip()
key = normalized.lower()
if not normalized or key in seen:
continue
seen.add(key)
result.append(normalized)
return result
def normalize_text(text: str) -> str:
return re.sub(r"\s+", " ", str(text or "")).strip()
def embedding_cache_model_key(model: str, dimensions: int) -> str:
return f"{model}:dimensions={dimensions}" if dimensions > 0 else model
def embed_with_retry(embedder: Any, texts: list[str], *, max_attempts: int = 8) -> list[list[float]]:
for attempt in range(1, max_attempts + 1):
try:
return embedder.embed(texts)
except Exception:
if attempt >= max_attempts:
raise
time.sleep(min(120.0, 2.0 ** (attempt - 1)))
raise RuntimeError("unreachable embedding retry state")
def encode_vector(vector: list[float]) -> bytes:
return struct.pack(f"<{len(vector)}f", *vector)
def decode_vector(blob: bytes | None, vector_json: str | None) -> list[float]:
if blob:
if len(blob) % 4 != 0:
raise ValueError("invalid cached vector blob length")
return list(struct.unpack(f"<{len(blob) // 4}f", blob))
if vector_json:
value = json.loads(vector_json)
if isinstance(value, list):
return [float(item) for item in value]
raise ValueError("cached embedding row does not contain a vector")