mirror of
https://github.com/VectifyAI/PageIndex.git
synced 2026-06-12 19:55:17 +02:00
fix(pifs): scope browse vector search before paging
This commit is contained in:
parent
ba821a70b9
commit
3562d47fdb
5 changed files with 238 additions and 16 deletions
|
|
@ -371,15 +371,25 @@ class PageIndexFileSystem:
|
|||
)
|
||||
parsed_filter = self.metadata.parse_filter(metadata_filter)
|
||||
scope = {"folder_path": path, "recursive": recursive}
|
||||
scope_file_refs = self.store.file_refs_for_scope(
|
||||
scope=scope,
|
||||
metadata_filter=parsed_filter,
|
||||
)
|
||||
offset = (page - 1) * page_size
|
||||
needed = offset + page_size + 1
|
||||
fetch_limit = max(needed * 10, 50)
|
||||
candidates = search_channel(
|
||||
space,
|
||||
query_text,
|
||||
limit=fetch_limit,
|
||||
filters=self._semantic_filters_for_scope(scope),
|
||||
semantic_filters = self._semantic_filters_for_scope(scope)
|
||||
semantic_filters["file_ref"] = scope_file_refs
|
||||
candidates = (
|
||||
search_channel(
|
||||
space,
|
||||
query_text,
|
||||
limit=needed,
|
||||
filters=semantic_filters,
|
||||
)
|
||||
if scope_file_refs
|
||||
else []
|
||||
)
|
||||
scope_file_ref_set = set(scope_file_refs)
|
||||
rows: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
for candidate in candidates:
|
||||
|
|
@ -389,6 +399,8 @@ class PageIndexFileSystem:
|
|||
continue
|
||||
if file_ref in seen:
|
||||
continue
|
||||
if file_ref not in scope_file_ref_set:
|
||||
continue
|
||||
if not self.store.file_matches(
|
||||
file_ref,
|
||||
scope=scope,
|
||||
|
|
|
|||
|
|
@ -159,15 +159,29 @@ class SQLiteVecSemanticIndex:
|
|||
raise SemanticIndexError(
|
||||
f"query vector dimension mismatch: expected {dimension}, got {len(vector)}"
|
||||
)
|
||||
fetch_k = min(4096, max(limit, limit * max(fetch_multiplier, 1)))
|
||||
source_types = _source_type_filters(filters or {})
|
||||
raw_filters = filters or {}
|
||||
source_types = _source_type_filters(raw_filters)
|
||||
file_refs = _file_ref_filters(raw_filters)
|
||||
if file_refs == []:
|
||||
return []
|
||||
with self.connect() as conn:
|
||||
if file_refs is not None:
|
||||
_install_file_ref_filter_table(conn, file_refs)
|
||||
rows = []
|
||||
if source_types:
|
||||
for source_type in source_types:
|
||||
fetch_k = self._search_fetch_k(
|
||||
conn,
|
||||
limit,
|
||||
fetch_multiplier,
|
||||
exact_file_ref_filter=file_refs is not None,
|
||||
source_type=source_type,
|
||||
)
|
||||
if fetch_k <= 0:
|
||||
continue
|
||||
rows.extend(
|
||||
conn.execute(
|
||||
"""
|
||||
f"""
|
||||
SELECT
|
||||
d.file_ref,
|
||||
d.external_id,
|
||||
|
|
@ -180,6 +194,7 @@ class SQLiteVecSemanticIndex:
|
|||
FROM semantic_index_vec v
|
||||
JOIN semantic_index_docs d ON d.rowid = v.rowid
|
||||
WHERE v.embedding MATCH ? AND k = ? AND v.source_type = ?
|
||||
{_file_ref_filter_sql(file_refs)}
|
||||
ORDER BY v.distance
|
||||
""",
|
||||
(sqlite_vec.serialize_float32(vector), fetch_k, source_type),
|
||||
|
|
@ -187,8 +202,16 @@ class SQLiteVecSemanticIndex:
|
|||
)
|
||||
rows.sort(key=lambda row: float(row["distance"]))
|
||||
else:
|
||||
fetch_k = self._search_fetch_k(
|
||||
conn,
|
||||
limit,
|
||||
fetch_multiplier,
|
||||
exact_file_ref_filter=file_refs is not None,
|
||||
)
|
||||
if fetch_k <= 0:
|
||||
return []
|
||||
rows = conn.execute(
|
||||
"""
|
||||
f"""
|
||||
SELECT
|
||||
d.file_ref,
|
||||
d.external_id,
|
||||
|
|
@ -201,6 +224,7 @@ class SQLiteVecSemanticIndex:
|
|||
FROM semantic_index_vec v
|
||||
JOIN semantic_index_docs d ON d.rowid = v.rowid
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
{_file_ref_filter_sql(file_refs)}
|
||||
ORDER BY v.distance
|
||||
""",
|
||||
(sqlite_vec.serialize_float32(vector), fetch_k),
|
||||
|
|
@ -226,6 +250,30 @@ class SQLiteVecSemanticIndex:
|
|||
break
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _search_fetch_k(
|
||||
conn: sqlite3.Connection,
|
||||
limit: int,
|
||||
fetch_multiplier: int,
|
||||
*,
|
||||
exact_file_ref_filter: bool,
|
||||
source_type: str | None = None,
|
||||
) -> int:
|
||||
if exact_file_ref_filter:
|
||||
where = []
|
||||
params: list[Any] = []
|
||||
if source_type is not None:
|
||||
where.append("source_type = ?")
|
||||
params.append(source_type)
|
||||
where_sql = "WHERE " + " AND ".join(where) if where else ""
|
||||
return int(
|
||||
conn.execute(
|
||||
f"SELECT COUNT(*) FROM semantic_index_docs {where_sql}",
|
||||
params,
|
||||
).fetchone()[0]
|
||||
)
|
||||
return min(4096, max(limit, limit * max(fetch_multiplier, 1)))
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
with self.connect() as conn:
|
||||
config = {
|
||||
|
|
@ -344,7 +392,8 @@ def _matches_filters(
|
|||
filters: dict[str, Any],
|
||||
) -> bool:
|
||||
for key, expected in filters.items():
|
||||
actual = row[key] if key in row.keys() else metadata.get(key)
|
||||
actual_key = "file_ref" if key == "file_refs" else key
|
||||
actual = row[actual_key] if actual_key in row.keys() else metadata.get(actual_key)
|
||||
if isinstance(expected, list):
|
||||
if str(actual) not in {str(item) for item in expected}:
|
||||
return False
|
||||
|
|
@ -360,3 +409,41 @@ def _source_type_filters(filters: dict[str, Any]) -> list[str]:
|
|||
if isinstance(value, list):
|
||||
return [str(item) for item in value if str(item)]
|
||||
return [str(value)] if str(value) else []
|
||||
|
||||
|
||||
def _file_ref_filters(filters: dict[str, Any]) -> list[str] | None:
|
||||
if "file_ref" in filters:
|
||||
value = filters.get("file_ref")
|
||||
elif "file_refs" in filters:
|
||||
value = filters.get("file_refs")
|
||||
else:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value if str(item)]
|
||||
return [str(value)] if str(value) else []
|
||||
|
||||
|
||||
def _install_file_ref_filter_table(conn: sqlite3.Connection, file_refs: list[str]) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TEMP TABLE IF NOT EXISTS semantic_index_filter_file_refs (
|
||||
file_ref TEXT PRIMARY KEY
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute("DELETE FROM semantic_index_filter_file_refs")
|
||||
conn.executemany(
|
||||
"INSERT OR IGNORE INTO semantic_index_filter_file_refs(file_ref) VALUES (?)",
|
||||
[(file_ref,) for file_ref in file_refs],
|
||||
)
|
||||
|
||||
|
||||
def _file_ref_filter_sql(file_refs: list[str] | None) -> str:
|
||||
if file_refs is None:
|
||||
return ""
|
||||
return (
|
||||
"AND EXISTS ("
|
||||
"SELECT 1 FROM semantic_index_filter_file_refs scope_refs "
|
||||
"WHERE scope_refs.file_ref = d.file_ref"
|
||||
")"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -753,6 +753,33 @@ class SQLiteFileSystemStore:
|
|||
return results
|
||||
return results
|
||||
|
||||
def file_refs_for_scope(
|
||||
self,
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
metadata_filter: Optional[dict[str, Any]] = None,
|
||||
) -> list[str]:
|
||||
where = ["f.deleted_at IS NULL"]
|
||||
params: list[Any] = []
|
||||
scope_sql, scope_params = self._scope_sql(scope)
|
||||
if scope_sql:
|
||||
where.append(scope_sql)
|
||||
params.extend(scope_params)
|
||||
metadata_sql, metadata_params = self._metadata_filter_sql(metadata_filter)
|
||||
where.extend(metadata_sql)
|
||||
params.extend(metadata_params)
|
||||
with self.connect() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT DISTINCT f.file_ref
|
||||
FROM files f
|
||||
WHERE {" AND ".join(where)}
|
||||
ORDER BY f.file_ref
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
return [row["file_ref"] for row in rows]
|
||||
|
||||
def _search_once(
|
||||
self,
|
||||
match_query: str | None,
|
||||
|
|
|
|||
|
|
@ -70,9 +70,10 @@ class ChannelBackend:
|
|||
|
||||
|
||||
class BrowseBackend:
|
||||
def __init__(self, document_ids, channels=("summary",)):
|
||||
def __init__(self, document_ids, channels=("summary",), file_refs_by_document_id=None):
|
||||
self.document_ids = list(document_ids)
|
||||
self.channels = channels
|
||||
self.file_refs_by_document_id = dict(file_refs_by_document_id or {})
|
||||
self.calls = []
|
||||
|
||||
def available_channels(self):
|
||||
|
|
@ -80,6 +81,20 @@ class BrowseBackend:
|
|||
|
||||
def search_channel(self, channel, query, *, limit=10, filters=None):
|
||||
self.calls.append((channel, query, limit, filters))
|
||||
file_ref_filter = set()
|
||||
if isinstance(filters, dict):
|
||||
raw_file_refs = filters.get("file_ref") or filters.get("file_refs") or []
|
||||
if isinstance(raw_file_refs, str):
|
||||
file_ref_filter = {raw_file_refs}
|
||||
else:
|
||||
file_ref_filter = {str(item) for item in raw_file_refs}
|
||||
document_ids = self.document_ids
|
||||
if file_ref_filter and self.file_refs_by_document_id:
|
||||
document_ids = [
|
||||
document_id
|
||||
for document_id in document_ids
|
||||
if self.file_refs_by_document_id.get(document_id) in file_ref_filter
|
||||
]
|
||||
return [
|
||||
SimpleNamespace(
|
||||
document_id=document_id,
|
||||
|
|
@ -87,7 +102,7 @@ class BrowseBackend:
|
|||
score=1.0 - rank * 0.01,
|
||||
sources=[{"channel": channel, "rank": rank, "distance": rank / 10}],
|
||||
)
|
||||
for rank, document_id in enumerate(self.document_ids[:limit], 1)
|
||||
for rank, document_id in enumerate(document_ids[:limit], 1)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -108,11 +123,11 @@ def _register_browse_file(filesystem, external_id, folder_path, *, department="o
|
|||
|
||||
filesystem.metadata_generator = SummaryGenerator()
|
||||
return filesystem.register_file(
|
||||
storage_uri=f"file:///tmp/{external_id}.pdf",
|
||||
source_path=f"documents/{external_id}.pdf",
|
||||
storage_uri=f"file:///tmp/{external_id}.txt",
|
||||
source_path=f"documents/{external_id}.txt",
|
||||
folder_path=folder_path,
|
||||
external_id=external_id,
|
||||
title=f"{external_id}.pdf",
|
||||
title=f"{external_id}.txt",
|
||||
content=f"{external_id} discusses vector databases and retrieval.",
|
||||
metadata={"department": department},
|
||||
metadata_policy={
|
||||
|
|
@ -262,6 +277,49 @@ def test_browse_supports_fixed_size_one_based_pagination_and_metadata_filter(tmp
|
|||
assert filtered["data"][0]["summary"] == "summary for doc_10"
|
||||
|
||||
|
||||
def test_browse_scopes_semantic_search_before_candidate_limit(tmp_path):
|
||||
import json
|
||||
|
||||
from pageindex.filesystem import PIFSCommandExecutor, PageIndexFileSystem
|
||||
|
||||
filesystem = PageIndexFileSystem(workspace=tmp_path / "workspace")
|
||||
file_refs_by_document_id = {}
|
||||
candidate_ids = []
|
||||
for index in range(150):
|
||||
external_id = f"off_scope_{index:02d}"
|
||||
candidate_ids.append(external_id)
|
||||
file_refs_by_document_id[external_id] = _register_browse_file(
|
||||
filesystem,
|
||||
external_id,
|
||||
"/other",
|
||||
)
|
||||
file_refs_by_document_id["doc_deep"] = _register_browse_file(
|
||||
filesystem,
|
||||
"doc_deep",
|
||||
"/documents/reports",
|
||||
)
|
||||
file_refs_by_document_id["doc_direct"] = _register_browse_file(
|
||||
filesystem,
|
||||
"doc_direct",
|
||||
"/documents",
|
||||
)
|
||||
backend = BrowseBackend(
|
||||
[*candidate_ids, "doc_deep", "doc_direct"],
|
||||
file_refs_by_document_id=file_refs_by_document_id,
|
||||
)
|
||||
filesystem.semantic_retrieval_backend = backend
|
||||
executor = PIFSCommandExecutor(filesystem, json_output=True)
|
||||
|
||||
direct = json.loads(executor.execute('browse /documents "vector database"'))["data"]
|
||||
assert [item["document_id"] for item in direct["data"]] == ["doc_direct"]
|
||||
|
||||
recursive = json.loads(executor.execute('browse -R /documents "vector database"'))["data"]
|
||||
assert [item["document_id"] for item in recursive["data"]] == [
|
||||
"doc_deep",
|
||||
"doc_direct",
|
||||
]
|
||||
|
||||
|
||||
def test_semantic_search_scope_keeps_ordinary_folders_out_of_source_type_filters(tmp_path):
|
||||
from pageindex.filesystem import PIFSCommandExecutor, PageIndexFileSystem
|
||||
from pageindex.filesystem.metadata_generation import MetadataGenerationResult
|
||||
|
|
|
|||
|
|
@ -55,6 +55,44 @@ def test_sqlite_vec_semantic_index_round_trip(tmp_path):
|
|||
assert [item.external_id for item in filtered] == ["doc_b"]
|
||||
|
||||
|
||||
def test_sqlite_vec_semantic_index_file_ref_filter_not_limited_by_global_rank(tmp_path):
|
||||
index = SQLiteVecSemanticIndex(tmp_path / "semantic.sqlite")
|
||||
index.reset(dimension=2, metadata={"field_mode": "summary"})
|
||||
|
||||
records = [
|
||||
SemanticIndexRecord(
|
||||
file_ref=f"file_off_{item:02d}",
|
||||
external_id=f"doc_off_{item:02d}",
|
||||
source_type="documents",
|
||||
source_path=f"other/{item:02d}.pdf",
|
||||
title=f"Off scope {item:02d}",
|
||||
text="off scope",
|
||||
vector=[1.0, 0.0],
|
||||
)
|
||||
for item in range(30)
|
||||
]
|
||||
records.append(
|
||||
SemanticIndexRecord(
|
||||
file_ref="file_in_scope",
|
||||
external_id="doc_in_scope",
|
||||
source_type="documents",
|
||||
source_path="documents/in-scope.pdf",
|
||||
title="In scope",
|
||||
text="in scope",
|
||||
vector=[0.0, 1.0],
|
||||
)
|
||||
)
|
||||
index.upsert_many(records)
|
||||
|
||||
results = index.search(
|
||||
[1.0, 0.0],
|
||||
limit=1,
|
||||
filters={"file_ref": ["file_in_scope"]},
|
||||
)
|
||||
|
||||
assert [item.file_ref for item in results] == ["file_in_scope"]
|
||||
|
||||
|
||||
def test_summary_projection_indexes_unified_metadata_summary(tmp_path):
|
||||
from pageindex.filesystem.projection_indexing import SummaryProjectionIndexer
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue