diff --git a/pageindex/filesystem/core.py b/pageindex/filesystem/core.py index c20cccf..78c1cec 100644 --- a/pageindex/filesystem/core.py +++ b/pageindex/filesystem/core.py @@ -371,15 +371,25 @@ class PageIndexFileSystem: ) parsed_filter = self.metadata.parse_filter(metadata_filter) scope = {"folder_path": path, "recursive": recursive} + scope_file_refs = self.store.file_refs_for_scope( + scope=scope, + metadata_filter=parsed_filter, + ) offset = (page - 1) * page_size needed = offset + page_size + 1 - fetch_limit = max(needed * 10, 50) - candidates = search_channel( - space, - query_text, - limit=fetch_limit, - filters=self._semantic_filters_for_scope(scope), + semantic_filters = self._semantic_filters_for_scope(scope) + semantic_filters["file_ref"] = scope_file_refs + candidates = ( + search_channel( + space, + query_text, + limit=needed, + filters=semantic_filters, + ) + if scope_file_refs + else [] ) + scope_file_ref_set = set(scope_file_refs) rows: list[dict[str, Any]] = [] seen: set[str] = set() for candidate in candidates: @@ -389,6 +399,8 @@ class PageIndexFileSystem: continue if file_ref in seen: continue + if file_ref not in scope_file_ref_set: + continue if not self.store.file_matches( file_ref, scope=scope, diff --git a/pageindex/filesystem/semantic_index.py b/pageindex/filesystem/semantic_index.py index 2453e1f..4a29551 100644 --- a/pageindex/filesystem/semantic_index.py +++ b/pageindex/filesystem/semantic_index.py @@ -159,15 +159,29 @@ class SQLiteVecSemanticIndex: raise SemanticIndexError( f"query vector dimension mismatch: expected {dimension}, got {len(vector)}" ) - fetch_k = min(4096, max(limit, limit * max(fetch_multiplier, 1))) - source_types = _source_type_filters(filters or {}) + raw_filters = filters or {} + source_types = _source_type_filters(raw_filters) + file_refs = _file_ref_filters(raw_filters) + if file_refs == []: + return [] with self.connect() as conn: + if file_refs is not None: + _install_file_ref_filter_table(conn, file_refs) rows = [] if source_types: for source_type in source_types: + fetch_k = self._search_fetch_k( + conn, + limit, + fetch_multiplier, + exact_file_ref_filter=file_refs is not None, + source_type=source_type, + ) + if fetch_k <= 0: + continue rows.extend( conn.execute( - """ + f""" SELECT d.file_ref, d.external_id, @@ -180,6 +194,7 @@ class SQLiteVecSemanticIndex: FROM semantic_index_vec v JOIN semantic_index_docs d ON d.rowid = v.rowid WHERE v.embedding MATCH ? AND k = ? AND v.source_type = ? + {_file_ref_filter_sql(file_refs)} ORDER BY v.distance """, (sqlite_vec.serialize_float32(vector), fetch_k, source_type), @@ -187,8 +202,16 @@ class SQLiteVecSemanticIndex: ) rows.sort(key=lambda row: float(row["distance"])) else: + fetch_k = self._search_fetch_k( + conn, + limit, + fetch_multiplier, + exact_file_ref_filter=file_refs is not None, + ) + if fetch_k <= 0: + return [] rows = conn.execute( - """ + f""" SELECT d.file_ref, d.external_id, @@ -201,6 +224,7 @@ class SQLiteVecSemanticIndex: FROM semantic_index_vec v JOIN semantic_index_docs d ON d.rowid = v.rowid WHERE v.embedding MATCH ? AND k = ? + {_file_ref_filter_sql(file_refs)} ORDER BY v.distance """, (sqlite_vec.serialize_float32(vector), fetch_k), @@ -226,6 +250,30 @@ class SQLiteVecSemanticIndex: break return results + @staticmethod + def _search_fetch_k( + conn: sqlite3.Connection, + limit: int, + fetch_multiplier: int, + *, + exact_file_ref_filter: bool, + source_type: str | None = None, + ) -> int: + if exact_file_ref_filter: + where = [] + params: list[Any] = [] + if source_type is not None: + where.append("source_type = ?") + params.append(source_type) + where_sql = "WHERE " + " AND ".join(where) if where else "" + return int( + conn.execute( + f"SELECT COUNT(*) FROM semantic_index_docs {where_sql}", + params, + ).fetchone()[0] + ) + return min(4096, max(limit, limit * max(fetch_multiplier, 1))) + def info(self) -> dict[str, Any]: with self.connect() as conn: config = { @@ -344,7 +392,8 @@ def _matches_filters( filters: dict[str, Any], ) -> bool: for key, expected in filters.items(): - actual = row[key] if key in row.keys() else metadata.get(key) + actual_key = "file_ref" if key == "file_refs" else key + actual = row[actual_key] if actual_key in row.keys() else metadata.get(actual_key) if isinstance(expected, list): if str(actual) not in {str(item) for item in expected}: return False @@ -360,3 +409,41 @@ def _source_type_filters(filters: dict[str, Any]) -> list[str]: if isinstance(value, list): return [str(item) for item in value if str(item)] return [str(value)] if str(value) else [] + + +def _file_ref_filters(filters: dict[str, Any]) -> list[str] | None: + if "file_ref" in filters: + value = filters.get("file_ref") + elif "file_refs" in filters: + value = filters.get("file_refs") + else: + return None + if isinstance(value, list): + return [str(item) for item in value if str(item)] + return [str(value)] if str(value) else [] + + +def _install_file_ref_filter_table(conn: sqlite3.Connection, file_refs: list[str]) -> None: + conn.execute( + """ + CREATE TEMP TABLE IF NOT EXISTS semantic_index_filter_file_refs ( + file_ref TEXT PRIMARY KEY + ) + """ + ) + conn.execute("DELETE FROM semantic_index_filter_file_refs") + conn.executemany( + "INSERT OR IGNORE INTO semantic_index_filter_file_refs(file_ref) VALUES (?)", + [(file_ref,) for file_ref in file_refs], + ) + + +def _file_ref_filter_sql(file_refs: list[str] | None) -> str: + if file_refs is None: + return "" + return ( + "AND EXISTS (" + "SELECT 1 FROM semantic_index_filter_file_refs scope_refs " + "WHERE scope_refs.file_ref = d.file_ref" + ")" + ) diff --git a/pageindex/filesystem/store.py b/pageindex/filesystem/store.py index 7517d70..30a7d32 100644 --- a/pageindex/filesystem/store.py +++ b/pageindex/filesystem/store.py @@ -753,6 +753,33 @@ class SQLiteFileSystemStore: return results return results + def file_refs_for_scope( + self, + *, + scope: Optional[dict[str, Any]] = None, + metadata_filter: Optional[dict[str, Any]] = None, + ) -> list[str]: + where = ["f.deleted_at IS NULL"] + params: list[Any] = [] + scope_sql, scope_params = self._scope_sql(scope) + if scope_sql: + where.append(scope_sql) + params.extend(scope_params) + metadata_sql, metadata_params = self._metadata_filter_sql(metadata_filter) + where.extend(metadata_sql) + params.extend(metadata_params) + with self.connect() as conn: + rows = conn.execute( + f""" + SELECT DISTINCT f.file_ref + FROM files f + WHERE {" AND ".join(where)} + ORDER BY f.file_ref + """, + params, + ).fetchall() + return [row["file_ref"] for row in rows] + def _search_once( self, match_query: str | None, diff --git a/tests/test_pageindex_filesystem_scope.py b/tests/test_pageindex_filesystem_scope.py index b18270c..04ce084 100644 --- a/tests/test_pageindex_filesystem_scope.py +++ b/tests/test_pageindex_filesystem_scope.py @@ -70,9 +70,10 @@ class ChannelBackend: class BrowseBackend: - def __init__(self, document_ids, channels=("summary",)): + def __init__(self, document_ids, channels=("summary",), file_refs_by_document_id=None): self.document_ids = list(document_ids) self.channels = channels + self.file_refs_by_document_id = dict(file_refs_by_document_id or {}) self.calls = [] def available_channels(self): @@ -80,6 +81,20 @@ class BrowseBackend: def search_channel(self, channel, query, *, limit=10, filters=None): self.calls.append((channel, query, limit, filters)) + file_ref_filter = set() + if isinstance(filters, dict): + raw_file_refs = filters.get("file_ref") or filters.get("file_refs") or [] + if isinstance(raw_file_refs, str): + file_ref_filter = {raw_file_refs} + else: + file_ref_filter = {str(item) for item in raw_file_refs} + document_ids = self.document_ids + if file_ref_filter and self.file_refs_by_document_id: + document_ids = [ + document_id + for document_id in document_ids + if self.file_refs_by_document_id.get(document_id) in file_ref_filter + ] return [ SimpleNamespace( document_id=document_id, @@ -87,7 +102,7 @@ class BrowseBackend: score=1.0 - rank * 0.01, sources=[{"channel": channel, "rank": rank, "distance": rank / 10}], ) - for rank, document_id in enumerate(self.document_ids[:limit], 1) + for rank, document_id in enumerate(document_ids[:limit], 1) ] @@ -108,11 +123,11 @@ def _register_browse_file(filesystem, external_id, folder_path, *, department="o filesystem.metadata_generator = SummaryGenerator() return filesystem.register_file( - storage_uri=f"file:///tmp/{external_id}.pdf", - source_path=f"documents/{external_id}.pdf", + storage_uri=f"file:///tmp/{external_id}.txt", + source_path=f"documents/{external_id}.txt", folder_path=folder_path, external_id=external_id, - title=f"{external_id}.pdf", + title=f"{external_id}.txt", content=f"{external_id} discusses vector databases and retrieval.", metadata={"department": department}, metadata_policy={ @@ -262,6 +277,49 @@ def test_browse_supports_fixed_size_one_based_pagination_and_metadata_filter(tmp assert filtered["data"][0]["summary"] == "summary for doc_10" +def test_browse_scopes_semantic_search_before_candidate_limit(tmp_path): + import json + + from pageindex.filesystem import PIFSCommandExecutor, PageIndexFileSystem + + filesystem = PageIndexFileSystem(workspace=tmp_path / "workspace") + file_refs_by_document_id = {} + candidate_ids = [] + for index in range(150): + external_id = f"off_scope_{index:02d}" + candidate_ids.append(external_id) + file_refs_by_document_id[external_id] = _register_browse_file( + filesystem, + external_id, + "/other", + ) + file_refs_by_document_id["doc_deep"] = _register_browse_file( + filesystem, + "doc_deep", + "/documents/reports", + ) + file_refs_by_document_id["doc_direct"] = _register_browse_file( + filesystem, + "doc_direct", + "/documents", + ) + backend = BrowseBackend( + [*candidate_ids, "doc_deep", "doc_direct"], + file_refs_by_document_id=file_refs_by_document_id, + ) + filesystem.semantic_retrieval_backend = backend + executor = PIFSCommandExecutor(filesystem, json_output=True) + + direct = json.loads(executor.execute('browse /documents "vector database"'))["data"] + assert [item["document_id"] for item in direct["data"]] == ["doc_direct"] + + recursive = json.loads(executor.execute('browse -R /documents "vector database"'))["data"] + assert [item["document_id"] for item in recursive["data"]] == [ + "doc_deep", + "doc_direct", + ] + + def test_semantic_search_scope_keeps_ordinary_folders_out_of_source_type_filters(tmp_path): from pageindex.filesystem import PIFSCommandExecutor, PageIndexFileSystem from pageindex.filesystem.metadata_generation import MetadataGenerationResult diff --git a/tests/test_semantic_index.py b/tests/test_semantic_index.py index 324ead7..d4263e1 100644 --- a/tests/test_semantic_index.py +++ b/tests/test_semantic_index.py @@ -55,6 +55,44 @@ def test_sqlite_vec_semantic_index_round_trip(tmp_path): assert [item.external_id for item in filtered] == ["doc_b"] +def test_sqlite_vec_semantic_index_file_ref_filter_not_limited_by_global_rank(tmp_path): + index = SQLiteVecSemanticIndex(tmp_path / "semantic.sqlite") + index.reset(dimension=2, metadata={"field_mode": "summary"}) + + records = [ + SemanticIndexRecord( + file_ref=f"file_off_{item:02d}", + external_id=f"doc_off_{item:02d}", + source_type="documents", + source_path=f"other/{item:02d}.pdf", + title=f"Off scope {item:02d}", + text="off scope", + vector=[1.0, 0.0], + ) + for item in range(30) + ] + records.append( + SemanticIndexRecord( + file_ref="file_in_scope", + external_id="doc_in_scope", + source_type="documents", + source_path="documents/in-scope.pdf", + title="In scope", + text="in scope", + vector=[0.0, 1.0], + ) + ) + index.upsert_many(records) + + results = index.search( + [1.0, 0.0], + limit=1, + filters={"file_ref": ["file_in_scope"]}, + ) + + assert [item.file_ref for item in results] == ["file_in_scope"] + + def test_summary_projection_indexes_unified_metadata_summary(tmp_path): from pageindex.filesystem.projection_indexing import SummaryProjectionIndexer