iai-mcp-opencode/src/iai_mcp/qdrant_store.py

1322 lines
47 KiB
Python
Raw Normal View History

2026-05-12 16:45:15 +02:00
"""Qdrant-backed persistent memory store (D-01 storage engine, sync write).
Replaces LanceDB with Qdrant for remote vector storage.
Collections (payload-partitioned, per Qdrant best practices):
- `records` : MemoryRecord rows (1024-dim cosine vectors)
All points carry `table: "records"` + `group_id` payload.
- `metadata` : Payload-only (no vectors) containing edges, events,
budget_ledger, ratelimit_ledger.
Each point carries `table` + `group_id` payload.
Both collections use `is_tenant: true` payload indexes on `table` for
co-located storage and per-table HNSW graphs (payload_m=16, m=0).
Encryption-at-rest is identical to LanceDB: AES-256-GCM on
literal_surface / provenance_json / profile_modulation_gain_json (records)
and data_json (events). AD = record UUID bytes.
The Qdrant client connects to a remote server via QDRANT_URL + QDRANT_API_KEY
environment variables (overridable via constructor).
"""
from __future__ import annotations
import asyncio
import base64
import functools
import json
import os
import re
import sys
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable
from uuid import UUID
from qdrant_client import QdrantClient, models
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
PayloadSchemaType,
PointStruct,
VectorParams,
)
from qdrant_client.http.exceptions import UnexpectedResponse
import pandas as pd
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from iai_mcp.crypto import (
CIPHERTEXT_PREFIX,
NONCE_BYTES,
CryptoKey,
encrypt_field,
is_encrypted,
)
from iai_mcp.types import (
DEFAULT_EMBED_DIM,
EMBED_DIM,
SCHEMA_VERSION_CURRENT,
MemoryRecord,
TIER_ENUM,
)
# --------------------------------------------------------------------------- env
# Qdrant connection: override with env vars or constructor kwargs.
QDRANT_URL = os.environ.get("QDRANT_URL")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
# --------------------------------------------------------------------------- tables
RECORDS_TABLE = "records"
EDGES_TABLE = "edges"
EVENTS_TABLE = "events"
BUDGET_TABLE = "budget_ledger"
RATELIMIT_TABLE = "ratelimit_ledger"
METADATA_TABLE = "metadata"
# Embedding dimension for the records collection.
# Defaults to 1024 for bge-m3 (remote embedder). Overridable via
# IAI_MCP_EMBED_DIM env var.
_RECORDS_DIM = int(os.environ.get("IAI_MCP_EMBED_DIM", "1024"))
# Metadata tables that live in the metadata collection.
_METADATA_TABLES = frozenset({EDGES_TABLE, EVENTS_TABLE, BUDGET_TABLE, RATELIMIT_TABLE})
# Edge type enum.
EDGE_TYPES: frozenset[str] = frozenset({
"hebbian",
"contradicts",
"consolidated_from",
"schema_instance_of",
"temporal_next",
"invariant_anchor",
"curiosity_bridge",
"profile_modulates",
"hebbian_structure",
})
# RFC-4122 canonical UUID regex.
_UUID_STR_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
)
def _uuid_literal(value: UUID | str) -> str:
"""Return a Qdrant WHERE-safe UUID literal."""
s = str(value).lower()
if not _UUID_STR_RE.match(s):
raise ValueError(f"not a canonical UUID: {value!r}")
return s
def _resolve_embed_dim() -> int:
"""Pick the embedding dimension for the records collection."""
env_dim = os.environ.get("IAI_MCP_EMBED_DIM")
if env_dim:
try:
return int(env_dim)
except ValueError:
pass
env_key = os.environ.get("IAI_MCP_EMBED_MODEL")
# bge-m3 -> 1024
if env_key == "bge-m3":
return 1024
return _RECORDS_DIM
class QdrantStore:
"""Qdrant-backed memory store.
Mirrors the MemoryStore API so all existing callers (daemon, MCP wrapper,
capture, retrieve, sleep, etc.) work unchanged.
"""
def __init__(
self,
path: Path | str | None = None,
user_id: str = "default",
read_consistency_interval: float | None = None,
url: str | None = None,
api_key: str | None = None,
) -> None:
"""Open (or initialise) a Qdrant-backed store.
Parameters
----------
url:
Qdrant server URL. Defaults to QDRANT_URL env var.
api_key:
Qdrant API key. Defaults to QDRANT_API_KEY env var.
user_id:
User ID that scopes the encryption key. Also used as group_id
for payload partitioning (multi-tenant ready).
read_consistency_interval:
Reserved for future consistency tuning (not used with Qdrant).
"""
self._embed_dim: int = _resolve_embed_dim()
self._user_id: str = user_id
self._group_id: str = user_id # payload partition key
self._read_consistency_interval = read_consistency_interval
# Qdrant client
self._qdrant_url = url or QDRANT_URL
self._qdrant_api_key = api_key or QDRANT_API_KEY
self._client = QdrantClient(
url=self._qdrant_url,
api_key=self._qdrant_api_key,
timeout=30,
)
# Encryption
self._crypto_key_wrapper = CryptoKey(user_id=user_id)
self._crypto_key: bytes | None = None
# Graph sync hook
self._graph_sync_hook: Callable[[str, MemoryRecord], None] | None = None
# Async write queue (same contract as MemoryStore)
self._write_queue = None
self._async_loop: asyncio.AbstractEventLoop | None = None
self._async_thread: threading.Thread | None = None
self._provenance_queue = None
# Ensure collections exist
self._ensure_collections()
# ------------------------------------------------------------------ schema
def _ensure_collections(self) -> None:
"""Create 2 collections (per Qdrant best practices: payload partitioning).
- `records`: 1024-dim cosine vectors for MemoryRecord rows.
- `metadata`: payload-only (no vectors) for edges/events/budget/ratelimit.
Both use `is_tenant: true` indexes on `table` for co-located storage
and per-table HNSW graphs (payload_m=16, m=0).
"""
# Collection 1: records (vectors)
if not self._collection_exists(RECORDS_TABLE):
try:
self._client.create_collection(
collection_name=RECORDS_TABLE,
vectors_config=VectorParams(
size=self._embed_dim,
distance=Distance.COSINE,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # per-table HNSW graph
m=0, # no global index (all data partitioned by table)
),
)
except UnexpectedResponse as e:
if e.status_code == 409:
pass # exists
elif e.status_code == 401:
raise
else:
raise
# Collection 2: metadata (payload-only)
if not self._collection_exists(METADATA_TABLE):
try:
self._client.create_collection(
collection_name=METADATA_TABLE,
# No vectors_config — payload-only collection
)
except UnexpectedResponse as e:
if e.status_code == 409:
pass # exists
elif e.status_code == 401:
raise
else:
raise
# Ensure payload indexes for filtered queries
self._ensure_indexes()
def _collection_exists(self, name: str) -> bool:
"""Check if a collection exists."""
try:
info = self._client.get_collection(name)
return info is not None
except UnexpectedResponse as e:
if e.status_code == 401:
raise RuntimeError(
"Qdrant 401: invalid API key. Set QDRANT_API_KEY env var."
) from e
# Collection doesn't exist or other error
return False
except Exception:
return False
def _ensure_indexes(self) -> None:
"""Create payload indexes for common filter fields.
Both collections get `table` and `group_id` as tenant indexes
(is_tenant=true) for co-located storage and per-table HNSW graphs.
"""
# Tenant indexes — critical for co-located storage
for collection_name in (RECORDS_TABLE, METADATA_TABLE):
for field_name, is_tenant in [("table", True), ("group_id", True)]:
try:
self._client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.KEYWORD,
)
# Note: Qdrant API doesn't expose is_tenant via create_payload_index;
# it's set via update_collection_payload_profiles API.
# For now, we create the keyword index; is_tenant optimization
# is handled by the application-level table filtering.
except Exception:
pass
# Collection-specific indexes
indexes = {
RECORDS_TABLE: [
("id", PayloadSchemaType.KEYWORD),
("tier", PayloadSchemaType.KEYWORD),
("community_id", PayloadSchemaType.KEYWORD),
],
METADATA_TABLE: [
("edge_type", PayloadSchemaType.KEYWORD),
("kind", PayloadSchemaType.KEYWORD),
("session_id", PayloadSchemaType.KEYWORD),
],
}
for tbl_name, fields in indexes.items():
for field_name, schema_type in fields:
try:
self._client.create_payload_index(
collection_name=tbl_name,
field_name=field_name,
field_schema=schema_type,
)
except Exception:
pass
@property
def embed_dim(self) -> int:
"""Actual embedding dimension in the records collection."""
return self._embed_dim
@property
def user_id(self) -> str:
"""user_id that scopes the encryption key."""
return self._user_id
# ------------------------------------------------------------------ encryption
def _key(self) -> bytes:
"""Lazy-load the encryption key."""
if self._crypto_key is None:
self._crypto_key = self._crypto_key_wrapper.get_or_create()
return self._crypto_key
def _ad(self, record_id: UUID | str) -> bytes:
"""Associated data for encryption: canonical UUID str bytes."""
return _uuid_literal(record_id).encode("ascii")
def _encrypt_for_record(self, record_id: UUID, value: str) -> str:
"""Encrypt a per-record sensitive field."""
if is_encrypted(value):
return value
return encrypt_field(value, self._key(), associated_data=self._ad(record_id))
@functools.cached_property
def _cached_aesgcm(self) -> AESGCM:
"""One AESGCM cipher per store lifetime."""
return AESGCM(self._key())
def _decrypt_for_record(self, record_id: UUID, value: str) -> str:
"""Decrypt a per-record sensitive field; pass through plaintext."""
if not is_encrypted(value):
return value
if not value.startswith(CIPHERTEXT_PREFIX):
raise ValueError("field is not iai:enc:v1:-prefixed ciphertext")
payload_b64 = value[len(CIPHERTEXT_PREFIX):]
payload = base64.b64decode(payload_b64)
if len(payload) < NONCE_BYTES + 16:
raise ValueError("ciphertext payload too short")
nonce = payload[:NONCE_BYTES]
ct_with_tag = payload[NONCE_BYTES:]
associated_data = self._ad(record_id)
plaintext_bytes = self._cached_aesgcm.decrypt(
nonce, ct_with_tag, associated_data or None
)
return plaintext_bytes.decode("utf-8")
# ------------------------------------------------------------------ I/O
def register_graph_sync_hook(
self, hook: Callable[[str, MemoryRecord], None] | None
) -> None:
"""Register a callback that mirrors store writes to the runtime graph."""
self._graph_sync_hook = hook
def _fire_graph_sync_hook(self, op: str, record: MemoryRecord) -> None:
"""Dispatch the (op, record) event. Failures are swallowed."""
hook = self._graph_sync_hook
if hook is None:
return
try:
hook(op, record)
except Exception as exc:
try:
sys.stderr.write(
json.dumps({
"event": "graph_sync_failed",
"op": op,
"record_id": str(getattr(record, "id", "")),
"error": str(exc),
"ts": datetime.now(timezone.utc).isoformat(),
})
+ "\n"
)
except Exception:
pass
# ------------------------------------------------------- conversions
def _to_point(self, r: MemoryRecord) -> PointStruct:
"""Convert a MemoryRecord to a Qdrant PointStruct."""
literal_ct = self._encrypt_for_record(r.id, r.literal_surface)
provenance_plain = json.dumps(r.provenance)
provenance_ct = self._encrypt_for_record(r.id, provenance_plain)
gain_plain = json.dumps(r.profile_modulation_gain or {})
gain_ct = self._encrypt_for_record(r.id, gain_plain)
return PointStruct(
id=str(r.id),
vector=[float(x) for x in r.embedding],
payload={
"table": RECORDS_TABLE,
"group_id": self._group_id,
"tier": r.tier,
"literal_surface": literal_ct,
"aaak_index": r.aaak_index or "",
"structure_hv": base64.b64encode(
bytes(r.structure_hv or b"")
).decode("ascii"),
"community_id": str(r.community_id) if r.community_id else "",
"centrality": float(r.centrality),
"detail_level": int(r.detail_level),
"pinned": bool(r.pinned),
"stability": float(r.stability),
"difficulty": float(r.difficulty),
"last_reviewed": (
r.last_reviewed.isoformat() if r.last_reviewed else None
),
"never_decay": bool(r.never_decay),
"never_merge": bool(r.never_merge),
"provenance_json": provenance_ct,
"created_at": r.created_at.isoformat() if r.created_at else None,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
"tags_json": json.dumps(r.tags),
"language": str(r.language),
"s5_trust_score": float(r.s5_trust_score),
"profile_modulation_gain_json": gain_ct,
"schema_version": int(r.schema_version),
},
)
def _from_point(self, point: PointStruct) -> MemoryRecord:
"""Convert a Qdrant PointStruct to a MemoryRecord."""
from uuid import UUID as _UUID
payload = point.payload
if not payload:
raise ValueError(f"point {point.id} has no payload")
row_uuid = _UUID(point.id)
# structure_hv: base64-decoded bytes
structure_raw = payload.get("structure_hv")
if structure_raw:
try:
structure_hv = base64.b64decode(structure_raw)
except Exception:
structure_hv = b""
else:
structure_hv = b""
community_raw = payload.get("community_id") or ""
community_id = _UUID(community_raw) if community_raw else None
# language back-compat
lang_raw = payload.get("language")
version_raw = payload.get("schema_version")
try:
version_int = int(version_raw) if version_raw is not None else SCHEMA_VERSION_CURRENT
except (TypeError, ValueError):
version_int = SCHEMA_VERSION_CURRENT
schema_version = version_int
is_empty_language = lang_raw is None or (isinstance(lang_raw, str) and lang_raw == "")
if is_empty_language and schema_version == 1:
language = "__LEGACY_EMPTY__"
elif is_empty_language:
language = "en"
else:
language = str(lang_raw)
s5_raw = payload.get("s5_trust_score")
s5_trust_score = float(s5_raw) if s5_raw is not None else 0.5
# decrypt profile_modulation_gain_json
gain_raw = payload.get("profile_modulation_gain_json") or "{}"
if is_encrypted(gain_raw):
gain_raw = self._decrypt_for_record(row_uuid, gain_raw)
try:
profile_modulation_gain = json.loads(gain_raw) or {}
except (TypeError, json.JSONDecodeError):
profile_modulation_gain = {}
# decrypt literal_surface + provenance_json
literal_raw = payload.get("literal_surface", "")
if is_encrypted(literal_raw):
literal_raw = self._decrypt_for_record(row_uuid, literal_raw)
provenance_raw = payload.get("provenance_json") or "[]"
if is_encrypted(provenance_raw):
provenance_raw = self._decrypt_for_record(row_uuid, provenance_raw)
try:
provenance_list = json.loads(provenance_raw) if provenance_raw else []
except (TypeError, json.JSONDecodeError):
provenance_list = []
def _parse_ts(val):
if val is None:
return None
try:
return datetime.fromisoformat(val)
except (TypeError, ValueError):
return None
rec = MemoryRecord(
id=row_uuid,
tier=payload.get("tier", "episodic"),
literal_surface=literal_raw,
aaak_index=payload.get("aaak_index") or "",
embedding=list(point.vector) if point.vector else [],
community_id=community_id,
centrality=float(payload.get("centrality", 0.0) or 0.0),
detail_level=int(payload.get("detail_level", 1)),
pinned=bool(payload.get("pinned", False)),
stability=float(payload.get("stability") or 0.0),
difficulty=float(payload.get("difficulty") or 0.0),
last_reviewed=_parse_ts(payload.get("last_reviewed")),
never_decay=bool(payload.get("never_decay", False)),
never_merge=bool(payload.get("never_merge", False)),
provenance=provenance_list,
created_at=_parse_ts(payload.get("created_at")) or datetime.now(timezone.utc),
updated_at=_parse_ts(payload.get("updated_at")) or datetime.now(timezone.utc),
tags=json.loads(payload.get("tags_json") or "[]"),
language=language,
s5_trust_score=s5_trust_score,
profile_modulation_gain=profile_modulation_gain,
schema_version=schema_version,
structure_hv=structure_hv,
)
if language == "__LEGACY_EMPTY__":
rec.language = ""
return rec
# ------------------------------------------------------- point operations
def _make_filter(self, id_str: str) -> Filter:
"""Create a filter for a single record ID."""
return Filter(must=[FieldCondition(
key="id",
match=MatchValue(value=id_str),
)])
def _make_tier_filter(self, tier: str) -> Filter:
"""Create a filter for a tier value."""
return Filter(must=[FieldCondition(
key="tier",
match=MatchValue(value=tier),
)])
def _make_combined_filter(self, *conditions: FieldCondition) -> Filter:
"""Combine multiple field conditions with AND."""
return Filter(must=list(conditions))
# ------------------------------------------------------- writes
def insert(self, record: MemoryRecord) -> None:
"""Append a record. verbatim, no rewrite at write time."""
if record.tier not in TIER_ENUM:
raise ValueError(f"invalid tier {record.tier!r}")
if len(record.embedding) != self._embed_dim:
raise ValueError(
f"embedding must be {self._embed_dim}d, got {len(record.embedding)}"
)
# lazy structure_hv fill
if not record.structure_hv:
from iai_mcp.tem import bind_structure
record.structure_hv = bind_structure(record)
point = self._to_point(record)
self._client.upsert(
collection_name=RECORDS_TABLE,
points=[point],
)
self._fire_graph_sync_hook("insert", record)
def update(self, record: MemoryRecord) -> None:
"""Full-record update (rewrites core columns)."""
if len(record.embedding) != self._embed_dim:
raise ValueError(
f"embedding must be {self._embed_dim}d, got {len(record.embedding)}"
)
# Check existence first
points = self._client.retrieve(
collection_name=RECORDS_TABLE,
ids=[str(record.id)],
)
if not points:
return
point = self._to_point(record)
self._client.upsert(
collection_name=RECORDS_TABLE,
points=[point],
)
self._fire_graph_sync_hook("update", record)
def delete(self, record_id: UUID) -> None:
"""Remove a record by id."""
try:
self._client.delete(
collection_name=RECORDS_TABLE,
points_selector=models.PointIdsList(
points=[str(record_id)],
),
)
except Exception:
return
class _DeleteShim:
def __init__(self, rid):
self.id = rid
self._fire_graph_sync_hook("delete", _DeleteShim(record_id))
def get(self, record_id: UUID) -> MemoryRecord | None:
"""Point read by ID."""
points = self._client.retrieve(
collection_name=RECORDS_TABLE,
ids=[str(record_id)],
with_payload=True,
)
if not points:
return None
return self._from_point(points[0])
def all_records(self) -> list[MemoryRecord]:
"""Full table scan of the records collection."""
offset = None
all_points = []
table_filter = Filter(must=[FieldCondition(
key="table", match=MatchValue(value=RECORDS_TABLE),
)])
while True:
points, next_offset = self._client.scroll(
collection_name=RECORDS_TABLE,
limit=1000,
offset=offset,
scroll_filter=table_filter,
with_payload=True,
)
all_points.extend(points)
if next_offset is None:
break
offset = next_offset
return [self._from_point(p) for p in all_points]
def iter_records(
self,
*,
columns: list[str] | None = None,
batch_size: int = 1024,
where: str | None = None,
):
"""Streaming iterator over records (filtered by table=records)."""
offset = None
while True:
# Build filter: always include table=records, optionally add tier
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
if where and where.startswith("tier = "):
tier = where.split("'")[1]
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
qdrant_filter = Filter(must=conditions) if conditions else None
points, next_offset = self._client.scroll(
collection_name=RECORDS_TABLE,
limit=batch_size,
offset=offset,
scroll_filter=qdrant_filter,
with_payload=True,
)
for point in points:
yield self._from_point(point)
if next_offset is None:
break
offset = next_offset
def iter_record_columns(
self,
columns: list[str],
*,
batch_size: int = 1024,
where: str | None = None,
):
"""Projection-only iteration; no MemoryRecord, no decrypt.
Filtered by table=records.
"""
if not columns:
raise ValueError("iter_record_columns requires a non-empty columns list")
offset = None
while True:
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
if where and where.startswith("tier = "):
tier = where.split("'")[1]
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
qdrant_filter = Filter(must=conditions) if conditions else None
points, next_offset = self._client.scroll(
collection_name=RECORDS_TABLE,
limit=batch_size,
offset=offset,
scroll_filter=qdrant_filter,
with_payload=True,
with_vectors=False,
)
for point in points:
# Filter to requested columns
row = {k: v for k, v in point.payload.items() if k in columns}
row["id"] = point.id
yield row
if next_offset is None:
break
offset = next_offset
def query_similar(
self,
vec: list[float],
k: int = 10,
tier: str | None = None,
) -> list[tuple[MemoryRecord, float]]:
"""Cosine-distance kNN search on the records collection."""
if tier is not None and tier not in TIER_ENUM:
raise ValueError(
f"invalid tier {tier!r}; must be one of {sorted(TIER_ENUM)}"
)
# Build query filter: must be in records collection + optional tier
conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))]
if tier is not None:
conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier)))
qdrant_filter = Filter(must=conditions)
# Check if collection is empty
try:
info = self._client.get_collection(RECORDS_TABLE)
if info.points_count == 0:
return []
except Exception:
return []
# Vector search
search_result = self._client.query_points(
collection_name=RECORDS_TABLE,
query=vec,
limit=k,
query_filter=qdrant_filter,
with_payload=True,
with_vectors=False,
).points
out: list[tuple[MemoryRecord, float]] = []
for point in search_result:
record = self._from_point(point)
# Qdrant returns score as similarity (1.0 = identical, 0.0 = orthogonal)
score = float(point.score) if point.score else 0.0
out.append((record, score))
return out
def update_record(self, record: MemoryRecord) -> None:
"""Persist FSRS-relevant columns back to the records collection."""
# Retrieve existing point
points = self._client.retrieve(
collection_name=RECORDS_TABLE,
ids=[str(record.id)],
with_payload=True,
)
if not points:
return
# Update only FSRS columns via patch
self._client.set_payload(
collection_name=RECORDS_TABLE,
payload={
"stability": float(record.stability),
"difficulty": float(record.difficulty),
"last_reviewed": record.last_reviewed.isoformat() if record.last_reviewed else None,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
payload_selector=[str(record.id)],
)
# -------------------------------------------------------- reconsolidation
def append_provenance(self, record_id: UUID, entry: dict) -> None:
"""Append a provenance entry to the record."""
# Retrieve existing point
points = self._client.retrieve(
collection_name=RECORDS_TABLE,
ids=[str(record_id)],
with_payload=True,
)
if not points:
return
payload = points[0].payload
raw = payload.get("provenance_json") or "[]"
if is_encrypted(raw):
raw = self._decrypt_for_record(record_id, raw)
try:
existing = json.loads(raw)
except (TypeError, json.JSONDecodeError):
existing = []
existing.append(entry)
new_plain = json.dumps(existing)
new_ct = self._encrypt_for_record(record_id, new_plain)
# Update provenance_json and updated_at
self._client.set_payload(
collection_name=RECORDS_TABLE,
payload={
"provenance_json": new_ct,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
payload_selector=[str(record_id)],
)
def append_provenance_batch(
self, pairs: "list[tuple[UUID, dict]]",
records_cache: "dict | None" = None,
) -> None:
"""Batched provenance append."""
if not pairs:
return
# Group entries by record_id
from collections import defaultdict
grouped: dict[str, list[dict]] = defaultdict(list)
for rid, entry in pairs:
grouped[str(rid)].append(entry)
now = datetime.now(timezone.utc).isoformat()
updates: dict[str, str] = {}
if records_cache is not None:
for rid_str, entries in grouped.items():
try:
canonical = _uuid_literal(rid_str)
except ValueError:
continue
try:
rec = records_cache.get(UUID(rid_str))
except (TypeError, ValueError):
rec = None
if rec is None:
rec = records_cache.get(rid_str)
if rec is None:
continue
existing = list(rec.provenance or [])
existing.extend(entries)
new_plain = json.dumps(existing)
new_ct = self._encrypt_for_record(UUID(rid_str), new_plain)
updates[canonical] = new_ct
else:
# Retrieve all affected points
ids_to_fetch = list(grouped.keys())
try:
points = self._client.retrieve(
collection_name=RECORDS_TABLE,
ids=ids_to_fetch,
with_payload=True,
)
except Exception:
return
for point in points:
rid_str = point.id
if rid_str not in grouped:
continue
entries = grouped[rid_str]
payload = point.payload
raw_prov = payload.get("provenance_json") or "[]"
if is_encrypted(raw_prov):
try:
raw_prov = self._decrypt_for_record(UUID(rid_str), raw_prov)
except Exception:
raw_prov = "[]"
try:
existing = json.loads(raw_prov)
except (TypeError, json.JSONDecodeError):
existing = []
existing.extend(entries)
new_plain = json.dumps(existing)
new_ct = self._encrypt_for_record(UUID(rid_str), new_plain)
updates[rid_str] = new_ct
if not updates:
return
# Batch update provenance_json + updated_at
# Qdrant doesn't have merge_insert, so we do individual set_payload calls
for rid_str, new_ct in updates.items():
try:
self._client.set_payload(
collection_name=RECORDS_TABLE,
payload={
"provenance_json": new_ct,
"updated_at": now,
},
payload_selector=[rid_str],
)
except Exception:
continue
# ------------------------------------------------------------------ edges
def boost_edges(
self,
pairs: list[tuple[UUID, UUID]],
delta: float | list[float] = 0.1,
edge_type: str = "hebbian",
) -> dict[tuple[str, str], float]:
"""Pairwise edge boost in the metadata collection (table=edges)."""
if edge_type not in EDGE_TYPES:
raise ValueError(
f"invalid edge_type {edge_type!r}; must be one of {sorted(EDGE_TYPES)}"
)
if isinstance(delta, (int, float)):
deltas = [float(delta)] * len(pairs)
else:
deltas = [float(d) for d in delta]
if len(deltas) != len(pairs):
raise ValueError(
f"deltas length {len(deltas)} != pairs length {len(pairs)}"
)
if not pairs:
return {}
# Coalesce duplicate canonical keys
coalesced: dict[tuple[str, str], float] = {}
for (a, b), d in zip(pairs, deltas):
key = (str(a), str(b))
canonical = tuple(sorted(key))
coalesced[canonical] = coalesced.get(canonical, 0.0) + d
if not coalesced:
return {}
# Fetch existing edges from metadata collection
all_edges = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE)
existing_map: dict[tuple[str, str, str], float] = {}
for point in all_edges:
p = point.payload
edge_key = (p.get("src", ""), p.get("dst", ""), p.get("edge_type", ""))
existing_map[edge_key] = float(p.get("weight", 0.0))
now = datetime.now(timezone.utc).isoformat()
points_to_upsert: list[PointStruct] = []
new_weights: dict[tuple[str, str], float] = {}
for (src_str, dst_str), accum_delta in coalesced.items():
edge_key = (src_str, dst_str, edge_type)
if edge_key in existing_map:
nw = existing_map[edge_key] + accum_delta
else:
nw = accum_delta
# Create payload-only point
points_to_upsert.append(PointStruct(
id=f"{src_str}:{dst_str}:{edge_type}",
vector=None,
payload={
"table": EDGES_TABLE,
"group_id": self._group_id,
"src": src_str,
"dst": dst_str,
"edge_type": edge_type,
"weight": nw,
"updated_at": now,
},
))
new_weights[(src_str, dst_str)] = nw
if points_to_upsert:
self._client.upsert(
collection_name=METADATA_TABLE,
points=points_to_upsert,
)
return new_weights
def reinforce_record(
self,
record_id: UUID,
anchor_id: UUID | None = None,
edge_type: str = "hebbian",
delta: float = 0.1,
) -> dict[tuple[str, str], float]:
"""Single-record Hebbian reinforcement."""
if anchor_id is None:
pair = (record_id, record_id)
else:
pair = (anchor_id, record_id)
return self.boost_edges([pair], delta=delta, edge_type=edge_type)
def add_contradicts_edge(self, original: UUID, new_id: UUID) -> None:
"""Add a contradicts edge in the metadata collection (table=edges)."""
self._client.upsert(
collection_name=METADATA_TABLE,
points=[PointStruct(
id=f"{original}:{new_id}:contradicts",
vector=None,
payload={
"table": EDGES_TABLE,
"group_id": self._group_id,
"src": str(original),
"dst": str(new_id),
"edge_type": "contradicts",
"weight": 1.0,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
)],
)
# ------------------------------------------------------- async writes
async def enable_async_writes(
self,
coalesce_ms: int = 100,
max_batch: int = 128,
max_queue_size: int = 4096,
) -> None:
"""Switch insert() onto the coalescing AsyncWriteQueue."""
if self._write_queue is not None:
return
from iai_mcp.write_queue import AsyncWriteQueue
ready = threading.Event()
loop_holder: dict = {}
def _run() -> None:
loop = asyncio.new_event_loop()
loop_holder["loop"] = loop
asyncio.set_event_loop(loop)
ready.set()
try:
loop.run_forever()
finally:
loop.close()
thread = threading.Thread(
target=_run, name="iai-mcp-async-writes", daemon=True,
)
thread.start()
ready.wait()
bg_loop: asyncio.AbstractEventLoop = loop_holder["loop"]
to_point = self._to_point
fire_hook = self._fire_graph_sync_hook
class _RecordPointAdapter:
async def add(self, records: list) -> None:
points = [to_point(r) for r in records]
self._client.upsert(RECORDS_TABLE, points=points)
_client = None # set below
adapter = _RecordPointAdapter()
adapter._client = self._client
def _on_flushed(batch: list) -> None:
for rec in batch:
fire_hook("insert", rec)
queue = AsyncWriteQueue(
adapter,
coalesce_ms=coalesce_ms,
max_batch=max_batch,
max_queue_size=max_queue_size,
on_flushed=_on_flushed,
)
asyncio.run_coroutine_threadsafe(queue.start(), bg_loop).result()
self._async_loop = bg_loop
self._async_thread = thread
self._write_queue = queue
self.enable_provenance_queue()
async def disable_async_writes(self) -> None:
"""Drain the queue, tear down the background loop."""
if self._write_queue is None:
self.disable_provenance_queue()
return
self.disable_provenance_queue()
bg_loop = self._async_loop
queue = self._write_queue
try:
asyncio.run_coroutine_threadsafe(queue.stop(), bg_loop).result()
finally:
if bg_loop is not None:
bg_loop.call_soon_threadsafe(bg_loop.stop)
if self._async_thread is not None:
self._async_thread.join(timeout=5.0)
self._write_queue = None
self._async_loop = None
self._async_thread = None
# -------------------------------------------------- provenance queue
def enable_provenance_queue(self, *, coalesce_ms: int = 50) -> None:
"""Route provenance writes through a daemon-thread queue."""
if self._provenance_queue is not None:
return
from iai_mcp.provenance_queue import ProvenanceWriteQueue
q = ProvenanceWriteQueue(self, coalesce_ms=coalesce_ms)
q.start()
self._provenance_queue = q
def disable_provenance_queue(self) -> None:
"""Drain + stop the provenance queue."""
q = self._provenance_queue
if q is None:
return
try:
q.flush(timeout=2.0)
except Exception:
pass
try:
q.stop()
except Exception:
pass
self._provenance_queue = None
def queue_provenance_batch(
self, pairs: "list[tuple[UUID, dict]]"
) -> None:
"""Fire-and-forget provenance write."""
if not pairs:
return
q = self._provenance_queue
if q is not None:
q.enqueue(pairs)
return
self.append_provenance_batch(pairs, records_cache=None)
# ------------------------------------------------------- helpers
def _scroll_all(self, collection_name: str, batch_size: int = 1000,
table_filter: str | None = None) -> list:
"""Scroll through all points in a collection.
When table_filter is provided, adds a must condition on the `table`
payload field to select a specific logical table within the collection.
"""
offset = None
all_points = []
scroll_filter = None
if table_filter:
scroll_filter = Filter(must=[FieldCondition(
key="table", match=MatchValue(value=table_filter),
)])
while True:
points, next_offset = self._client.scroll(
collection_name=collection_name,
limit=batch_size,
offset=offset,
scroll_filter=scroll_filter,
with_payload=True,
with_vectors=False,
)
all_points.extend(points)
if next_offset is None:
break
offset = next_offset
return all_points
# ---------------------------------------------------------- events
def events_add(self, event_id: UUID, kind: str, severity: str, domain: str,
ts: datetime, data_json: str, session_id: str, source_ids_json: str) -> None:
"""Add a single event row to the metadata collection (table=events)."""
point = PointStruct(
id=str(event_id),
vector={},
payload={
"table": EVENTS_TABLE,
"group_id": self._group_id,
"kind": kind,
"severity": severity,
"domain": domain,
"ts": ts.isoformat(),
"data_json": data_json,
"session_id": session_id,
"source_ids_json": source_ids_json,
},
)
self._client.upsert(collection_name=METADATA_TABLE, points=[point])
def events_query(self, kind: str | None = None, since: datetime | None = None,
severity: str | None = None, limit: int = 100) -> list[dict]:
"""Query events from the metadata collection (table=events), newest first."""
# Always filter by table=events
conditions = [FieldCondition(key="table", match=MatchValue(value=EVENTS_TABLE))]
if kind is not None:
conditions.append(FieldCondition(key="kind", match=MatchValue(value=kind)))
if severity is not None:
conditions.append(FieldCondition(key="severity", match=MatchValue(value=severity)))
event_filter = Filter(must=conditions) if conditions else None
points, _ = self._client.scroll(
collection_name=METADATA_TABLE,
limit=limit * 10, # fetch extra for post-filtering
with_payload=True,
with_vectors=False,
scroll_filter=event_filter,
)
out: list[dict] = []
for pt in points:
p = pt.payload
ts_str = p.get("ts", "")
try:
ts_dt = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
ts_dt = datetime.now(timezone.utc)
if since is not None:
since_cmp = since if since.tzinfo else since.replace(tzinfo=timezone.utc)
if ts_dt < since_cmp:
continue
raw_data = p.get("data_json") or "{}"
try:
data = json.loads(raw_data)
except (TypeError, json.JSONDecodeError):
data = {}
try:
source_ids = json.loads(p.get("source_ids_json") or "[]")
except (TypeError, json.JSONDecodeError):
source_ids = []
out.append({
"id": p.get("id", str(pt.id)),
"kind": p.get("kind", ""),
"severity": p.get("severity") or None,
"domain": p.get("domain") or None,
"ts": ts_dt,
"data": data,
"session_id": p.get("session_id", "-"),
"source_ids": source_ids,
})
out.sort(key=lambda x: x["ts"], reverse=True)
return out[:limit]
# --------------------------------------------------------- count_rows
def count_rows(self, table_name: str) -> int:
"""Return the number of points matching table_name in the appropriate collection.
- `records` counts points in records collection with table=records
- edges/events/budget/ratelimit counts points in metadata collection
with table=edges/events/budget_ledger/ratelimit_ledger
"""
if table_name == RECORDS_TABLE:
try:
info = self._client.get_collection(RECORDS_TABLE)
return info.points_count or 0
except Exception:
return 0
elif table_name in _METADATA_TABLES:
try:
points = self._scroll_all(METADATA_TABLE, table_filter=table_name)
return len(points)
except Exception:
return 0
else:
return 0
# ------------------------------------------------------- edges_as_df
def edges_as_dataframe(self) -> "pd.DataFrame":
"""Return all edges from the metadata collection as a pandas DataFrame."""
try:
points = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE, batch_size=1000)
if not points:
return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"])
rows = []
for pt in points:
p = pt.payload
rows.append({
"src": p.get("src", ""),
"dst": p.get("dst", ""),
"edge_type": p.get("edge_type", ""),
"weight": float(p.get("weight", 0.0)),
"updated_at": p.get("updated_at", ""),
})
return pd.DataFrame(rows)
except Exception:
return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"])
# ------------------------------------------------------------------ db shim
class _DbShim:
"""LanceDB-compatible shim: store.db.open_table(name) -> table-like object."""
def __init__(self, store: "QdrantStore") -> None:
self._store = store
def open_table(self, name: str) -> "QdrantStore._TableShim":
return self._TableShim(self._store, name)
class _TableShim:
def __init__(self, store: "QdrantStore", name: str) -> None:
self._store = store
self._name = name
def count_rows(self) -> int:
return self._store.count_rows(self._name)
def to_pandas(self) -> pd.DataFrame:
if self._name == EDGES_TABLE:
return self._store.edges_as_dataframe()
elif self._name == EVENTS_TABLE:
return pd.DataFrame()
return pd.DataFrame()
@property
def db(self) -> "QdrantStore._DbShim":
"""LanceDB-compatible shim for store.db.open_table()."""
return self._DbShim(self)