"""Qdrant-backed persistent memory store (D-01 storage engine, sync write). Replaces LanceDB with Qdrant for remote vector storage. Collections (payload-partitioned, per Qdrant best practices): - `records` : MemoryRecord rows (1024-dim cosine vectors) All points carry `table: "records"` + `group_id` payload. - `metadata` : Payload-only (no vectors) containing edges, events, budget_ledger, ratelimit_ledger. Each point carries `table` + `group_id` payload. Both collections use `is_tenant: true` payload indexes on `table` for co-located storage and per-table HNSW graphs (payload_m=16, m=0). Encryption-at-rest is identical to LanceDB: AES-256-GCM on literal_surface / provenance_json / profile_modulation_gain_json (records) and data_json (events). AD = record UUID bytes. The Qdrant client connects to a remote server via QDRANT_URL + QDRANT_API_KEY environment variables (overridable via constructor). """ from __future__ import annotations import asyncio import base64 import functools import json import os import re import sys import threading from datetime import datetime, timezone from pathlib import Path from typing import Callable from uuid import UUID from qdrant_client import QdrantClient, models from qdrant_client.models import ( Distance, FieldCondition, Filter, MatchValue, PayloadSchemaType, PointStruct, VectorParams, ) from qdrant_client.http.exceptions import UnexpectedResponse import pandas as pd from cryptography.hazmat.primitives.ciphers.aead import AESGCM from iai_mcp.crypto import ( CIPHERTEXT_PREFIX, NONCE_BYTES, CryptoKey, encrypt_field, is_encrypted, ) from iai_mcp.types import ( DEFAULT_EMBED_DIM, EMBED_DIM, SCHEMA_VERSION_CURRENT, MemoryRecord, TIER_ENUM, ) # --------------------------------------------------------------------------- env # Qdrant connection: override with env vars or constructor kwargs. QDRANT_URL = os.environ.get("QDRANT_URL") QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") # --------------------------------------------------------------------------- tables RECORDS_TABLE = "records" EDGES_TABLE = "edges" EVENTS_TABLE = "events" BUDGET_TABLE = "budget_ledger" RATELIMIT_TABLE = "ratelimit_ledger" METADATA_TABLE = "metadata" # Embedding dimension for the records collection. # Defaults to 1024 for bge-m3 (remote embedder). Overridable via # IAI_MCP_EMBED_DIM env var. _RECORDS_DIM = int(os.environ.get("IAI_MCP_EMBED_DIM", "1024")) # Metadata tables that live in the metadata collection. _METADATA_TABLES = frozenset({EDGES_TABLE, EVENTS_TABLE, BUDGET_TABLE, RATELIMIT_TABLE}) # Edge type enum. EDGE_TYPES: frozenset[str] = frozenset({ "hebbian", "contradicts", "consolidated_from", "schema_instance_of", "temporal_next", "invariant_anchor", "curiosity_bridge", "profile_modulates", "hebbian_structure", }) # RFC-4122 canonical UUID regex. _UUID_STR_RE = re.compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" ) def _uuid_literal(value: UUID | str) -> str: """Return a Qdrant WHERE-safe UUID literal.""" s = str(value).lower() if not _UUID_STR_RE.match(s): raise ValueError(f"not a canonical UUID: {value!r}") return s def _resolve_embed_dim() -> int: """Pick the embedding dimension for the records collection.""" env_dim = os.environ.get("IAI_MCP_EMBED_DIM") if env_dim: try: return int(env_dim) except ValueError: pass env_key = os.environ.get("IAI_MCP_EMBED_MODEL") # bge-m3 -> 1024 if env_key == "bge-m3": return 1024 return _RECORDS_DIM class QdrantStore: """Qdrant-backed memory store. Mirrors the MemoryStore API so all existing callers (daemon, MCP wrapper, capture, retrieve, sleep, etc.) work unchanged. """ def __init__( self, path: Path | str | None = None, user_id: str = "default", read_consistency_interval: float | None = None, url: str | None = None, api_key: str | None = None, ) -> None: """Open (or initialise) a Qdrant-backed store. Parameters ---------- url: Qdrant server URL. Defaults to QDRANT_URL env var. api_key: Qdrant API key. Defaults to QDRANT_API_KEY env var. user_id: User ID that scopes the encryption key. Also used as group_id for payload partitioning (multi-tenant ready). read_consistency_interval: Reserved for future consistency tuning (not used with Qdrant). """ self._embed_dim: int = _resolve_embed_dim() self._user_id: str = user_id self._group_id: str = user_id # payload partition key self._read_consistency_interval = read_consistency_interval # Qdrant client self._qdrant_url = url or QDRANT_URL self._qdrant_api_key = api_key or QDRANT_API_KEY self._client = QdrantClient( url=self._qdrant_url, api_key=self._qdrant_api_key, timeout=30, ) # Encryption self._crypto_key_wrapper = CryptoKey(user_id=user_id) self._crypto_key: bytes | None = None # Graph sync hook self._graph_sync_hook: Callable[[str, MemoryRecord], None] | None = None # Async write queue (same contract as MemoryStore) self._write_queue = None self._async_loop: asyncio.AbstractEventLoop | None = None self._async_thread: threading.Thread | None = None self._provenance_queue = None # Ensure collections exist self._ensure_collections() # ------------------------------------------------------------------ schema def _ensure_collections(self) -> None: """Create 2 collections (per Qdrant best practices: payload partitioning). - `records`: 1024-dim cosine vectors for MemoryRecord rows. - `metadata`: payload-only (no vectors) for edges/events/budget/ratelimit. Both use `is_tenant: true` indexes on `table` for co-located storage and per-table HNSW graphs (payload_m=16, m=0). """ # Collection 1: records (vectors) if not self._collection_exists(RECORDS_TABLE): try: self._client.create_collection( collection_name=RECORDS_TABLE, vectors_config=VectorParams( size=self._embed_dim, distance=Distance.COSINE, ), hnsw_config=models.HnswConfigDiff( payload_m=16, # per-table HNSW graph m=0, # no global index (all data partitioned by table) ), ) except UnexpectedResponse as e: if e.status_code == 409: pass # exists elif e.status_code == 401: raise else: raise # Collection 2: metadata (payload-only) if not self._collection_exists(METADATA_TABLE): try: self._client.create_collection( collection_name=METADATA_TABLE, # No vectors_config — payload-only collection ) except UnexpectedResponse as e: if e.status_code == 409: pass # exists elif e.status_code == 401: raise else: raise # Ensure payload indexes for filtered queries self._ensure_indexes() def _collection_exists(self, name: str) -> bool: """Check if a collection exists.""" try: info = self._client.get_collection(name) return info is not None except UnexpectedResponse as e: if e.status_code == 401: raise RuntimeError( "Qdrant 401: invalid API key. Set QDRANT_API_KEY env var." ) from e # Collection doesn't exist or other error return False except Exception: return False def _ensure_indexes(self) -> None: """Create payload indexes for common filter fields. Both collections get `table` and `group_id` as tenant indexes (is_tenant=true) for co-located storage and per-table HNSW graphs. """ # Tenant indexes — critical for co-located storage for collection_name in (RECORDS_TABLE, METADATA_TABLE): for field_name, is_tenant in [("table", True), ("group_id", True)]: try: self._client.create_payload_index( collection_name=collection_name, field_name=field_name, field_schema=PayloadSchemaType.KEYWORD, ) # Note: Qdrant API doesn't expose is_tenant via create_payload_index; # it's set via update_collection_payload_profiles API. # For now, we create the keyword index; is_tenant optimization # is handled by the application-level table filtering. except Exception: pass # Collection-specific indexes indexes = { RECORDS_TABLE: [ ("id", PayloadSchemaType.KEYWORD), ("tier", PayloadSchemaType.KEYWORD), ("community_id", PayloadSchemaType.KEYWORD), ], METADATA_TABLE: [ ("edge_type", PayloadSchemaType.KEYWORD), ("kind", PayloadSchemaType.KEYWORD), ("session_id", PayloadSchemaType.KEYWORD), ], } for tbl_name, fields in indexes.items(): for field_name, schema_type in fields: try: self._client.create_payload_index( collection_name=tbl_name, field_name=field_name, field_schema=schema_type, ) except Exception: pass @property def embed_dim(self) -> int: """Actual embedding dimension in the records collection.""" return self._embed_dim @property def user_id(self) -> str: """user_id that scopes the encryption key.""" return self._user_id # ------------------------------------------------------------------ encryption def _key(self) -> bytes: """Lazy-load the encryption key.""" if self._crypto_key is None: self._crypto_key = self._crypto_key_wrapper.get_or_create() return self._crypto_key def _ad(self, record_id: UUID | str) -> bytes: """Associated data for encryption: canonical UUID str bytes.""" return _uuid_literal(record_id).encode("ascii") def _encrypt_for_record(self, record_id: UUID, value: str) -> str: """Encrypt a per-record sensitive field.""" if is_encrypted(value): return value return encrypt_field(value, self._key(), associated_data=self._ad(record_id)) @functools.cached_property def _cached_aesgcm(self) -> AESGCM: """One AESGCM cipher per store lifetime.""" return AESGCM(self._key()) def _decrypt_for_record(self, record_id: UUID, value: str) -> str: """Decrypt a per-record sensitive field; pass through plaintext.""" if not is_encrypted(value): return value if not value.startswith(CIPHERTEXT_PREFIX): raise ValueError("field is not iai:enc:v1:-prefixed ciphertext") payload_b64 = value[len(CIPHERTEXT_PREFIX):] payload = base64.b64decode(payload_b64) if len(payload) < NONCE_BYTES + 16: raise ValueError("ciphertext payload too short") nonce = payload[:NONCE_BYTES] ct_with_tag = payload[NONCE_BYTES:] associated_data = self._ad(record_id) plaintext_bytes = self._cached_aesgcm.decrypt( nonce, ct_with_tag, associated_data or None ) return plaintext_bytes.decode("utf-8") # ------------------------------------------------------------------ I/O def register_graph_sync_hook( self, hook: Callable[[str, MemoryRecord], None] | None ) -> None: """Register a callback that mirrors store writes to the runtime graph.""" self._graph_sync_hook = hook def _fire_graph_sync_hook(self, op: str, record: MemoryRecord) -> None: """Dispatch the (op, record) event. Failures are swallowed.""" hook = self._graph_sync_hook if hook is None: return try: hook(op, record) except Exception as exc: try: sys.stderr.write( json.dumps({ "event": "graph_sync_failed", "op": op, "record_id": str(getattr(record, "id", "")), "error": str(exc), "ts": datetime.now(timezone.utc).isoformat(), }) + "\n" ) except Exception: pass # ------------------------------------------------------- conversions def _to_point(self, r: MemoryRecord) -> PointStruct: """Convert a MemoryRecord to a Qdrant PointStruct.""" literal_ct = self._encrypt_for_record(r.id, r.literal_surface) provenance_plain = json.dumps(r.provenance) provenance_ct = self._encrypt_for_record(r.id, provenance_plain) gain_plain = json.dumps(r.profile_modulation_gain or {}) gain_ct = self._encrypt_for_record(r.id, gain_plain) return PointStruct( id=str(r.id), vector=[float(x) for x in r.embedding], payload={ "table": RECORDS_TABLE, "group_id": self._group_id, "tier": r.tier, "literal_surface": literal_ct, "aaak_index": r.aaak_index or "", "structure_hv": base64.b64encode( bytes(r.structure_hv or b"") ).decode("ascii"), "community_id": str(r.community_id) if r.community_id else "", "centrality": float(r.centrality), "detail_level": int(r.detail_level), "pinned": bool(r.pinned), "stability": float(r.stability), "difficulty": float(r.difficulty), "last_reviewed": ( r.last_reviewed.isoformat() if r.last_reviewed else None ), "never_decay": bool(r.never_decay), "never_merge": bool(r.never_merge), "provenance_json": provenance_ct, "created_at": r.created_at.isoformat() if r.created_at else None, "updated_at": r.updated_at.isoformat() if r.updated_at else None, "tags_json": json.dumps(r.tags), "language": str(r.language), "s5_trust_score": float(r.s5_trust_score), "profile_modulation_gain_json": gain_ct, "schema_version": int(r.schema_version), }, ) def _from_point(self, point: PointStruct) -> MemoryRecord: """Convert a Qdrant PointStruct to a MemoryRecord.""" from uuid import UUID as _UUID payload = point.payload if not payload: raise ValueError(f"point {point.id} has no payload") row_uuid = _UUID(point.id) # structure_hv: base64-decoded bytes structure_raw = payload.get("structure_hv") if structure_raw: try: structure_hv = base64.b64decode(structure_raw) except Exception: structure_hv = b"" else: structure_hv = b"" community_raw = payload.get("community_id") or "" community_id = _UUID(community_raw) if community_raw else None # language back-compat lang_raw = payload.get("language") version_raw = payload.get("schema_version") try: version_int = int(version_raw) if version_raw is not None else SCHEMA_VERSION_CURRENT except (TypeError, ValueError): version_int = SCHEMA_VERSION_CURRENT schema_version = version_int is_empty_language = lang_raw is None or (isinstance(lang_raw, str) and lang_raw == "") if is_empty_language and schema_version == 1: language = "__LEGACY_EMPTY__" elif is_empty_language: language = "en" else: language = str(lang_raw) s5_raw = payload.get("s5_trust_score") s5_trust_score = float(s5_raw) if s5_raw is not None else 0.5 # decrypt profile_modulation_gain_json gain_raw = payload.get("profile_modulation_gain_json") or "{}" if is_encrypted(gain_raw): gain_raw = self._decrypt_for_record(row_uuid, gain_raw) try: profile_modulation_gain = json.loads(gain_raw) or {} except (TypeError, json.JSONDecodeError): profile_modulation_gain = {} # decrypt literal_surface + provenance_json literal_raw = payload.get("literal_surface", "") if is_encrypted(literal_raw): literal_raw = self._decrypt_for_record(row_uuid, literal_raw) provenance_raw = payload.get("provenance_json") or "[]" if is_encrypted(provenance_raw): provenance_raw = self._decrypt_for_record(row_uuid, provenance_raw) try: provenance_list = json.loads(provenance_raw) if provenance_raw else [] except (TypeError, json.JSONDecodeError): provenance_list = [] def _parse_ts(val): if val is None: return None try: return datetime.fromisoformat(val) except (TypeError, ValueError): return None rec = MemoryRecord( id=row_uuid, tier=payload.get("tier", "episodic"), literal_surface=literal_raw, aaak_index=payload.get("aaak_index") or "", embedding=list(point.vector) if point.vector else [], community_id=community_id, centrality=float(payload.get("centrality", 0.0) or 0.0), detail_level=int(payload.get("detail_level", 1)), pinned=bool(payload.get("pinned", False)), stability=float(payload.get("stability") or 0.0), difficulty=float(payload.get("difficulty") or 0.0), last_reviewed=_parse_ts(payload.get("last_reviewed")), never_decay=bool(payload.get("never_decay", False)), never_merge=bool(payload.get("never_merge", False)), provenance=provenance_list, created_at=_parse_ts(payload.get("created_at")) or datetime.now(timezone.utc), updated_at=_parse_ts(payload.get("updated_at")) or datetime.now(timezone.utc), tags=json.loads(payload.get("tags_json") or "[]"), language=language, s5_trust_score=s5_trust_score, profile_modulation_gain=profile_modulation_gain, schema_version=schema_version, structure_hv=structure_hv, ) if language == "__LEGACY_EMPTY__": rec.language = "" return rec # ------------------------------------------------------- point operations def _make_filter(self, id_str: str) -> Filter: """Create a filter for a single record ID.""" return Filter(must=[FieldCondition( key="id", match=MatchValue(value=id_str), )]) def _make_tier_filter(self, tier: str) -> Filter: """Create a filter for a tier value.""" return Filter(must=[FieldCondition( key="tier", match=MatchValue(value=tier), )]) def _make_combined_filter(self, *conditions: FieldCondition) -> Filter: """Combine multiple field conditions with AND.""" return Filter(must=list(conditions)) # ------------------------------------------------------- writes def insert(self, record: MemoryRecord) -> None: """Append a record. verbatim, no rewrite at write time.""" if record.tier not in TIER_ENUM: raise ValueError(f"invalid tier {record.tier!r}") if len(record.embedding) != self._embed_dim: raise ValueError( f"embedding must be {self._embed_dim}d, got {len(record.embedding)}" ) # lazy structure_hv fill if not record.structure_hv: from iai_mcp.tem import bind_structure record.structure_hv = bind_structure(record) point = self._to_point(record) self._client.upsert( collection_name=RECORDS_TABLE, points=[point], ) self._fire_graph_sync_hook("insert", record) def update(self, record: MemoryRecord) -> None: """Full-record update (rewrites core columns).""" if len(record.embedding) != self._embed_dim: raise ValueError( f"embedding must be {self._embed_dim}d, got {len(record.embedding)}" ) # Check existence first points = self._client.retrieve( collection_name=RECORDS_TABLE, ids=[str(record.id)], ) if not points: return point = self._to_point(record) self._client.upsert( collection_name=RECORDS_TABLE, points=[point], ) self._fire_graph_sync_hook("update", record) def delete(self, record_id: UUID) -> None: """Remove a record by id.""" try: self._client.delete( collection_name=RECORDS_TABLE, points_selector=models.PointIdsList( points=[str(record_id)], ), ) except Exception: return class _DeleteShim: def __init__(self, rid): self.id = rid self._fire_graph_sync_hook("delete", _DeleteShim(record_id)) def get(self, record_id: UUID) -> MemoryRecord | None: """Point read by ID.""" points = self._client.retrieve( collection_name=RECORDS_TABLE, ids=[str(record_id)], with_payload=True, ) if not points: return None return self._from_point(points[0]) def all_records(self) -> list[MemoryRecord]: """Full table scan of the records collection.""" offset = None all_points = [] table_filter = Filter(must=[FieldCondition( key="table", match=MatchValue(value=RECORDS_TABLE), )]) while True: points, next_offset = self._client.scroll( collection_name=RECORDS_TABLE, limit=1000, offset=offset, scroll_filter=table_filter, with_payload=True, ) all_points.extend(points) if next_offset is None: break offset = next_offset return [self._from_point(p) for p in all_points] def iter_records( self, *, columns: list[str] | None = None, batch_size: int = 1024, where: str | None = None, ): """Streaming iterator over records (filtered by table=records).""" offset = None while True: # Build filter: always include table=records, optionally add tier conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] if where and where.startswith("tier = "): tier = where.split("'")[1] conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) qdrant_filter = Filter(must=conditions) if conditions else None points, next_offset = self._client.scroll( collection_name=RECORDS_TABLE, limit=batch_size, offset=offset, scroll_filter=qdrant_filter, with_payload=True, ) for point in points: yield self._from_point(point) if next_offset is None: break offset = next_offset def iter_record_columns( self, columns: list[str], *, batch_size: int = 1024, where: str | None = None, ): """Projection-only iteration; no MemoryRecord, no decrypt. Filtered by table=records. """ if not columns: raise ValueError("iter_record_columns requires a non-empty columns list") offset = None while True: conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] if where and where.startswith("tier = "): tier = where.split("'")[1] conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) qdrant_filter = Filter(must=conditions) if conditions else None points, next_offset = self._client.scroll( collection_name=RECORDS_TABLE, limit=batch_size, offset=offset, scroll_filter=qdrant_filter, with_payload=True, with_vectors=False, ) for point in points: # Filter to requested columns row = {k: v for k, v in point.payload.items() if k in columns} row["id"] = point.id yield row if next_offset is None: break offset = next_offset def query_similar( self, vec: list[float], k: int = 10, tier: str | None = None, ) -> list[tuple[MemoryRecord, float]]: """Cosine-distance kNN search on the records collection.""" if tier is not None and tier not in TIER_ENUM: raise ValueError( f"invalid tier {tier!r}; must be one of {sorted(TIER_ENUM)}" ) # Build query filter: must be in records collection + optional tier conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] if tier is not None: conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) qdrant_filter = Filter(must=conditions) # Check if collection is empty try: info = self._client.get_collection(RECORDS_TABLE) if info.points_count == 0: return [] except Exception: return [] # Vector search search_result = self._client.query_points( collection_name=RECORDS_TABLE, query=vec, limit=k, query_filter=qdrant_filter, with_payload=True, with_vectors=False, ).points out: list[tuple[MemoryRecord, float]] = [] for point in search_result: record = self._from_point(point) # Qdrant returns score as similarity (1.0 = identical, 0.0 = orthogonal) score = float(point.score) if point.score else 0.0 out.append((record, score)) return out def update_record(self, record: MemoryRecord) -> None: """Persist FSRS-relevant columns back to the records collection.""" # Retrieve existing point points = self._client.retrieve( collection_name=RECORDS_TABLE, ids=[str(record.id)], with_payload=True, ) if not points: return # Update only FSRS columns via patch self._client.set_payload( collection_name=RECORDS_TABLE, payload={ "stability": float(record.stability), "difficulty": float(record.difficulty), "last_reviewed": record.last_reviewed.isoformat() if record.last_reviewed else None, "updated_at": datetime.now(timezone.utc).isoformat(), }, payload_selector=[str(record.id)], ) # -------------------------------------------------------- reconsolidation def append_provenance(self, record_id: UUID, entry: dict) -> None: """Append a provenance entry to the record.""" # Retrieve existing point points = self._client.retrieve( collection_name=RECORDS_TABLE, ids=[str(record_id)], with_payload=True, ) if not points: return payload = points[0].payload raw = payload.get("provenance_json") or "[]" if is_encrypted(raw): raw = self._decrypt_for_record(record_id, raw) try: existing = json.loads(raw) except (TypeError, json.JSONDecodeError): existing = [] existing.append(entry) new_plain = json.dumps(existing) new_ct = self._encrypt_for_record(record_id, new_plain) # Update provenance_json and updated_at self._client.set_payload( collection_name=RECORDS_TABLE, payload={ "provenance_json": new_ct, "updated_at": datetime.now(timezone.utc).isoformat(), }, payload_selector=[str(record_id)], ) def append_provenance_batch( self, pairs: "list[tuple[UUID, dict]]", records_cache: "dict | None" = None, ) -> None: """Batched provenance append.""" if not pairs: return # Group entries by record_id from collections import defaultdict grouped: dict[str, list[dict]] = defaultdict(list) for rid, entry in pairs: grouped[str(rid)].append(entry) now = datetime.now(timezone.utc).isoformat() updates: dict[str, str] = {} if records_cache is not None: for rid_str, entries in grouped.items(): try: canonical = _uuid_literal(rid_str) except ValueError: continue try: rec = records_cache.get(UUID(rid_str)) except (TypeError, ValueError): rec = None if rec is None: rec = records_cache.get(rid_str) if rec is None: continue existing = list(rec.provenance or []) existing.extend(entries) new_plain = json.dumps(existing) new_ct = self._encrypt_for_record(UUID(rid_str), new_plain) updates[canonical] = new_ct else: # Retrieve all affected points ids_to_fetch = list(grouped.keys()) try: points = self._client.retrieve( collection_name=RECORDS_TABLE, ids=ids_to_fetch, with_payload=True, ) except Exception: return for point in points: rid_str = point.id if rid_str not in grouped: continue entries = grouped[rid_str] payload = point.payload raw_prov = payload.get("provenance_json") or "[]" if is_encrypted(raw_prov): try: raw_prov = self._decrypt_for_record(UUID(rid_str), raw_prov) except Exception: raw_prov = "[]" try: existing = json.loads(raw_prov) except (TypeError, json.JSONDecodeError): existing = [] existing.extend(entries) new_plain = json.dumps(existing) new_ct = self._encrypt_for_record(UUID(rid_str), new_plain) updates[rid_str] = new_ct if not updates: return # Batch update provenance_json + updated_at # Qdrant doesn't have merge_insert, so we do individual set_payload calls for rid_str, new_ct in updates.items(): try: self._client.set_payload( collection_name=RECORDS_TABLE, payload={ "provenance_json": new_ct, "updated_at": now, }, payload_selector=[rid_str], ) except Exception: continue # ------------------------------------------------------------------ edges def boost_edges( self, pairs: list[tuple[UUID, UUID]], delta: float | list[float] = 0.1, edge_type: str = "hebbian", ) -> dict[tuple[str, str], float]: """Pairwise edge boost in the metadata collection (table=edges).""" if edge_type not in EDGE_TYPES: raise ValueError( f"invalid edge_type {edge_type!r}; must be one of {sorted(EDGE_TYPES)}" ) if isinstance(delta, (int, float)): deltas = [float(delta)] * len(pairs) else: deltas = [float(d) for d in delta] if len(deltas) != len(pairs): raise ValueError( f"deltas length {len(deltas)} != pairs length {len(pairs)}" ) if not pairs: return {} # Coalesce duplicate canonical keys coalesced: dict[tuple[str, str], float] = {} for (a, b), d in zip(pairs, deltas): key = (str(a), str(b)) canonical = tuple(sorted(key)) coalesced[canonical] = coalesced.get(canonical, 0.0) + d if not coalesced: return {} # Fetch existing edges from metadata collection all_edges = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE) existing_map: dict[tuple[str, str, str], float] = {} for point in all_edges: p = point.payload edge_key = (p.get("src", ""), p.get("dst", ""), p.get("edge_type", "")) existing_map[edge_key] = float(p.get("weight", 0.0)) now = datetime.now(timezone.utc).isoformat() points_to_upsert: list[PointStruct] = [] new_weights: dict[tuple[str, str], float] = {} for (src_str, dst_str), accum_delta in coalesced.items(): edge_key = (src_str, dst_str, edge_type) if edge_key in existing_map: nw = existing_map[edge_key] + accum_delta else: nw = accum_delta # Create payload-only point points_to_upsert.append(PointStruct( id=f"{src_str}:{dst_str}:{edge_type}", vector=None, payload={ "table": EDGES_TABLE, "group_id": self._group_id, "src": src_str, "dst": dst_str, "edge_type": edge_type, "weight": nw, "updated_at": now, }, )) new_weights[(src_str, dst_str)] = nw if points_to_upsert: self._client.upsert( collection_name=METADATA_TABLE, points=points_to_upsert, ) return new_weights def reinforce_record( self, record_id: UUID, anchor_id: UUID | None = None, edge_type: str = "hebbian", delta: float = 0.1, ) -> dict[tuple[str, str], float]: """Single-record Hebbian reinforcement.""" if anchor_id is None: pair = (record_id, record_id) else: pair = (anchor_id, record_id) return self.boost_edges([pair], delta=delta, edge_type=edge_type) def add_contradicts_edge(self, original: UUID, new_id: UUID) -> None: """Add a contradicts edge in the metadata collection (table=edges).""" self._client.upsert( collection_name=METADATA_TABLE, points=[PointStruct( id=f"{original}:{new_id}:contradicts", vector=None, payload={ "table": EDGES_TABLE, "group_id": self._group_id, "src": str(original), "dst": str(new_id), "edge_type": "contradicts", "weight": 1.0, "updated_at": datetime.now(timezone.utc).isoformat(), }, )], ) # ------------------------------------------------------- async writes async def enable_async_writes( self, coalesce_ms: int = 100, max_batch: int = 128, max_queue_size: int = 4096, ) -> None: """Switch insert() onto the coalescing AsyncWriteQueue.""" if self._write_queue is not None: return from iai_mcp.write_queue import AsyncWriteQueue ready = threading.Event() loop_holder: dict = {} def _run() -> None: loop = asyncio.new_event_loop() loop_holder["loop"] = loop asyncio.set_event_loop(loop) ready.set() try: loop.run_forever() finally: loop.close() thread = threading.Thread( target=_run, name="iai-mcp-async-writes", daemon=True, ) thread.start() ready.wait() bg_loop: asyncio.AbstractEventLoop = loop_holder["loop"] to_point = self._to_point fire_hook = self._fire_graph_sync_hook class _RecordPointAdapter: async def add(self, records: list) -> None: points = [to_point(r) for r in records] self._client.upsert(RECORDS_TABLE, points=points) _client = None # set below adapter = _RecordPointAdapter() adapter._client = self._client def _on_flushed(batch: list) -> None: for rec in batch: fire_hook("insert", rec) queue = AsyncWriteQueue( adapter, coalesce_ms=coalesce_ms, max_batch=max_batch, max_queue_size=max_queue_size, on_flushed=_on_flushed, ) asyncio.run_coroutine_threadsafe(queue.start(), bg_loop).result() self._async_loop = bg_loop self._async_thread = thread self._write_queue = queue self.enable_provenance_queue() async def disable_async_writes(self) -> None: """Drain the queue, tear down the background loop.""" if self._write_queue is None: self.disable_provenance_queue() return self.disable_provenance_queue() bg_loop = self._async_loop queue = self._write_queue try: asyncio.run_coroutine_threadsafe(queue.stop(), bg_loop).result() finally: if bg_loop is not None: bg_loop.call_soon_threadsafe(bg_loop.stop) if self._async_thread is not None: self._async_thread.join(timeout=5.0) self._write_queue = None self._async_loop = None self._async_thread = None # -------------------------------------------------- provenance queue def enable_provenance_queue(self, *, coalesce_ms: int = 50) -> None: """Route provenance writes through a daemon-thread queue.""" if self._provenance_queue is not None: return from iai_mcp.provenance_queue import ProvenanceWriteQueue q = ProvenanceWriteQueue(self, coalesce_ms=coalesce_ms) q.start() self._provenance_queue = q def disable_provenance_queue(self) -> None: """Drain + stop the provenance queue.""" q = self._provenance_queue if q is None: return try: q.flush(timeout=2.0) except Exception: pass try: q.stop() except Exception: pass self._provenance_queue = None def queue_provenance_batch( self, pairs: "list[tuple[UUID, dict]]" ) -> None: """Fire-and-forget provenance write.""" if not pairs: return q = self._provenance_queue if q is not None: q.enqueue(pairs) return self.append_provenance_batch(pairs, records_cache=None) # ------------------------------------------------------- helpers def _scroll_all(self, collection_name: str, batch_size: int = 1000, table_filter: str | None = None) -> list: """Scroll through all points in a collection. When table_filter is provided, adds a must condition on the `table` payload field to select a specific logical table within the collection. """ offset = None all_points = [] scroll_filter = None if table_filter: scroll_filter = Filter(must=[FieldCondition( key="table", match=MatchValue(value=table_filter), )]) while True: points, next_offset = self._client.scroll( collection_name=collection_name, limit=batch_size, offset=offset, scroll_filter=scroll_filter, with_payload=True, with_vectors=False, ) all_points.extend(points) if next_offset is None: break offset = next_offset return all_points # ---------------------------------------------------------- events def events_add(self, event_id: UUID, kind: str, severity: str, domain: str, ts: datetime, data_json: str, session_id: str, source_ids_json: str) -> None: """Add a single event row to the metadata collection (table=events).""" point = PointStruct( id=str(event_id), vector={}, payload={ "table": EVENTS_TABLE, "group_id": self._group_id, "kind": kind, "severity": severity, "domain": domain, "ts": ts.isoformat(), "data_json": data_json, "session_id": session_id, "source_ids_json": source_ids_json, }, ) self._client.upsert(collection_name=METADATA_TABLE, points=[point]) def events_query(self, kind: str | None = None, since: datetime | None = None, severity: str | None = None, limit: int = 100) -> list[dict]: """Query events from the metadata collection (table=events), newest first.""" # Always filter by table=events conditions = [FieldCondition(key="table", match=MatchValue(value=EVENTS_TABLE))] if kind is not None: conditions.append(FieldCondition(key="kind", match=MatchValue(value=kind))) if severity is not None: conditions.append(FieldCondition(key="severity", match=MatchValue(value=severity))) event_filter = Filter(must=conditions) if conditions else None points, _ = self._client.scroll( collection_name=METADATA_TABLE, limit=limit * 10, # fetch extra for post-filtering with_payload=True, with_vectors=False, scroll_filter=event_filter, ) out: list[dict] = [] for pt in points: p = pt.payload ts_str = p.get("ts", "") try: ts_dt = datetime.fromisoformat(ts_str) except (ValueError, TypeError): ts_dt = datetime.now(timezone.utc) if since is not None: since_cmp = since if since.tzinfo else since.replace(tzinfo=timezone.utc) if ts_dt < since_cmp: continue raw_data = p.get("data_json") or "{}" try: data = json.loads(raw_data) except (TypeError, json.JSONDecodeError): data = {} try: source_ids = json.loads(p.get("source_ids_json") or "[]") except (TypeError, json.JSONDecodeError): source_ids = [] out.append({ "id": p.get("id", str(pt.id)), "kind": p.get("kind", ""), "severity": p.get("severity") or None, "domain": p.get("domain") or None, "ts": ts_dt, "data": data, "session_id": p.get("session_id", "-"), "source_ids": source_ids, }) out.sort(key=lambda x: x["ts"], reverse=True) return out[:limit] # --------------------------------------------------------- count_rows def count_rows(self, table_name: str) -> int: """Return the number of points matching table_name in the appropriate collection. - `records` → counts points in records collection with table=records - edges/events/budget/ratelimit → counts points in metadata collection with table=edges/events/budget_ledger/ratelimit_ledger """ if table_name == RECORDS_TABLE: try: info = self._client.get_collection(RECORDS_TABLE) return info.points_count or 0 except Exception: return 0 elif table_name in _METADATA_TABLES: try: points = self._scroll_all(METADATA_TABLE, table_filter=table_name) return len(points) except Exception: return 0 else: return 0 # ------------------------------------------------------- edges_as_df def edges_as_dataframe(self) -> "pd.DataFrame": """Return all edges from the metadata collection as a pandas DataFrame.""" try: points = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE, batch_size=1000) if not points: return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"]) rows = [] for pt in points: p = pt.payload rows.append({ "src": p.get("src", ""), "dst": p.get("dst", ""), "edge_type": p.get("edge_type", ""), "weight": float(p.get("weight", 0.0)), "updated_at": p.get("updated_at", ""), }) return pd.DataFrame(rows) except Exception: return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"]) # ------------------------------------------------------------------ db shim class _DbShim: """LanceDB-compatible shim: store.db.open_table(name) -> table-like object.""" def __init__(self, store: "QdrantStore") -> None: self._store = store def open_table(self, name: str) -> "QdrantStore._TableShim": return self._TableShim(self._store, name) class _TableShim: def __init__(self, store: "QdrantStore", name: str) -> None: self._store = store self._name = name def count_rows(self) -> int: return self._store.count_rows(self._name) def to_pandas(self) -> pd.DataFrame: if self._name == EDGES_TABLE: return self._store.edges_as_dataframe() elif self._name == EVENTS_TABLE: return pd.DataFrame() return pd.DataFrame() @property def db(self) -> "QdrantStore._DbShim": """LanceDB-compatible shim for store.db.open_table().""" return self._DbShim(self)