from __future__ import annotations import json import os import re import sqlite3 import struct import time from dataclasses import dataclass from pathlib import Path from typing import Any from .semantic_index import SQLiteVecSemanticIndex, SemanticIndexError, SemanticSearchResult INDEX_BY_CHANNEL = { "metadata": "metadata_composite_vector", "summary": "summary_only_vector", "entity": "entity_vectors", "constraint": "constraint_vectors", "relation": "relation_vectors", } HYBRID_ENTITY_RELATION_CHANNELS = ("metadata", "entity", "constraint", "relation") SEMANTIC_TOOL_CHANNELS = ("summary", "entity", "relation") HYBRID_ENTITY_RELATION_WEIGHTS = { "metadata": 0.25, "entity": 0.25, "relation": 0.30, "constraint": 0.20, } @dataclass(frozen=True) class QueryProjection: entities: list[str] relations: list[str] constraints: list[str] expected_answer_type: str = "" @dataclass(frozen=True) class HybridProjectionCandidate: document_id: str score: float sources: list[dict[str, Any]] source_type: str source_path: str title: str metadata: dict[str, Any] snippet: str class HybridProjectionSearchBackend: """Hybrid entity/relation/vector retrieval over rebuildable projection indexes. The SQLite catalog remains the source of truth. This backend only reads external sqlite-vec projection indexes and returns candidate document ids for the catalog to resolve and filter. """ def __init__( self, index_dir: str | Path, *, embedder: Any, embedding_provider: str, embedding_model: str, embedding_dimensions: int = 256, embedding_cache_path: str | Path | None = None, per_channel_limit: int = 100, fetch_multiplier: int = 100, ) -> None: self.index_dir = Path(index_dir).expanduser() self.embedder = embedder self.embedding_provider = embedding_provider self.embedding_model = embedding_model self.embedding_dimensions = embedding_dimensions self.cache_model = embedding_cache_model_key(embedding_model, embedding_dimensions) self.embedding_cache = EmbeddingCache( Path(embedding_cache_path).expanduser() if embedding_cache_path is not None else self.index_dir / "embedding_cache.sqlite" ) self.per_channel_limit = per_channel_limit self.fetch_multiplier = fetch_multiplier self.indexes = { channel: SQLiteVecSemanticIndex(self.index_dir / f"{index_name}.sqlite") for channel, index_name in INDEX_BY_CHANNEL.items() } @classmethod def from_provider( cls, index_dir: str | Path, *, embedding_provider: str = "openai", embedding_model: str = "text-embedding-3-small", embedding_dimensions: int = 256, embedding_timeout: float = 60, **kwargs: Any, ) -> "HybridProjectionSearchBackend": return cls( index_dir, embedder=make_embedder( embedding_provider, embedding_model, dimensions=embedding_dimensions, timeout=embedding_timeout, ), embedding_provider=embedding_provider, embedding_model=embedding_model, embedding_dimensions=embedding_dimensions, **kwargs, ) def search( self, query: str, *, limit: int = 10, filters: dict[str, Any] | None = None, ) -> list[HybridProjectionCandidate]: query = normalize_text(query) if not query: return [] projection = heuristic_query_projection(query) channels = tuple( channel for channel in HYBRID_ENTITY_RELATION_CHANNELS if self._channel_document_count(channel) > 0 ) if not channels: if self._channel_document_count("summary") > 0: return self.search_channel("summary", query, limit=limit, filters=filters) return [] channel_hits = self._search_channels( query=query, projection=projection, limit=max(limit, self.per_channel_limit), filters=filters, channels=channels, ) return aggregate_hybrid_entity_relation(channel_hits, projection)[:limit] def search_channel( self, channel: str, query: str, *, limit: int = 10, filters: dict[str, Any] | None = None, ) -> list[HybridProjectionCandidate]: if channel not in SEMANTIC_TOOL_CHANNELS: raise ValueError(f"unsupported semantic channel: {channel}") if channel not in self.available_channels(): return [] query = normalize_text(query) if not query: return [] projection = heuristic_query_projection(query) vector = self.embedding_cache.embed_texts( [query_text_for_channel(channel, query, projection)], provider=self.embedding_provider, model=self.cache_model, embedder=self.embedder, batch_size=1, )[0] results = self.indexes[channel].search( vector, limit=limit, filters=filters, fetch_multiplier=self.fetch_multiplier, ) return rank_single_semantic_channel(channel, results) def available_channels(self) -> tuple[str, ...]: return tuple( channel for channel in SEMANTIC_TOOL_CHANNELS if self._channel_document_count(channel) > 0 ) def info(self) -> dict[str, Any]: return { "index_dir": str(self.index_dir), "embedding_provider": self.embedding_provider, "embedding_model": self.embedding_model, "embedding_dimensions": self.embedding_dimensions, "strategy": "hybrid_entity_relation_vector", "available_channels": list(self.available_channels()), "channels": { channel: self._safe_channel_info(channel) for channel in self.indexes }, } def _channel_document_count(self, channel: str) -> int: info = self._safe_channel_info(channel) if not info.get("available"): return 0 return int(info.get("document_count") or 0) def _safe_channel_info(self, channel: str) -> dict[str, Any]: index = self.indexes[channel] if not index.db_path.exists(): return { "db_path": str(index.db_path), "available": False, "document_count": 0, "error": "index file is missing", } try: info = index.info() except (OSError, sqlite3.Error, SemanticIndexError) as exc: return { "db_path": str(index.db_path), "available": False, "document_count": 0, "error": str(exc), } return {**info, "available": int(info.get("document_count") or 0) > 0} def _search_channels( self, *, query: str, projection: QueryProjection, limit: int, filters: dict[str, Any] | None, channels: tuple[str, ...], ) -> dict[str, list[SemanticSearchResult]]: query_texts = { channel: query_text_for_channel(channel, query, projection) for channel in channels } vectors = self.embedding_cache.embed_texts( [query_texts[channel] for channel in channels], provider=self.embedding_provider, model=self.cache_model, embedder=self.embedder, batch_size=1, ) return { channel: self.indexes[channel].search( vector, limit=limit, filters=filters, fetch_multiplier=self.fetch_multiplier, ) for channel, vector in zip(channels, vectors) } class EmbeddingCache: def __init__(self, db_path: Path): self.db_path = db_path self.db_path.parent.mkdir(parents=True, exist_ok=True) with self.connect() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS embedding_cache ( provider TEXT NOT NULL, model TEXT NOT NULL, text_hash TEXT NOT NULL, dimension INTEGER NOT NULL, vector_blob BLOB, vector_json TEXT, created_at TEXT DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY(provider, model, text_hash) ) """ ) conn.commit() def connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def embed_texts( self, texts: list[str], *, provider: str, model: str, embedder: Any, batch_size: int, ) -> list[list[float]]: hashes = [SQLiteVecSemanticIndex.text_hash(text) for text in texts] cached: dict[str, list[float]] = {} with self.connect() as conn: for text_hash in sorted(set(hashes)): row = conn.execute( """ SELECT vector_blob, vector_json FROM embedding_cache WHERE provider = ? AND model = ? AND text_hash = ? """, (provider, model, text_hash), ).fetchone() if row is not None: cached[text_hash] = decode_vector(row["vector_blob"], row["vector_json"]) missing_positions = [ index for index, text_hash in enumerate(hashes) if text_hash not in cached ] for start in range(0, len(missing_positions), max(1, batch_size)): positions = missing_positions[start : start + max(1, batch_size)] batch_texts = [texts[index] for index in positions] vectors = embed_with_retry(embedder, batch_texts) with self.connect() as conn: conn.executemany( """ INSERT OR REPLACE INTO embedding_cache( provider, model, text_hash, dimension, vector_blob, vector_json ) VALUES (?, ?, ?, ?, ?, '') """, [ ( provider, model, hashes[index], len(vector), encode_vector(vector), ) for index, vector in zip(positions, vectors) ], ) conn.commit() for index, vector in zip(positions, vectors): cached[hashes[index]] = vector return [cached[text_hash] for text_hash in hashes] class EmbeddingClient: def __init__(self, *, provider: str, model: str, dimensions: int, timeout: float): self.provider = provider.lower() self.model = model self.dimensions = dimensions if self.provider != "openai": raise ValueError(f"unknown embedding provider: {provider}") from openai import OpenAI api_key = os.environ.get("PIFS_EMBEDDING_API_KEY") or os.environ.get("OPENAI_API_KEY") base_url = os.environ.get("PIFS_EMBEDDING_BASE_URL") or os.environ.get("OPENAI_BASE_URL") if not api_key: raise ValueError( "PIFS_EMBEDDING_API_KEY or OPENAI_API_KEY is required for PIFS embeddings" ) self.client = OpenAI(api_key=api_key, base_url=base_url or None, timeout=timeout) def embed(self, texts: list[str]) -> list[list[float]]: kwargs: dict[str, Any] = {"model": self.model, "input": texts} if self.dimensions > 0: kwargs["dimensions"] = self.dimensions response = self.client.embeddings.create(**kwargs) return [list(item.embedding) for item in sorted(response.data, key=lambda item: item.index)] def make_embedder(provider: str, model: str, *, dimensions: int, timeout: float) -> Any: return EmbeddingClient( provider=provider, model=model, dimensions=dimensions, timeout=timeout, ) def query_text_for_channel(channel: str, query: str, projection: QueryProjection) -> str: if channel in {"metadata", "summary"}: return query if channel == "entity": return compact_join(projection.entities, limit=24) or query if channel == "constraint": return compact_join(projection.constraints, limit=24) or query if channel == "relation": return "\n".join(projection.relations) or query raise ValueError(f"unknown semantic channel: {channel}") def rank_single_semantic_channel( channel: str, results: list[SemanticSearchResult], ) -> list[HybridProjectionCandidate]: rows: list[HybridProjectionCandidate] = [] seen: set[str] = set() for rank, result in enumerate(results, 1): doc_id = str(result.external_id or result.file_ref) if doc_id in seen: continue seen.add(doc_id) rows.append( HybridProjectionCandidate( document_id=doc_id, score=1 / (60 + rank), sources=[{"channel": channel, "rank": rank, "distance": result.distance}], source_type=result.source_type, source_path=result.source_path, title=result.title, metadata=result.metadata, snippet=f"{channel}_vector rank={rank}", ) ) return rows def aggregate_hybrid_entity_relation( channel_hits: dict[str, list[SemanticSearchResult]], projection: QueryProjection, ) -> list[HybridProjectionCandidate]: by_doc: dict[str, dict[str, Any]] = {} for channel, results in channel_hits.items(): weight = HYBRID_ENTITY_RELATION_WEIGHTS[channel] seen_in_channel = set() for rank, result in enumerate(results, 1): doc_id = str(result.external_id or result.file_ref) if doc_id in seen_in_channel: continue seen_in_channel.add(doc_id) item = by_doc.setdefault( doc_id, { "document_id": doc_id, "score": 0.0, "sources": [], "source_type": result.source_type, "source_path": result.source_path, "title": result.title, "metadata": result.metadata, }, ) item["score"] += weight * (1 / (60 + rank)) item["sources"].append({"channel": channel, "rank": rank, "distance": result.distance}) candidates = [] for item in by_doc.values(): item["score"] += exact_match_bonus(item, projection) candidates.append( HybridProjectionCandidate( document_id=item["document_id"], score=float(item["score"]), sources=item["sources"], source_type=item["source_type"], source_path=item["source_path"], title=item["title"], metadata=item["metadata"], snippet=hybrid_snippet(item), ) ) return sorted( candidates, key=lambda item: ( -item.score, min(source["rank"] for source in item.sources), item.document_id, ), ) def exact_match_bonus(item: dict[str, Any], projection: QueryProjection) -> float: haystack = json.dumps( { "title": item.get("title", ""), "source_path": item.get("source_path", ""), "metadata": item.get("metadata", {}), }, ensure_ascii=False, ).lower() terms = [*projection.entities[:8], *projection.constraints[:6]] matched = 0 for term in terms: normalized = str(term).lower().strip() if len(normalized) >= 3 and normalized in haystack: matched += 1 return min(0.02, matched * 0.004) def hybrid_snippet(item: dict[str, Any]) -> str: channels = ", ".join( f"{source['channel']}@{source['rank']}" for source in item.get("sources", [])[:4] ) topic = str((item.get("metadata") or {}).get("topic") or "").strip() parts = [f"hybrid_entity_relation_vector {channels}"] if topic: parts.append(f"topic: {topic}") return "; ".join(parts) def heuristic_query_projection(question: str) -> QueryProjection: entities = dedupe( [ *identifier_terms(question), *keyword_terms(question)[:16], ] )[:16] constraints = dedupe( [ *extract_constraint_terms(question), *numeric_terms(question), ] )[:12] predicate = infer_query_predicate(question) subject = entities[0] if entities else "question" return QueryProjection( entities=entities, relations=[f"{subject} | {predicate} | {question}"], constraints=constraints, expected_answer_type=infer_answer_type(question), ) def compact_join(values: list[str], *, limit: int) -> str: return " | ".join(values[:limit]) def identifier_terms(text: str) -> list[str]: patterns = [ r"\b[A-Z]{2,12}-\d{2,}\b", r"\b[A-Za-z_][A-Za-z0-9_]{2,}\b\s*(?:=|:)\s*[A-Za-z0-9_.:/-]+", r"\b[A-Za-z][A-Za-z0-9_+-]+(?:[-_+][A-Za-z0-9]+)+\b", r"\b[A-Z]{2,}[A-Za-z0-9_-]*\b", ] found: list[str] = [] for pattern in patterns: found.extend(match.strip() for match in re.findall(pattern, text)) return found def keyword_terms(text: str) -> list[str]: stopwords = { "about", "after", "also", "and", "are", "for", "from", "how", "into", "the", "this", "that", "what", "when", "where", "which", "with", } terms = [ term.lower() for term in re.findall(r"[A-Za-z][A-Za-z0-9_+-]{2,}", text) if term.lower() not in stopwords ] return dedupe(terms) def extract_constraint_terms(text: str) -> list[str]: constraints = [] for pattern in [ r"\b(?:must|should|required|requires?|default(?:s)?|limit(?:s)?|maximum|minimum)\b[^.!?\n]{0,120}", r"\b[A-Za-z_][A-Za-z0-9_]{2,}\s*(?:=|:)\s*[A-Za-z0-9_.:/-]+", ]: constraints.extend(match.strip() for match in re.findall(pattern, text, flags=re.IGNORECASE)) return dedupe(constraints) def numeric_terms(text: str) -> list[str]: return re.findall( r"\b\d+(?:\.\d+)?\s*(?:MiB|GiB|MB|GB|ms|sec|seconds|minutes|hours|days|%|tokens?|req/s|rps)\b", text, flags=re.IGNORECASE, ) def infer_query_predicate(question: str) -> str: lowered = question.lower() rules = [ ("asks_default", ["default", "defaults"]), ("asks_limit", ["limit", "maximum", "minimum", "size"]), ("asks_cause", ["caused", "cause", "why"]), ("asks_owner", ["who", "owner", "assigned"]), ("asks_deadline", ["when", "deadline", "date"]), ("asks_status", ["status", "state"]), ("asks_requirement", ["required", "requirement", "must"]), ] for predicate, needles in rules: if any(needle in lowered for needle in needles): return predicate return "asks_about" def infer_answer_type(question: str) -> str: lowered = question.lower() if "how many" in lowered or "limit" in lowered or "size" in lowered: return "number_or_limit" if lowered.startswith("who"): return "person_or_team" if lowered.startswith("when"): return "date_or_time" if "why" in lowered or "caused" in lowered: return "cause" return "fact" def dedupe(values: Any) -> list[str]: seen = set() result = [] for value in values: normalized = re.sub(r"\s+", " ", str(value)).strip() key = normalized.lower() if not normalized or key in seen: continue seen.add(key) result.append(normalized) return result def normalize_text(text: str) -> str: return re.sub(r"\s+", " ", str(text or "")).strip() def embedding_cache_model_key(model: str, dimensions: int) -> str: return f"{model}:dimensions={dimensions}" if dimensions > 0 else model def embed_with_retry(embedder: Any, texts: list[str], *, max_attempts: int = 8) -> list[list[float]]: for attempt in range(1, max_attempts + 1): try: return embedder.embed(texts) except Exception: if attempt >= max_attempts: raise time.sleep(min(120.0, 2.0 ** (attempt - 1))) raise RuntimeError("unreachable embedding retry state") def encode_vector(vector: list[float]) -> bytes: return struct.pack(f"<{len(vector)}f", *vector) def decode_vector(blob: bytes | None, vector_json: str | None) -> list[float]: if blob: if len(blob) % 4 != 0: raise ValueError("invalid cached vector blob length") return list(struct.unpack(f"<{len(blob) // 4}f", blob)) if vector_json: value = json.loads(vector_json) if isinstance(value, list): return [float(item) for item in value] raise ValueError("cached embedding row does not contain a vector")