PageIndex/pageindex/filesystem/semantic_index.py

362 lines
13 KiB
Python

from __future__ import annotations
import hashlib
import json
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
import sqlite_vec
class SemanticIndexError(RuntimeError):
pass
@dataclass(frozen=True)
class SemanticIndexRecord:
file_ref: str
vector: list[float]
text: str
external_id: str | None = None
source_type: str = ""
source_path: str = ""
title: str = ""
metadata: dict[str, Any] | None = None
@dataclass(frozen=True)
class SemanticSearchResult:
file_ref: str
distance: float
external_id: str | None
source_type: str
source_path: str
title: str
text_hash: str
metadata: dict[str, Any]
class RebuildableSemanticIndex(Protocol):
def reset(self, *, dimension: int, metadata: dict[str, Any] | None = None) -> None:
...
def upsert_many(self, records: list[SemanticIndexRecord]) -> int:
...
def search(
self,
vector: list[float],
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
fetch_multiplier: int = 20,
) -> list[SemanticSearchResult]:
...
def info(self) -> dict[str, Any]:
...
class SQLiteVecSemanticIndex:
"""Rebuildable local semantic index backed by sqlite-vec.
This is intentionally separate from the PIFS catalog tables. The catalog
remains source of truth; this file is a rebuildable recall index.
"""
def __init__(self, db_path: str | Path):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
def reset(self, *, dimension: int, metadata: dict[str, Any] | None = None) -> None:
if dimension <= 0:
raise SemanticIndexError("semantic index dimension must be positive")
with self.connect() as conn:
conn.executescript(
"""
DROP TABLE IF EXISTS semantic_index_vec;
DROP TABLE IF EXISTS semantic_index_docs;
DROP TABLE IF EXISTS semantic_index_config;
CREATE TABLE semantic_index_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
CREATE TABLE semantic_index_docs (
rowid INTEGER PRIMARY KEY,
file_ref TEXT NOT NULL UNIQUE,
external_id TEXT,
source_type TEXT NOT NULL DEFAULT '',
source_path TEXT NOT NULL DEFAULT '',
title TEXT NOT NULL DEFAULT '',
text_hash TEXT NOT NULL,
text_chars INTEGER NOT NULL DEFAULT 0,
metadata_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_semantic_index_docs_file_ref
ON semantic_index_docs(file_ref);
CREATE INDEX idx_semantic_index_docs_external_id
ON semantic_index_docs(external_id);
CREATE INDEX idx_semantic_index_docs_source_type
ON semantic_index_docs(source_type);
"""
)
conn.execute(
"CREATE VIRTUAL TABLE semantic_index_vec USING "
f"vec0(source_type TEXT partition key, embedding float[{dimension}])"
)
config = {
"dimension": str(dimension),
"adapter": "sqlite-vec",
"adapter_version": sqlite_vec.__version__,
"metadata": json.dumps(metadata or {}, ensure_ascii=False, sort_keys=True),
}
conn.executemany(
"INSERT INTO semantic_index_config(key, value) VALUES (?, ?)",
sorted(config.items()),
)
conn.commit()
def upsert_many(self, records: list[SemanticIndexRecord]) -> int:
if not records:
return 0
dimension = self.dimension()
with self.connect() as conn:
inserted = 0
for record in records:
if len(record.vector) != dimension:
raise SemanticIndexError(
f"vector dimension mismatch for {record.file_ref}: "
f"expected {dimension}, got {len(record.vector)}"
)
rowid = self._upsert_doc(conn, record)
conn.execute("DELETE FROM semantic_index_vec WHERE rowid = ?", (rowid,))
conn.execute(
"INSERT INTO semantic_index_vec(rowid, source_type, embedding) VALUES (?, ?, ?)",
(
rowid,
record.source_type,
sqlite_vec.serialize_float32(record.vector),
),
)
inserted += 1
conn.commit()
return inserted
def search(
self,
vector: list[float],
*,
limit: int = 10,
filters: dict[str, Any] | None = None,
fetch_multiplier: int = 20,
) -> list[SemanticSearchResult]:
dimension = self.dimension()
if len(vector) != dimension:
raise SemanticIndexError(
f"query vector dimension mismatch: expected {dimension}, got {len(vector)}"
)
fetch_k = min(4096, max(limit, limit * max(fetch_multiplier, 1)))
source_types = _source_type_filters(filters or {})
with self.connect() as conn:
rows = []
if source_types:
for source_type in source_types:
rows.extend(
conn.execute(
"""
SELECT
d.file_ref,
d.external_id,
d.source_type,
d.source_path,
d.title,
d.text_hash,
d.metadata_json,
v.distance
FROM semantic_index_vec v
JOIN semantic_index_docs d ON d.rowid = v.rowid
WHERE v.embedding MATCH ? AND k = ? AND v.source_type = ?
ORDER BY v.distance
""",
(sqlite_vec.serialize_float32(vector), fetch_k, source_type),
).fetchall()
)
rows.sort(key=lambda row: float(row["distance"]))
else:
rows = conn.execute(
"""
SELECT
d.file_ref,
d.external_id,
d.source_type,
d.source_path,
d.title,
d.text_hash,
d.metadata_json,
v.distance
FROM semantic_index_vec v
JOIN semantic_index_docs d ON d.rowid = v.rowid
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance
""",
(sqlite_vec.serialize_float32(vector), fetch_k),
).fetchall()
results: list[SemanticSearchResult] = []
for row in rows:
metadata = _json_obj(row["metadata_json"])
if not _matches_filters(row, metadata, filters or {}):
continue
results.append(
SemanticSearchResult(
file_ref=row["file_ref"],
distance=float(row["distance"]),
external_id=row["external_id"],
source_type=row["source_type"],
source_path=row["source_path"],
title=row["title"],
text_hash=row["text_hash"],
metadata=metadata,
)
)
if len(results) >= limit:
break
return results
def info(self) -> dict[str, Any]:
with self.connect() as conn:
config = {
row["key"]: row["value"]
for row in conn.execute(
"SELECT key, value FROM semantic_index_config ORDER BY key"
).fetchall()
}
count = conn.execute("SELECT COUNT(*) FROM semantic_index_docs").fetchone()[0]
parsed_metadata: dict[str, Any]
try:
parsed_metadata = json.loads(config.get("metadata", "{}"))
except json.JSONDecodeError:
parsed_metadata = {}
return {
"db_path": str(self.db_path),
"adapter": config.get("adapter", "sqlite-vec"),
"adapter_version": config.get("adapter_version", ""),
"dimension": int(config.get("dimension", "0") or 0),
"document_count": count,
"metadata": parsed_metadata,
}
def dimension(self) -> int:
with self.connect() as conn:
row = conn.execute(
"SELECT value FROM semantic_index_config WHERE key = 'dimension'"
).fetchone()
if row is None:
raise SemanticIndexError(
f"semantic index is not initialized; call reset() first: {self.db_path}"
)
return int(row["value"])
def connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
return conn
@staticmethod
def text_hash(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
@staticmethod
def _upsert_doc(conn: sqlite3.Connection, record: SemanticIndexRecord) -> int:
existing = conn.execute(
"SELECT rowid FROM semantic_index_docs WHERE file_ref = ?",
(record.file_ref,),
).fetchone()
metadata_json = json.dumps(record.metadata or {}, ensure_ascii=False, sort_keys=True)
text_hash = SQLiteVecSemanticIndex.text_hash(record.text)
if existing is None:
cursor = conn.execute(
"""
INSERT INTO semantic_index_docs(
file_ref, external_id, source_type, source_path, title,
text_hash, text_chars, metadata_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
record.file_ref,
record.external_id,
record.source_type,
record.source_path,
record.title,
text_hash,
len(record.text),
metadata_json,
),
)
return int(cursor.lastrowid)
rowid = int(existing["rowid"])
conn.execute(
"""
UPDATE semantic_index_docs
SET external_id = ?,
source_type = ?,
source_path = ?,
title = ?,
text_hash = ?,
text_chars = ?,
metadata_json = ?,
updated_at = CURRENT_TIMESTAMP
WHERE rowid = ?
""",
(
record.external_id,
record.source_type,
record.source_path,
record.title,
text_hash,
len(record.text),
metadata_json,
rowid,
),
)
return rowid
def _json_obj(text: str | None) -> dict[str, Any]:
if not text:
return {}
try:
value = json.loads(text)
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _matches_filters(
row: sqlite3.Row,
metadata: dict[str, Any],
filters: dict[str, Any],
) -> bool:
for key, expected in filters.items():
actual = row[key] if key in row.keys() else metadata.get(key)
if isinstance(expected, list):
if str(actual) not in {str(item) for item in expected}:
return False
elif str(actual) != str(expected):
return False
return True
def _source_type_filters(filters: dict[str, Any]) -> list[str]:
value = filters.get("source_type")
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value if str(item)]
return [str(value)] if str(value) else []