PageIndex/pageindex/filesystem/semantic_index.py
2026-06-01 01:40:44 +08:00

469 lines
16 KiB
Python

from __future__ import annotations
import hashlib
import json
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
import sqlite_vec
class SemanticIndexError(RuntimeError):
pass
@dataclass(frozen=True)
class SemanticIndexRecord:
file_ref: str
vector: list[float]
text: str
external_id: str | None = None
source_type: str = ""
title: str = ""
metadata: dict[str, Any] | None = None
@dataclass(frozen=True)
class SemanticSearchResult:
file_ref: str
distance: float
external_id: str | None
source_type: str
title: str
text_hash: str
metadata: dict[str, Any]
class RebuildableSemanticIndex(Protocol):
def reset(self, *, dimension: int, metadata: dict[str, Any] | None = None) -> None:
...
def upsert_many(self, records: list[SemanticIndexRecord]) -> int:
...
def search(
self,
vector: list[float],
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
fetch_multiplier: int = 20,
) -> list[SemanticSearchResult]:
...
def info(self) -> dict[str, Any]:
...
class SQLiteVecSemanticIndex:
"""Rebuildable local semantic index backed by sqlite-vec.
This is intentionally separate from the PIFS catalog tables. The catalog
remains source of truth; this file is a rebuildable recall index.
"""
def __init__(self, db_path: str | Path):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
def reset(self, *, dimension: int, metadata: dict[str, Any] | None = None) -> None:
if dimension <= 0:
raise SemanticIndexError("semantic index dimension must be positive")
with self.connect() as conn:
conn.executescript(
"""
DROP TABLE IF EXISTS semantic_index_vec;
DROP TABLE IF EXISTS semantic_index_docs;
DROP TABLE IF EXISTS semantic_index_config;
CREATE TABLE semantic_index_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
CREATE TABLE semantic_index_docs (
rowid INTEGER PRIMARY KEY,
file_ref TEXT NOT NULL UNIQUE,
external_id TEXT,
source_type TEXT NOT NULL DEFAULT '',
title TEXT NOT NULL DEFAULT '',
text_hash TEXT NOT NULL,
text_chars INTEGER NOT NULL DEFAULT 0,
metadata_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_semantic_index_docs_file_ref
ON semantic_index_docs(file_ref);
CREATE INDEX idx_semantic_index_docs_external_id
ON semantic_index_docs(external_id);
CREATE INDEX idx_semantic_index_docs_source_type
ON semantic_index_docs(source_type);
"""
)
conn.execute(
"CREATE VIRTUAL TABLE semantic_index_vec USING "
f"vec0(source_type TEXT partition key, embedding float[{dimension}])"
)
config = {
"dimension": str(dimension),
"adapter": "sqlite-vec",
"adapter_version": sqlite_vec.__version__,
"metadata": json.dumps(metadata or {}, ensure_ascii=False, sort_keys=True),
}
conn.executemany(
"INSERT INTO semantic_index_config(key, value) VALUES (?, ?)",
sorted(config.items()),
)
conn.commit()
def upsert_many(self, records: list[SemanticIndexRecord]) -> int:
if not records:
return 0
dimension = self.dimension()
with self.connect() as conn:
inserted = 0
for record in records:
if len(record.vector) != dimension:
raise SemanticIndexError(
f"vector dimension mismatch for {record.file_ref}: "
f"expected {dimension}, got {len(record.vector)}"
)
rowid = self._upsert_doc(conn, record)
conn.execute("DELETE FROM semantic_index_vec WHERE rowid = ?", (rowid,))
conn.execute(
"INSERT INTO semantic_index_vec(rowid, source_type, embedding) VALUES (?, ?, ?)",
(
rowid,
record.source_type,
sqlite_vec.serialize_float32(record.vector),
),
)
inserted += 1
conn.commit()
return inserted
def delete_file_refs(self, file_refs: list[str]) -> int:
refs = [str(file_ref) for file_ref in file_refs if str(file_ref)]
if not refs:
return 0
placeholders = ", ".join("?" for _ in refs)
with self.connect() as conn:
rows = conn.execute(
f"""
SELECT rowid
FROM semantic_index_docs
WHERE file_ref IN ({placeholders})
""",
refs,
).fetchall()
rowids = [int(row["rowid"]) for row in rows]
if not rowids:
return 0
rowid_placeholders = ", ".join("?" for _ in rowids)
conn.execute(
f"DELETE FROM semantic_index_vec WHERE rowid IN ({rowid_placeholders})",
rowids,
)
conn.execute(
f"DELETE FROM semantic_index_docs WHERE rowid IN ({rowid_placeholders})",
rowids,
)
conn.commit()
return len(rowids)
def search(
self,
vector: list[float],
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
fetch_multiplier: int = 20,
) -> list[SemanticSearchResult]:
dimension = self.dimension()
if len(vector) != dimension:
raise SemanticIndexError(
f"query vector dimension mismatch: expected {dimension}, got {len(vector)}"
)
raw_filters = filters or {}
source_types = _source_type_filters(raw_filters)
file_refs = _file_ref_filters(raw_filters)
if file_refs == []:
return []
with self.connect() as conn:
if file_refs is not None:
_install_file_ref_filter_table(conn, file_refs)
rows = []
if source_types:
for source_type in source_types:
fetch_k = self._search_fetch_k(
conn,
limit,
fetch_multiplier,
exact_file_ref_filter=file_refs is not None,
source_type=source_type,
)
if fetch_k <= 0:
continue
rows.extend(
conn.execute(
f"""
SELECT
d.file_ref,
d.external_id,
d.source_type,
d.title,
d.text_hash,
d.metadata_json,
v.distance
FROM semantic_index_vec v
JOIN semantic_index_docs d ON d.rowid = v.rowid
WHERE v.embedding MATCH ? AND k = ? AND v.source_type = ?
{_file_ref_filter_sql(file_refs)}
ORDER BY v.distance
""",
(sqlite_vec.serialize_float32(vector), fetch_k, source_type),
).fetchall()
)
rows.sort(key=lambda row: float(row["distance"]))
else:
fetch_k = self._search_fetch_k(
conn,
limit,
fetch_multiplier,
exact_file_ref_filter=file_refs is not None,
)
if fetch_k <= 0:
return []
rows = conn.execute(
f"""
SELECT
d.file_ref,
d.external_id,
d.source_type,
d.title,
d.text_hash,
d.metadata_json,
v.distance
FROM semantic_index_vec v
JOIN semantic_index_docs d ON d.rowid = v.rowid
WHERE v.embedding MATCH ? AND k = ?
{_file_ref_filter_sql(file_refs)}
ORDER BY v.distance
""",
(sqlite_vec.serialize_float32(vector), fetch_k),
).fetchall()
results: list[SemanticSearchResult] = []
for row in rows:
metadata = _json_obj(row["metadata_json"])
if not _matches_filters(row, metadata, filters or {}):
continue
results.append(
SemanticSearchResult(
file_ref=row["file_ref"],
distance=float(row["distance"]),
external_id=row["external_id"],
source_type=row["source_type"],
title=row["title"],
text_hash=row["text_hash"],
metadata=metadata,
)
)
if len(results) >= limit:
break
return results
@staticmethod
def _search_fetch_k(
conn: sqlite3.Connection,
limit: int,
fetch_multiplier: int,
*,
exact_file_ref_filter: bool,
source_type: str | None = None,
) -> int:
if exact_file_ref_filter:
where = []
params: list[Any] = []
if source_type is not None:
where.append("source_type = ?")
params.append(source_type)
where_sql = "WHERE " + " AND ".join(where) if where else ""
return int(
conn.execute(
f"SELECT COUNT(*) FROM semantic_index_docs {where_sql}",
params,
).fetchone()[0]
)
return min(4096, max(limit, limit * max(fetch_multiplier, 1)))
def info(self) -> dict[str, Any]:
with self.connect() as conn:
config = {
row["key"]: row["value"]
for row in conn.execute(
"SELECT key, value FROM semantic_index_config ORDER BY key"
).fetchall()
}
count = conn.execute("SELECT COUNT(*) FROM semantic_index_docs").fetchone()[0]
parsed_metadata: dict[str, Any]
try:
parsed_metadata = json.loads(config.get("metadata", "{}"))
except json.JSONDecodeError:
parsed_metadata = {}
return {
"db_path": str(self.db_path),
"adapter": config.get("adapter", "sqlite-vec"),
"adapter_version": config.get("adapter_version", ""),
"dimension": int(config.get("dimension", "0") or 0),
"document_count": count,
"metadata": parsed_metadata,
}
def dimension(self) -> int:
with self.connect() as conn:
row = conn.execute(
"SELECT value FROM semantic_index_config WHERE key = 'dimension'"
).fetchone()
if row is None:
raise SemanticIndexError(
f"semantic index is not initialized; call reset() first: {self.db_path}"
)
return int(row["value"])
def connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
return conn
@staticmethod
def text_hash(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
@staticmethod
def _upsert_doc(conn: sqlite3.Connection, record: SemanticIndexRecord) -> int:
existing = conn.execute(
"SELECT rowid FROM semantic_index_docs WHERE file_ref = ?",
(record.file_ref,),
).fetchone()
metadata_json = json.dumps(record.metadata or {}, ensure_ascii=False, sort_keys=True)
text_hash = SQLiteVecSemanticIndex.text_hash(record.text)
if existing is None:
cursor = conn.execute(
"""
INSERT INTO semantic_index_docs(
file_ref, external_id, source_type, title,
text_hash, text_chars, metadata_json
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
record.file_ref,
record.external_id,
record.source_type,
record.title,
text_hash,
len(record.text),
metadata_json,
),
)
return int(cursor.lastrowid)
rowid = int(existing["rowid"])
conn.execute(
"""
UPDATE semantic_index_docs
SET external_id = ?,
source_type = ?,
title = ?,
text_hash = ?,
text_chars = ?,
metadata_json = ?,
updated_at = CURRENT_TIMESTAMP
WHERE rowid = ?
""",
(
record.external_id,
record.source_type,
record.title,
text_hash,
len(record.text),
metadata_json,
rowid,
),
)
return rowid
def _json_obj(text: str | None) -> dict[str, Any]:
if not text:
return {}
try:
value = json.loads(text)
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _matches_filters(
row: sqlite3.Row,
metadata: dict[str, Any],
filters: dict[str, Any],
) -> bool:
for key, expected in filters.items():
actual_key = "file_ref" if key == "file_refs" else key
actual = row[actual_key] if actual_key in row.keys() else metadata.get(actual_key)
if isinstance(expected, list):
if str(actual) not in {str(item) for item in expected}:
return False
elif str(actual) != str(expected):
return False
return True
def _source_type_filters(filters: dict[str, Any]) -> list[str]:
value = filters.get("source_type")
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if str(item)]
return [str(value)] if str(value) else []
def _file_ref_filters(filters: dict[str, Any]) -> list[str] | None:
if "file_ref" in filters:
value = filters.get("file_ref")
elif "file_refs" in filters:
value = filters.get("file_refs")
else:
return None
if isinstance(value, list):
return [str(item) for item in value if str(item)]
return [str(value)] if str(value) else []
def _install_file_ref_filter_table(conn: sqlite3.Connection, file_refs: list[str]) -> None:
conn.execute(
"""
CREATE TEMP TABLE IF NOT EXISTS semantic_index_filter_file_refs (
file_ref TEXT PRIMARY KEY
)
"""
)
conn.execute("DELETE FROM semantic_index_filter_file_refs")
conn.executemany(
"INSERT OR IGNORE INTO semantic_index_filter_file_refs(file_ref) VALUES (?)",
[(file_ref,) for file_ref in file_refs],
)
def _file_ref_filter_sql(file_refs: list[str] | None) -> str:
if file_refs is None:
return ""
return (
"AND EXISTS ("
"SELECT 1 FROM semantic_index_filter_file_refs scope_refs "
"WHERE scope_refs.file_ref = d.file_ref"
")"
)