- doctor: skip LanceDB check when qdrant_storage/ dir detected - topology: use daemon socket instead of local store (avoids lancedb crash) - qdrant_store: add records_as_dataframe() + wire into _TableShim so build_runtime_graph() works with Qdrant (was returning empty)
1369 lines
49 KiB
Python
1369 lines
49 KiB
Python
"""Qdrant-backed persistent memory store (D-01 storage engine, sync write).
|
|
|
|
Replaces LanceDB with Qdrant for remote vector storage.
|
|
|
|
Collections (payload-partitioned, per Qdrant best practices):
|
|
- `records` : MemoryRecord rows (1024-dim cosine vectors)
|
|
All points carry `table: "records"` + `group_id` payload.
|
|
- `metadata` : Payload-only (no vectors) containing edges, events,
|
|
budget_ledger, ratelimit_ledger.
|
|
Each point carries `table` + `group_id` payload.
|
|
|
|
Both collections use `is_tenant: true` payload indexes on `table` for
|
|
co-located storage and per-table HNSW graphs (payload_m=16, m=0).
|
|
|
|
Encryption-at-rest is identical to LanceDB: AES-256-GCM on
|
|
literal_surface / provenance_json / profile_modulation_gain_json (records)
|
|
and data_json (events). AD = record UUID bytes.
|
|
|
|
The Qdrant client connects to a remote server via QDRANT_URL + QDRANT_API_KEY
|
|
environment variables (overridable via constructor).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import functools
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
from uuid import UUID
|
|
|
|
from qdrant_client import QdrantClient, models
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
FieldCondition,
|
|
Filter,
|
|
MatchValue,
|
|
PayloadSchemaType,
|
|
PointStruct,
|
|
VectorParams,
|
|
)
|
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
|
|
|
import pandas as pd
|
|
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
|
|
from iai_mcp.crypto import (
|
|
CIPHERTEXT_PREFIX,
|
|
NONCE_BYTES,
|
|
CryptoKey,
|
|
encrypt_field,
|
|
is_encrypted,
|
|
)
|
|
from iai_mcp.types import (
|
|
DEFAULT_EMBED_DIM,
|
|
EMBED_DIM,
|
|
SCHEMA_VERSION_CURRENT,
|
|
MemoryRecord,
|
|
TIER_ENUM,
|
|
)
|
|
|
|
# --------------------------------------------------------------------------- env
|
|
# Qdrant connection: override with env vars or constructor kwargs.
|
|
QDRANT_URL = os.environ.get("QDRANT_URL")
|
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
|
|
|
# --------------------------------------------------------------------------- tables
|
|
RECORDS_TABLE = "records"
|
|
EDGES_TABLE = "edges"
|
|
EVENTS_TABLE = "events"
|
|
BUDGET_TABLE = "budget_ledger"
|
|
RATELIMIT_TABLE = "ratelimit_ledger"
|
|
METADATA_TABLE = "metadata"
|
|
|
|
# Embedding dimension for the records collection.
|
|
# Defaults to 1024 for bge-m3 (remote embedder). Overridable via
|
|
# IAI_MCP_EMBED_DIM env var.
|
|
_RECORDS_DIM = int(os.environ.get("IAI_MCP_EMBED_DIM", "1024"))
|
|
|
|
# Metadata tables that live in the metadata collection.
|
|
_METADATA_TABLES = frozenset({EDGES_TABLE, EVENTS_TABLE, BUDGET_TABLE, RATELIMIT_TABLE})
|
|
|
|
# Edge type enum.
|
|
EDGE_TYPES: frozenset[str] = frozenset({
|
|
"hebbian",
|
|
"contradicts",
|
|
"consolidated_from",
|
|
"schema_instance_of",
|
|
"temporal_next",
|
|
"invariant_anchor",
|
|
"curiosity_bridge",
|
|
"profile_modulates",
|
|
"hebbian_structure",
|
|
})
|
|
|
|
# RFC-4122 canonical UUID regex.
|
|
_UUID_STR_RE = re.compile(
|
|
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
|
)
|
|
|
|
|
|
def _uuid_literal(value: UUID | str) -> str:
|
|
"""Return a Qdrant WHERE-safe UUID literal."""
|
|
s = str(value).lower()
|
|
if not _UUID_STR_RE.match(s):
|
|
raise ValueError(f"not a canonical UUID: {value!r}")
|
|
return s
|
|
|
|
|
|
def _resolve_embed_dim() -> int:
|
|
"""Pick the embedding dimension for the records collection."""
|
|
env_dim = os.environ.get("IAI_MCP_EMBED_DIM")
|
|
if env_dim:
|
|
try:
|
|
return int(env_dim)
|
|
except ValueError:
|
|
pass
|
|
env_key = os.environ.get("IAI_MCP_EMBED_MODEL")
|
|
# bge-m3 -> 1024
|
|
if env_key == "bge-m3":
|
|
return 1024
|
|
return _RECORDS_DIM
|
|
|
|
|
|
class QdrantStore:
|
|
"""Qdrant-backed memory store.
|
|
|
|
Mirrors the MemoryStore API so all existing callers (daemon, MCP wrapper,
|
|
capture, retrieve, sleep, etc.) work unchanged.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path: Path | str | None = None,
|
|
user_id: str = "default",
|
|
read_consistency_interval: float | None = None,
|
|
url: str | None = None,
|
|
api_key: str | None = None,
|
|
) -> None:
|
|
"""Open (or initialise) a Qdrant-backed store.
|
|
|
|
Parameters
|
|
----------
|
|
url:
|
|
Qdrant server URL. Defaults to QDRANT_URL env var.
|
|
api_key:
|
|
Qdrant API key. Defaults to QDRANT_API_KEY env var.
|
|
user_id:
|
|
User ID that scopes the encryption key. Also used as group_id
|
|
for payload partitioning (multi-tenant ready).
|
|
read_consistency_interval:
|
|
Reserved for future consistency tuning (not used with Qdrant).
|
|
"""
|
|
self._embed_dim: int = _resolve_embed_dim()
|
|
self._user_id: str = user_id
|
|
self._group_id: str = user_id # payload partition key
|
|
self._read_consistency_interval = read_consistency_interval
|
|
|
|
# Qdrant client
|
|
self._qdrant_url = url or QDRANT_URL
|
|
self._qdrant_api_key = api_key or QDRANT_API_KEY
|
|
self._client = QdrantClient(
|
|
url=self._qdrant_url,
|
|
api_key=self._qdrant_api_key,
|
|
timeout=30,
|
|
)
|
|
|
|
# Encryption
|
|
self._crypto_key_wrapper = CryptoKey(user_id=user_id)
|
|
self._crypto_key: bytes | None = None
|
|
|
|
# Graph sync hook
|
|
self._graph_sync_hook: Callable[[str, MemoryRecord], None] | None = None
|
|
|
|
# Async write queue (same contract as MemoryStore)
|
|
self._write_queue = None
|
|
self._async_loop: asyncio.AbstractEventLoop | None = None
|
|
self._async_thread: threading.Thread | None = None
|
|
self._provenance_queue = None
|
|
|
|
# Ensure collections exist
|
|
self._ensure_collections()
|
|
|
|
# ------------------------------------------------------------------ schema
|
|
|
|
def _ensure_collections(self) -> None:
|
|
"""Create 2 collections (per Qdrant best practices: payload partitioning).
|
|
|
|
- `records`: 1024-dim cosine vectors for MemoryRecord rows.
|
|
- `metadata`: payload-only (no vectors) for edges/events/budget/ratelimit.
|
|
|
|
Both use `is_tenant: true` indexes on `table` for co-located storage
|
|
and per-table HNSW graphs (payload_m=16, m=0).
|
|
"""
|
|
# Collection 1: records (vectors)
|
|
if not self._collection_exists(RECORDS_TABLE):
|
|
try:
|
|
self._client.create_collection(
|
|
collection_name=RECORDS_TABLE,
|
|
vectors_config=VectorParams(
|
|
size=self._embed_dim,
|
|
distance=Distance.COSINE,
|
|
),
|
|
hnsw_config=models.HnswConfigDiff(
|
|
payload_m=16, # per-table HNSW graph
|
|
m=0, # no global index (all data partitioned by table)
|
|
),
|
|
)
|
|
except UnexpectedResponse as e:
|
|
if e.status_code == 409:
|
|
pass # exists
|
|
elif e.status_code == 401:
|
|
raise
|
|
else:
|
|
raise
|
|
|
|
# Collection 2: metadata (payload-only)
|
|
if not self._collection_exists(METADATA_TABLE):
|
|
try:
|
|
self._client.create_collection(
|
|
collection_name=METADATA_TABLE,
|
|
# No vectors_config — payload-only collection
|
|
)
|
|
except UnexpectedResponse as e:
|
|
if e.status_code == 409:
|
|
pass # exists
|
|
elif e.status_code == 401:
|
|
raise
|
|
else:
|
|
raise
|
|
|
|
# Ensure payload indexes for filtered queries
|
|
self._ensure_indexes()
|
|
|
|
def _collection_exists(self, name: str) -> bool:
|
|
"""Check if a collection exists."""
|
|
try:
|
|
info = self._client.get_collection(name)
|
|
return info is not None
|
|
except UnexpectedResponse as e:
|
|
if e.status_code == 401:
|
|
raise RuntimeError(
|
|
"Qdrant 401: invalid API key. Set QDRANT_API_KEY env var."
|
|
) from e
|
|
# Collection doesn't exist or other error
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
def _ensure_indexes(self) -> None:
|
|
"""Create payload indexes for common filter fields.
|
|
|
|
Both collections get `table` and `group_id` as tenant indexes
|
|
(is_tenant=true) for co-located storage and per-table HNSW graphs.
|
|
"""
|
|
# Tenant indexes — critical for co-located storage
|
|
for collection_name in (RECORDS_TABLE, METADATA_TABLE):
|
|
for field_name, is_tenant in [("table", True), ("group_id", True)]:
|
|
try:
|
|
self._client.create_payload_index(
|
|
collection_name=collection_name,
|
|
field_name=field_name,
|
|
field_schema=PayloadSchemaType.KEYWORD,
|
|
)
|
|
# Note: Qdrant API doesn't expose is_tenant via create_payload_index;
|
|
# it's set via update_collection_payload_profiles API.
|
|
# For now, we create the keyword index; is_tenant optimization
|
|
# is handled by the application-level table filtering.
|
|
except Exception:
|
|
pass
|
|
|
|
# Collection-specific indexes
|
|
indexes = {
|
|
RECORDS_TABLE: [
|
|
("id", PayloadSchemaType.KEYWORD),
|
|
("tier", PayloadSchemaType.KEYWORD),
|
|
("community_id", PayloadSchemaType.KEYWORD),
|
|
],
|
|
METADATA_TABLE: [
|
|
("edge_type", PayloadSchemaType.KEYWORD),
|
|
("kind", PayloadSchemaType.KEYWORD),
|
|
("session_id", PayloadSchemaType.KEYWORD),
|
|
],
|
|
}
|
|
for tbl_name, fields in indexes.items():
|
|
for field_name, schema_type in fields:
|
|
try:
|
|
self._client.create_payload_index(
|
|
collection_name=tbl_name,
|
|
field_name=field_name,
|
|
field_schema=schema_type,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
@property
|
|
def embed_dim(self) -> int:
|
|
"""Actual embedding dimension in the records collection."""
|
|
return self._embed_dim
|
|
|
|
@property
|
|
def user_id(self) -> str:
|
|
"""user_id that scopes the encryption key."""
|
|
return self._user_id
|
|
|
|
# ------------------------------------------------------------------ encryption
|
|
|
|
def _key(self) -> bytes:
|
|
"""Lazy-load the encryption key."""
|
|
if self._crypto_key is None:
|
|
self._crypto_key = self._crypto_key_wrapper.get_or_create()
|
|
return self._crypto_key
|
|
|
|
def _ad(self, record_id: UUID | str) -> bytes:
|
|
"""Associated data for encryption: canonical UUID str bytes."""
|
|
return _uuid_literal(record_id).encode("ascii")
|
|
|
|
def _encrypt_for_record(self, record_id: UUID, value: str) -> str:
|
|
"""Encrypt a per-record sensitive field."""
|
|
if is_encrypted(value):
|
|
return value
|
|
return encrypt_field(value, self._key(), associated_data=self._ad(record_id))
|
|
|
|
@functools.cached_property
|
|
def _cached_aesgcm(self) -> AESGCM:
|
|
"""One AESGCM cipher per store lifetime."""
|
|
return AESGCM(self._key())
|
|
|
|
def _decrypt_for_record(self, record_id: UUID, value: str) -> str:
|
|
"""Decrypt a per-record sensitive field; pass through plaintext."""
|
|
if not is_encrypted(value):
|
|
return value
|
|
if not value.startswith(CIPHERTEXT_PREFIX):
|
|
raise ValueError("field is not iai:enc:v1:-prefixed ciphertext")
|
|
payload_b64 = value[len(CIPHERTEXT_PREFIX):]
|
|
payload = base64.b64decode(payload_b64)
|
|
if len(payload) < NONCE_BYTES + 16:
|
|
raise ValueError("ciphertext payload too short")
|
|
nonce = payload[:NONCE_BYTES]
|
|
ct_with_tag = payload[NONCE_BYTES:]
|
|
associated_data = self._ad(record_id)
|
|
plaintext_bytes = self._cached_aesgcm.decrypt(
|
|
nonce, ct_with_tag, associated_data or None
|
|
)
|
|
return plaintext_bytes.decode("utf-8")
|
|
|
|
# ------------------------------------------------------------------ I/O
|
|
|
|
def register_graph_sync_hook(
|
|
self, hook: Callable[[str, MemoryRecord], None] | None
|
|
) -> None:
|
|
"""Register a callback that mirrors store writes to the runtime graph."""
|
|
self._graph_sync_hook = hook
|
|
|
|
def _fire_graph_sync_hook(self, op: str, record: MemoryRecord) -> None:
|
|
"""Dispatch the (op, record) event. Failures are swallowed."""
|
|
hook = self._graph_sync_hook
|
|
if hook is None:
|
|
return
|
|
try:
|
|
hook(op, record)
|
|
except Exception as exc:
|
|
try:
|
|
sys.stderr.write(
|
|
json.dumps({
|
|
"event": "graph_sync_failed",
|
|
"op": op,
|
|
"record_id": str(getattr(record, "id", "")),
|
|
"error": str(exc),
|
|
"ts": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
+ "\n"
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
# ------------------------------------------------------- conversions
|
|
|
|
def _to_point(self, r: MemoryRecord) -> PointStruct:
|
|
"""Convert a MemoryRecord to a Qdrant PointStruct."""
|
|
literal_ct = self._encrypt_for_record(r.id, r.literal_surface)
|
|
provenance_plain = json.dumps(r.provenance)
|
|
provenance_ct = self._encrypt_for_record(r.id, provenance_plain)
|
|
gain_plain = json.dumps(r.profile_modulation_gain or {})
|
|
gain_ct = self._encrypt_for_record(r.id, gain_plain)
|
|
|
|
return PointStruct(
|
|
id=str(r.id),
|
|
vector=[float(x) for x in r.embedding],
|
|
payload={
|
|
"table": RECORDS_TABLE,
|
|
"group_id": self._group_id,
|
|
"tier": r.tier,
|
|
"literal_surface": literal_ct,
|
|
"aaak_index": r.aaak_index or "",
|
|
"structure_hv": base64.b64encode(
|
|
bytes(r.structure_hv or b"")
|
|
).decode("ascii"),
|
|
"community_id": str(r.community_id) if r.community_id else "",
|
|
"centrality": float(r.centrality),
|
|
"detail_level": int(r.detail_level),
|
|
"pinned": bool(r.pinned),
|
|
"stability": float(r.stability),
|
|
"difficulty": float(r.difficulty),
|
|
"last_reviewed": (
|
|
r.last_reviewed.isoformat() if r.last_reviewed else None
|
|
),
|
|
"never_decay": bool(r.never_decay),
|
|
"never_merge": bool(r.never_merge),
|
|
"provenance_json": provenance_ct,
|
|
"created_at": r.created_at.isoformat() if r.created_at else None,
|
|
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
|
"tags_json": json.dumps(r.tags),
|
|
"language": str(r.language),
|
|
"s5_trust_score": float(r.s5_trust_score),
|
|
"profile_modulation_gain_json": gain_ct,
|
|
"schema_version": int(r.schema_version),
|
|
},
|
|
)
|
|
|
|
def _from_point(self, point: PointStruct) -> MemoryRecord:
|
|
"""Convert a Qdrant PointStruct to a MemoryRecord."""
|
|
from uuid import UUID as _UUID
|
|
|
|
payload = point.payload
|
|
if not payload:
|
|
raise ValueError(f"point {point.id} has no payload")
|
|
|
|
row_uuid = _UUID(point.id)
|
|
|
|
# structure_hv: base64-decoded bytes
|
|
structure_raw = payload.get("structure_hv")
|
|
if structure_raw:
|
|
try:
|
|
structure_hv = base64.b64decode(structure_raw)
|
|
except Exception:
|
|
structure_hv = b""
|
|
else:
|
|
structure_hv = b""
|
|
|
|
community_raw = payload.get("community_id") or ""
|
|
community_id = _UUID(community_raw) if community_raw else None
|
|
|
|
# language back-compat
|
|
lang_raw = payload.get("language")
|
|
version_raw = payload.get("schema_version")
|
|
try:
|
|
version_int = int(version_raw) if version_raw is not None else SCHEMA_VERSION_CURRENT
|
|
except (TypeError, ValueError):
|
|
version_int = SCHEMA_VERSION_CURRENT
|
|
schema_version = version_int
|
|
|
|
is_empty_language = lang_raw is None or (isinstance(lang_raw, str) and lang_raw == "")
|
|
if is_empty_language and schema_version == 1:
|
|
language = "__LEGACY_EMPTY__"
|
|
elif is_empty_language:
|
|
language = "en"
|
|
else:
|
|
language = str(lang_raw)
|
|
|
|
s5_raw = payload.get("s5_trust_score")
|
|
s5_trust_score = float(s5_raw) if s5_raw is not None else 0.5
|
|
|
|
# decrypt profile_modulation_gain_json
|
|
gain_raw = payload.get("profile_modulation_gain_json") or "{}"
|
|
if is_encrypted(gain_raw):
|
|
gain_raw = self._decrypt_for_record(row_uuid, gain_raw)
|
|
try:
|
|
profile_modulation_gain = json.loads(gain_raw) or {}
|
|
except (TypeError, json.JSONDecodeError):
|
|
profile_modulation_gain = {}
|
|
|
|
# decrypt literal_surface + provenance_json
|
|
literal_raw = payload.get("literal_surface", "")
|
|
if is_encrypted(literal_raw):
|
|
literal_raw = self._decrypt_for_record(row_uuid, literal_raw)
|
|
provenance_raw = payload.get("provenance_json") or "[]"
|
|
if is_encrypted(provenance_raw):
|
|
provenance_raw = self._decrypt_for_record(row_uuid, provenance_raw)
|
|
try:
|
|
provenance_list = json.loads(provenance_raw) if provenance_raw else []
|
|
except (TypeError, json.JSONDecodeError):
|
|
provenance_list = []
|
|
|
|
def _parse_ts(val):
|
|
if val is None:
|
|
return None
|
|
try:
|
|
return datetime.fromisoformat(val)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
rec = MemoryRecord(
|
|
id=row_uuid,
|
|
tier=payload.get("tier", "episodic"),
|
|
literal_surface=literal_raw,
|
|
aaak_index=payload.get("aaak_index") or "",
|
|
embedding=list(point.vector) if point.vector else [],
|
|
community_id=community_id,
|
|
centrality=float(payload.get("centrality", 0.0) or 0.0),
|
|
detail_level=int(payload.get("detail_level", 1)),
|
|
pinned=bool(payload.get("pinned", False)),
|
|
stability=float(payload.get("stability") or 0.0),
|
|
difficulty=float(payload.get("difficulty") or 0.0),
|
|
last_reviewed=_parse_ts(payload.get("last_reviewed")),
|
|
never_decay=bool(payload.get("never_decay", False)),
|
|
never_merge=bool(payload.get("never_merge", False)),
|
|
provenance=provenance_list,
|
|
created_at=_parse_ts(payload.get("created_at")) or datetime.now(timezone.utc),
|
|
updated_at=_parse_ts(payload.get("updated_at")) or datetime.now(timezone.utc),
|
|
tags=json.loads(payload.get("tags_json") or "[]"),
|
|
language=language,
|
|
s5_trust_score=s5_trust_score,
|
|
profile_modulation_gain=profile_modulation_gain,
|
|
schema_version=schema_version,
|
|
structure_hv=structure_hv,
|
|
)
|
|
if language == "__LEGACY_EMPTY__":
|
|
rec.language = ""
|
|
return rec
|
|
|
|
# ------------------------------------------------------- point operations
|
|
|
|
def _make_filter(self, id_str: str) -> Filter:
|
|
"""Create a filter for a single record ID."""
|
|
return Filter(must=[FieldCondition(
|
|
key="id",
|
|
match=MatchValue(value=id_str),
|
|
)])
|
|
|
|
def _make_tier_filter(self, tier: str) -> Filter:
|
|
"""Create a filter for a tier value."""
|
|
return Filter(must=[FieldCondition(
|
|
key="tier",
|
|
match=MatchValue(value=tier),
|
|
)])
|
|
|
|
def _make_combined_filter(self, *conditions: FieldCondition) -> Filter:
|
|
"""Combine multiple field conditions with AND."""
|
|
return Filter(must=list(conditions))
|
|
|
|
# ------------------------------------------------------- writes
|
|
|
|
def insert(self, record: MemoryRecord) -> None:
|
|
"""Append a record. verbatim, no rewrite at write time."""
|
|
if record.tier not in TIER_ENUM:
|
|
raise ValueError(f"invalid tier {record.tier!r}")
|
|
if len(record.embedding) != self._embed_dim:
|
|
raise ValueError(
|
|
f"embedding must be {self._embed_dim}d, got {len(record.embedding)}"
|
|
)
|
|
|
|
# lazy structure_hv fill
|
|
if not record.structure_hv:
|
|
from iai_mcp.tem import bind_structure
|
|
record.structure_hv = bind_structure(record)
|
|
|
|
point = self._to_point(record)
|
|
self._client.upsert(
|
|
collection_name=RECORDS_TABLE,
|
|
points=[point],
|
|
)
|
|
self._fire_graph_sync_hook("insert", record)
|
|
|
|
def update(self, record: MemoryRecord) -> None:
|
|
"""Full-record update (rewrites core columns)."""
|
|
if len(record.embedding) != self._embed_dim:
|
|
raise ValueError(
|
|
f"embedding must be {self._embed_dim}d, got {len(record.embedding)}"
|
|
)
|
|
|
|
# Check existence first
|
|
points = self._client.retrieve(
|
|
collection_name=RECORDS_TABLE,
|
|
ids=[str(record.id)],
|
|
)
|
|
if not points:
|
|
return
|
|
|
|
point = self._to_point(record)
|
|
self._client.upsert(
|
|
collection_name=RECORDS_TABLE,
|
|
points=[point],
|
|
)
|
|
self._fire_graph_sync_hook("update", record)
|
|
|
|
def delete(self, record_id: UUID) -> None:
|
|
"""Remove a record by id."""
|
|
try:
|
|
self._client.delete(
|
|
collection_name=RECORDS_TABLE,
|
|
points_selector=models.PointIdsList(
|
|
points=[str(record_id)],
|
|
),
|
|
)
|
|
except Exception:
|
|
return
|
|
|
|
class _DeleteShim:
|
|
def __init__(self, rid):
|
|
self.id = rid
|
|
self._fire_graph_sync_hook("delete", _DeleteShim(record_id))
|
|
|
|
def get(self, record_id: UUID) -> MemoryRecord | None:
|
|
"""Point read by ID."""
|
|
points = self._client.retrieve(
|
|
collection_name=RECORDS_TABLE,
|
|
ids=[str(record_id)],
|
|
with_payload=True,
|
|
)
|
|
if not points:
|
|
return None
|
|
return self._from_point(points[0])
|
|
|
|
def all_records(self) -> list[MemoryRecord]:
|
|
"""Full table scan of the records collection."""
|
|
offset = None
|
|
all_points = []
|
|
table_filter = Filter(must=[FieldCondition(
|
|
key="table", match=MatchValue(value=RECORDS_TABLE),
|
|
)])
|
|
while True:
|
|
points, next_offset = self._client.scroll(
|
|
collection_name=RECORDS_TABLE,
|
|
limit=1000,
|
|
offset=offset,
|
|
scroll_filter=table_filter,
|
|
with_payload=True,
|
|
)
|
|
all_points.extend(points)
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
return [self._from_point(p) for p in all_points]
|
|
|
|
def iter_records(
|
|
self,
|
|
*,
|
|
columns: list[str] | None = None,
|
|
batch_size: int = 1024,
|
|
where: str | None = None,
|
|
):
|
|
"""Streaming iterator over records (filtered by table=records)."""
|
|
offset = None
|
|
while True:
|
|
# Build filter: always include table=records, optionally add tier
|
|
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
|
|
if where and where.startswith("tier = "):
|
|
tier = where.split("'")[1]
|
|
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
|
|
qdrant_filter = Filter(must=conditions) if conditions else None
|
|
|
|
points, next_offset = self._client.scroll(
|
|
collection_name=RECORDS_TABLE,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
scroll_filter=qdrant_filter,
|
|
with_payload=True,
|
|
)
|
|
for point in points:
|
|
yield self._from_point(point)
|
|
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
|
|
def iter_record_columns(
|
|
self,
|
|
columns: list[str],
|
|
*,
|
|
batch_size: int = 1024,
|
|
where: str | None = None,
|
|
):
|
|
"""Projection-only iteration; no MemoryRecord, no decrypt.
|
|
Filtered by table=records.
|
|
"""
|
|
if not columns:
|
|
raise ValueError("iter_record_columns requires a non-empty columns list")
|
|
|
|
offset = None
|
|
while True:
|
|
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
|
|
if where and where.startswith("tier = "):
|
|
tier = where.split("'")[1]
|
|
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
|
|
qdrant_filter = Filter(must=conditions) if conditions else None
|
|
|
|
points, next_offset = self._client.scroll(
|
|
collection_name=RECORDS_TABLE,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
scroll_filter=qdrant_filter,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
for point in points:
|
|
# Filter to requested columns
|
|
row = {k: v for k, v in point.payload.items() if k in columns}
|
|
row["id"] = point.id
|
|
yield row
|
|
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
|
|
def query_similar(
|
|
self,
|
|
vec: list[float],
|
|
k: int = 10,
|
|
tier: str | None = None,
|
|
) -> list[tuple[MemoryRecord, float]]:
|
|
"""Cosine-distance kNN search on the records collection."""
|
|
if tier is not None and tier not in TIER_ENUM:
|
|
raise ValueError(
|
|
f"invalid tier {tier!r}; must be one of {sorted(TIER_ENUM)}"
|
|
)
|
|
|
|
# Build query filter: must be in records collection + optional tier
|
|
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
|
|
if tier is not None:
|
|
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
|
|
qdrant_filter = Filter(must=conditions)
|
|
|
|
# Check if collection is empty
|
|
try:
|
|
info = self._client.get_collection(RECORDS_TABLE)
|
|
if info.points_count == 0:
|
|
return []
|
|
except Exception:
|
|
return []
|
|
|
|
# Vector search
|
|
search_result = self._client.query_points(
|
|
collection_name=RECORDS_TABLE,
|
|
query=vec,
|
|
limit=k,
|
|
query_filter=qdrant_filter,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
).points
|
|
|
|
out: list[tuple[MemoryRecord, float]] = []
|
|
for point in search_result:
|
|
record = self._from_point(point)
|
|
# Qdrant returns score as similarity (1.0 = identical, 0.0 = orthogonal)
|
|
score = float(point.score) if point.score else 0.0
|
|
out.append((record, score))
|
|
return out
|
|
|
|
def update_record(self, record: MemoryRecord) -> None:
|
|
"""Persist FSRS-relevant columns back to the records collection."""
|
|
# Retrieve existing point
|
|
points = self._client.retrieve(
|
|
collection_name=RECORDS_TABLE,
|
|
ids=[str(record.id)],
|
|
with_payload=True,
|
|
)
|
|
if not points:
|
|
return
|
|
|
|
# Update only FSRS columns via patch
|
|
self._client.set_payload(
|
|
collection_name=RECORDS_TABLE,
|
|
payload={
|
|
"stability": float(record.stability),
|
|
"difficulty": float(record.difficulty),
|
|
"last_reviewed": record.last_reviewed.isoformat() if record.last_reviewed else None,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
payload_selector=[str(record.id)],
|
|
)
|
|
|
|
# -------------------------------------------------------- reconsolidation
|
|
|
|
def append_provenance(self, record_id: UUID, entry: dict) -> None:
|
|
"""Append a provenance entry to the record."""
|
|
# Retrieve existing point
|
|
points = self._client.retrieve(
|
|
collection_name=RECORDS_TABLE,
|
|
ids=[str(record_id)],
|
|
with_payload=True,
|
|
)
|
|
if not points:
|
|
return
|
|
|
|
payload = points[0].payload
|
|
raw = payload.get("provenance_json") or "[]"
|
|
if is_encrypted(raw):
|
|
raw = self._decrypt_for_record(record_id, raw)
|
|
try:
|
|
existing = json.loads(raw)
|
|
except (TypeError, json.JSONDecodeError):
|
|
existing = []
|
|
existing.append(entry)
|
|
new_plain = json.dumps(existing)
|
|
new_ct = self._encrypt_for_record(record_id, new_plain)
|
|
|
|
# Update provenance_json and updated_at
|
|
self._client.set_payload(
|
|
collection_name=RECORDS_TABLE,
|
|
payload={
|
|
"provenance_json": new_ct,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
payload_selector=[str(record_id)],
|
|
)
|
|
|
|
def append_provenance_batch(
|
|
self, pairs: "list[tuple[UUID, dict]]",
|
|
records_cache: "dict | None" = None,
|
|
) -> None:
|
|
"""Batched provenance append."""
|
|
if not pairs:
|
|
return
|
|
|
|
# Group entries by record_id
|
|
from collections import defaultdict
|
|
grouped: dict[str, list[dict]] = defaultdict(list)
|
|
for rid, entry in pairs:
|
|
grouped[str(rid)].append(entry)
|
|
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
updates: dict[str, str] = {}
|
|
|
|
if records_cache is not None:
|
|
for rid_str, entries in grouped.items():
|
|
try:
|
|
canonical = _uuid_literal(rid_str)
|
|
except ValueError:
|
|
continue
|
|
try:
|
|
rec = records_cache.get(UUID(rid_str))
|
|
except (TypeError, ValueError):
|
|
rec = None
|
|
if rec is None:
|
|
rec = records_cache.get(rid_str)
|
|
if rec is None:
|
|
continue
|
|
existing = list(rec.provenance or [])
|
|
existing.extend(entries)
|
|
new_plain = json.dumps(existing)
|
|
new_ct = self._encrypt_for_record(UUID(rid_str), new_plain)
|
|
updates[canonical] = new_ct
|
|
else:
|
|
# Retrieve all affected points
|
|
ids_to_fetch = list(grouped.keys())
|
|
try:
|
|
points = self._client.retrieve(
|
|
collection_name=RECORDS_TABLE,
|
|
ids=ids_to_fetch,
|
|
with_payload=True,
|
|
)
|
|
except Exception:
|
|
return
|
|
|
|
for point in points:
|
|
rid_str = point.id
|
|
if rid_str not in grouped:
|
|
continue
|
|
entries = grouped[rid_str]
|
|
payload = point.payload
|
|
raw_prov = payload.get("provenance_json") or "[]"
|
|
if is_encrypted(raw_prov):
|
|
try:
|
|
raw_prov = self._decrypt_for_record(UUID(rid_str), raw_prov)
|
|
except Exception:
|
|
raw_prov = "[]"
|
|
try:
|
|
existing = json.loads(raw_prov)
|
|
except (TypeError, json.JSONDecodeError):
|
|
existing = []
|
|
existing.extend(entries)
|
|
new_plain = json.dumps(existing)
|
|
new_ct = self._encrypt_for_record(UUID(rid_str), new_plain)
|
|
updates[rid_str] = new_ct
|
|
|
|
if not updates:
|
|
return
|
|
|
|
# Batch update provenance_json + updated_at
|
|
# Qdrant doesn't have merge_insert, so we do individual set_payload calls
|
|
for rid_str, new_ct in updates.items():
|
|
try:
|
|
self._client.set_payload(
|
|
collection_name=RECORDS_TABLE,
|
|
payload={
|
|
"provenance_json": new_ct,
|
|
"updated_at": now,
|
|
},
|
|
payload_selector=[rid_str],
|
|
)
|
|
except Exception:
|
|
continue
|
|
|
|
# ------------------------------------------------------------------ edges
|
|
|
|
def boost_edges(
|
|
self,
|
|
pairs: list[tuple[UUID, UUID]],
|
|
delta: float | list[float] = 0.1,
|
|
edge_type: str = "hebbian",
|
|
) -> dict[tuple[str, str], float]:
|
|
"""Pairwise edge boost in the metadata collection (table=edges)."""
|
|
if edge_type not in EDGE_TYPES:
|
|
raise ValueError(
|
|
f"invalid edge_type {edge_type!r}; must be one of {sorted(EDGE_TYPES)}"
|
|
)
|
|
|
|
if isinstance(delta, (int, float)):
|
|
deltas = [float(delta)] * len(pairs)
|
|
else:
|
|
deltas = [float(d) for d in delta]
|
|
if len(deltas) != len(pairs):
|
|
raise ValueError(
|
|
f"deltas length {len(deltas)} != pairs length {len(pairs)}"
|
|
)
|
|
|
|
if not pairs:
|
|
return {}
|
|
|
|
# Coalesce duplicate canonical keys
|
|
coalesced: dict[tuple[str, str], float] = {}
|
|
for (a, b), d in zip(pairs, deltas):
|
|
key = (str(a), str(b))
|
|
canonical = tuple(sorted(key))
|
|
coalesced[canonical] = coalesced.get(canonical, 0.0) + d
|
|
|
|
if not coalesced:
|
|
return {}
|
|
|
|
# Fetch existing edges from metadata collection
|
|
all_edges = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE)
|
|
existing_map: dict[tuple[str, str, str], float] = {}
|
|
for point in all_edges:
|
|
p = point.payload
|
|
edge_key = (p.get("src", ""), p.get("dst", ""), p.get("edge_type", ""))
|
|
existing_map[edge_key] = float(p.get("weight", 0.0))
|
|
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
points_to_upsert: list[PointStruct] = []
|
|
new_weights: dict[tuple[str, str], float] = {}
|
|
|
|
for (src_str, dst_str), accum_delta in coalesced.items():
|
|
edge_key = (src_str, dst_str, edge_type)
|
|
if edge_key in existing_map:
|
|
nw = existing_map[edge_key] + accum_delta
|
|
else:
|
|
nw = accum_delta
|
|
|
|
# Create payload-only point
|
|
points_to_upsert.append(PointStruct(
|
|
id=f"{src_str}:{dst_str}:{edge_type}",
|
|
vector=None,
|
|
payload={
|
|
"table": EDGES_TABLE,
|
|
"group_id": self._group_id,
|
|
"src": src_str,
|
|
"dst": dst_str,
|
|
"edge_type": edge_type,
|
|
"weight": nw,
|
|
"updated_at": now,
|
|
},
|
|
))
|
|
new_weights[(src_str, dst_str)] = nw
|
|
|
|
if points_to_upsert:
|
|
self._client.upsert(
|
|
collection_name=METADATA_TABLE,
|
|
points=points_to_upsert,
|
|
)
|
|
|
|
return new_weights
|
|
|
|
def reinforce_record(
|
|
self,
|
|
record_id: UUID,
|
|
anchor_id: UUID | None = None,
|
|
edge_type: str = "hebbian",
|
|
delta: float = 0.1,
|
|
) -> dict[tuple[str, str], float]:
|
|
"""Single-record Hebbian reinforcement."""
|
|
if anchor_id is None:
|
|
pair = (record_id, record_id)
|
|
else:
|
|
pair = (anchor_id, record_id)
|
|
return self.boost_edges([pair], delta=delta, edge_type=edge_type)
|
|
|
|
def add_contradicts_edge(self, original: UUID, new_id: UUID) -> None:
|
|
"""Add a contradicts edge in the metadata collection (table=edges)."""
|
|
self._client.upsert(
|
|
collection_name=METADATA_TABLE,
|
|
points=[PointStruct(
|
|
id=f"{original}:{new_id}:contradicts",
|
|
vector=None,
|
|
payload={
|
|
"table": EDGES_TABLE,
|
|
"group_id": self._group_id,
|
|
"src": str(original),
|
|
"dst": str(new_id),
|
|
"edge_type": "contradicts",
|
|
"weight": 1.0,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
)],
|
|
)
|
|
|
|
# ------------------------------------------------------- async writes
|
|
|
|
async def enable_async_writes(
|
|
self,
|
|
coalesce_ms: int = 100,
|
|
max_batch: int = 128,
|
|
max_queue_size: int = 4096,
|
|
) -> None:
|
|
"""Switch insert() onto the coalescing AsyncWriteQueue."""
|
|
if self._write_queue is not None:
|
|
return
|
|
|
|
from iai_mcp.write_queue import AsyncWriteQueue
|
|
|
|
ready = threading.Event()
|
|
loop_holder: dict = {}
|
|
|
|
def _run() -> None:
|
|
loop = asyncio.new_event_loop()
|
|
loop_holder["loop"] = loop
|
|
asyncio.set_event_loop(loop)
|
|
ready.set()
|
|
try:
|
|
loop.run_forever()
|
|
finally:
|
|
loop.close()
|
|
|
|
thread = threading.Thread(
|
|
target=_run, name="iai-mcp-async-writes", daemon=True,
|
|
)
|
|
thread.start()
|
|
ready.wait()
|
|
bg_loop: asyncio.AbstractEventLoop = loop_holder["loop"]
|
|
|
|
to_point = self._to_point
|
|
fire_hook = self._fire_graph_sync_hook
|
|
|
|
class _RecordPointAdapter:
|
|
async def add(self, records: list) -> None:
|
|
points = [to_point(r) for r in records]
|
|
self._client.upsert(RECORDS_TABLE, points=points)
|
|
|
|
_client = None # set below
|
|
|
|
adapter = _RecordPointAdapter()
|
|
adapter._client = self._client
|
|
|
|
def _on_flushed(batch: list) -> None:
|
|
for rec in batch:
|
|
fire_hook("insert", rec)
|
|
|
|
queue = AsyncWriteQueue(
|
|
adapter,
|
|
coalesce_ms=coalesce_ms,
|
|
max_batch=max_batch,
|
|
max_queue_size=max_queue_size,
|
|
on_flushed=_on_flushed,
|
|
)
|
|
asyncio.run_coroutine_threadsafe(queue.start(), bg_loop).result()
|
|
|
|
self._async_loop = bg_loop
|
|
self._async_thread = thread
|
|
self._write_queue = queue
|
|
self.enable_provenance_queue()
|
|
|
|
async def disable_async_writes(self) -> None:
|
|
"""Drain the queue, tear down the background loop."""
|
|
if self._write_queue is None:
|
|
self.disable_provenance_queue()
|
|
return
|
|
self.disable_provenance_queue()
|
|
bg_loop = self._async_loop
|
|
queue = self._write_queue
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(queue.stop(), bg_loop).result()
|
|
finally:
|
|
if bg_loop is not None:
|
|
bg_loop.call_soon_threadsafe(bg_loop.stop)
|
|
if self._async_thread is not None:
|
|
self._async_thread.join(timeout=5.0)
|
|
self._write_queue = None
|
|
self._async_loop = None
|
|
self._async_thread = None
|
|
|
|
# -------------------------------------------------- provenance queue
|
|
|
|
def enable_provenance_queue(self, *, coalesce_ms: int = 50) -> None:
|
|
"""Route provenance writes through a daemon-thread queue."""
|
|
if self._provenance_queue is not None:
|
|
return
|
|
from iai_mcp.provenance_queue import ProvenanceWriteQueue
|
|
|
|
q = ProvenanceWriteQueue(self, coalesce_ms=coalesce_ms)
|
|
q.start()
|
|
self._provenance_queue = q
|
|
|
|
def disable_provenance_queue(self) -> None:
|
|
"""Drain + stop the provenance queue."""
|
|
q = self._provenance_queue
|
|
if q is None:
|
|
return
|
|
try:
|
|
q.flush(timeout=2.0)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
q.stop()
|
|
except Exception:
|
|
pass
|
|
self._provenance_queue = None
|
|
|
|
def queue_provenance_batch(
|
|
self, pairs: "list[tuple[UUID, dict]]"
|
|
) -> None:
|
|
"""Fire-and-forget provenance write."""
|
|
if not pairs:
|
|
return
|
|
q = self._provenance_queue
|
|
if q is not None:
|
|
q.enqueue(pairs)
|
|
return
|
|
self.append_provenance_batch(pairs, records_cache=None)
|
|
|
|
# ------------------------------------------------------- helpers
|
|
|
|
def _scroll_all(self, collection_name: str, batch_size: int = 1000,
|
|
table_filter: str | None = None) -> list:
|
|
"""Scroll through all points in a collection.
|
|
|
|
When table_filter is provided, adds a must condition on the `table`
|
|
payload field to select a specific logical table within the collection.
|
|
"""
|
|
offset = None
|
|
all_points = []
|
|
scroll_filter = None
|
|
if table_filter:
|
|
scroll_filter = Filter(must=[FieldCondition(
|
|
key="table", match=MatchValue(value=table_filter),
|
|
)])
|
|
while True:
|
|
points, next_offset = self._client.scroll(
|
|
collection_name=collection_name,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
scroll_filter=scroll_filter,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
all_points.extend(points)
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
return all_points
|
|
|
|
# ---------------------------------------------------------- events
|
|
|
|
def events_add(self, event_id: UUID, kind: str, severity: str, domain: str,
|
|
ts: datetime, data_json: str, session_id: str, source_ids_json: str) -> None:
|
|
"""Add a single event row to the metadata collection (table=events)."""
|
|
point = PointStruct(
|
|
id=str(event_id),
|
|
vector={},
|
|
payload={
|
|
"table": EVENTS_TABLE,
|
|
"group_id": self._group_id,
|
|
"kind": kind,
|
|
"severity": severity,
|
|
"domain": domain,
|
|
"ts": ts.isoformat(),
|
|
"data_json": data_json,
|
|
"session_id": session_id,
|
|
"source_ids_json": source_ids_json,
|
|
},
|
|
)
|
|
self._client.upsert(collection_name=METADATA_TABLE, points=[point])
|
|
|
|
def events_query(self, kind: str | None = None, since: datetime | None = None,
|
|
severity: str | None = None, limit: int = 100) -> list[dict]:
|
|
"""Query events from the metadata collection (table=events), newest first."""
|
|
# Always filter by table=events
|
|
conditions = [FieldCondition(key="table", match=MatchValue(value=EVENTS_TABLE))]
|
|
if kind is not None:
|
|
conditions.append(FieldCondition(key="kind", match=MatchValue(value=kind)))
|
|
if severity is not None:
|
|
conditions.append(FieldCondition(key="severity", match=MatchValue(value=severity)))
|
|
|
|
event_filter = Filter(must=conditions) if conditions else None
|
|
points, _ = self._client.scroll(
|
|
collection_name=METADATA_TABLE,
|
|
limit=limit * 10, # fetch extra for post-filtering
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
scroll_filter=event_filter,
|
|
)
|
|
|
|
out: list[dict] = []
|
|
for pt in points:
|
|
p = pt.payload
|
|
ts_str = p.get("ts", "")
|
|
try:
|
|
ts_dt = datetime.fromisoformat(ts_str)
|
|
except (ValueError, TypeError):
|
|
ts_dt = datetime.now(timezone.utc)
|
|
|
|
if since is not None:
|
|
since_cmp = since if since.tzinfo else since.replace(tzinfo=timezone.utc)
|
|
if ts_dt < since_cmp:
|
|
continue
|
|
|
|
raw_data = p.get("data_json") or "{}"
|
|
try:
|
|
data = json.loads(raw_data)
|
|
except (TypeError, json.JSONDecodeError):
|
|
data = {}
|
|
try:
|
|
source_ids = json.loads(p.get("source_ids_json") or "[]")
|
|
except (TypeError, json.JSONDecodeError):
|
|
source_ids = []
|
|
|
|
out.append({
|
|
"id": p.get("id", str(pt.id)),
|
|
"kind": p.get("kind", ""),
|
|
"severity": p.get("severity") or None,
|
|
"domain": p.get("domain") or None,
|
|
"ts": ts_dt,
|
|
"data": data,
|
|
"session_id": p.get("session_id", "-"),
|
|
"source_ids": source_ids,
|
|
})
|
|
|
|
out.sort(key=lambda x: x["ts"], reverse=True)
|
|
return out[:limit]
|
|
|
|
# --------------------------------------------------------- count_rows
|
|
|
|
def count_rows(self, table_name: str) -> int:
|
|
"""Return the number of points matching table_name in the appropriate collection.
|
|
|
|
- `records` → counts points in records collection with table=records
|
|
- edges/events/budget/ratelimit → counts points in metadata collection
|
|
with table=edges/events/budget_ledger/ratelimit_ledger
|
|
"""
|
|
if table_name == RECORDS_TABLE:
|
|
try:
|
|
info = self._client.get_collection(RECORDS_TABLE)
|
|
return info.points_count or 0
|
|
except Exception:
|
|
return 0
|
|
elif table_name in _METADATA_TABLES:
|
|
try:
|
|
points = self._scroll_all(METADATA_TABLE, table_filter=table_name)
|
|
return len(points)
|
|
except Exception:
|
|
return 0
|
|
else:
|
|
return 0
|
|
|
|
# ------------------------------------------------------- edges_as_df
|
|
|
|
def edges_as_dataframe(self) -> "pd.DataFrame":
|
|
"""Return all edges from the metadata collection as a pandas DataFrame."""
|
|
try:
|
|
points = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE, batch_size=1000)
|
|
if not points:
|
|
return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"])
|
|
rows = []
|
|
for pt in points:
|
|
p = pt.payload
|
|
rows.append({
|
|
"src": p.get("src", ""),
|
|
"dst": p.get("dst", ""),
|
|
"edge_type": p.get("edge_type", ""),
|
|
"weight": float(p.get("weight", 0.0)),
|
|
"updated_at": p.get("updated_at", ""),
|
|
})
|
|
return pd.DataFrame(rows)
|
|
except Exception:
|
|
return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"])
|
|
|
|
def records_as_dataframe(self) -> "pd.DataFrame":
|
|
"""Return all records from the records collection as a pandas DataFrame."""
|
|
try:
|
|
records = self.all_records()
|
|
if not records:
|
|
return pd.DataFrame(columns=[
|
|
"id", "tier", "literal_surface", "embedding",
|
|
"community_id", "centrality", "pinned",
|
|
"tags_json", "language", "aaak_index",
|
|
"stability", "difficulty", "last_reviewed",
|
|
"never_decay", "never_merge", "detail_level",
|
|
"s5_trust_score", "structure_hv",
|
|
])
|
|
rows = []
|
|
for r in records:
|
|
rows.append({
|
|
"id": str(r.id),
|
|
"tier": r.tier,
|
|
"literal_surface": r.literal_surface,
|
|
"embedding": r.embedding,
|
|
"community_id": str(r.community_id) if r.community_id else None,
|
|
"centrality": r.centrality,
|
|
"pinned": r.pinned,
|
|
"tags_json": r.tags_json if hasattr(r, "tags_json") else "[]",
|
|
"language": r.language,
|
|
"aaak_index": r.aaak_index,
|
|
"stability": r.stability,
|
|
"difficulty": r.difficulty,
|
|
"last_reviewed": str(r.last_reviewed) if r.last_reviewed else None,
|
|
"never_decay": r.never_decay,
|
|
"never_merge": r.never_merge,
|
|
"detail_level": r.detail_level,
|
|
"s5_trust_score": r.s5_trust_score,
|
|
"structure_hv": r.structure_hv.hex() if r.structure_hv else "",
|
|
})
|
|
return pd.DataFrame(rows)
|
|
except Exception:
|
|
return pd.DataFrame(columns=[
|
|
"id", "tier", "literal_surface", "embedding",
|
|
"community_id", "centrality", "pinned",
|
|
"tags_json", "language", "aaak_index",
|
|
"stability", "difficulty", "last_reviewed",
|
|
"never_decay", "never_merge", "detail_level",
|
|
"s5_trust_score", "structure_hv",
|
|
])
|
|
|
|
# ------------------------------------------------------------------ db shim
|
|
|
|
class _DbShim:
|
|
"""LanceDB-compatible shim: store.db.open_table(name) -> table-like object."""
|
|
|
|
def __init__(self, store: "QdrantStore") -> None:
|
|
self._store = store
|
|
|
|
def open_table(self, name: str) -> "QdrantStore._TableShim":
|
|
return self._TableShim(self._store, name)
|
|
|
|
class _TableShim:
|
|
def __init__(self, store: "QdrantStore", name: str) -> None:
|
|
self._store = store
|
|
self._name = name
|
|
|
|
def count_rows(self) -> int:
|
|
return self._store.count_rows(self._name)
|
|
|
|
def to_pandas(self) -> pd.DataFrame:
|
|
if self._name == EDGES_TABLE:
|
|
return self._store.edges_as_dataframe()
|
|
elif self._name == RECORDS_TABLE:
|
|
return self._store.records_as_dataframe()
|
|
elif self._name == EVENTS_TABLE:
|
|
return pd.DataFrame()
|
|
return pd.DataFrame()
|
|
|
|
@property
|
|
def db(self) -> "QdrantStore._DbShim":
|
|
"""LanceDB-compatible shim for store.db.open_table()."""
|
|
return self._DbShim(self)
|