diff --git a/deploy/systemd/iai-mcp-daemon.service b/deploy/systemd/iai-mcp-daemon.service index a307181..31572ce 100644 --- a/deploy/systemd/iai-mcp-daemon.service +++ b/deploy/systemd/iai-mcp-daemon.service @@ -16,13 +16,16 @@ After=default.target [Service] Type=simple -ExecStart=/usr/bin/python3 -m iai_mcp.daemon +ExecStart=/home/andreas/.venv/iai-mcp/bin/python -m iai_mcp.daemon Restart=on-failure RestartSec=30 StartLimitIntervalSec=60 StartLimitBurst=3 Environment="IAI_MCP_STORE=%h/.iai-mcp" +Environment="QDRANT_URL=http://192.168.0.22:6333" +Environment="QDRANT_API_KEY=1CerixWX$3zdlj" +Environment="IAI_MCP_EMBED_MODEL=bge-m3" Environment="LANG=en_US.UTF-8" StandardOutput=journal diff --git a/mcp-wrapper/src/tools.ts b/mcp-wrapper/src/tools.ts index 9fe5aad..a43ed71 100644 --- a/mcp-wrapper/src/tools.ts +++ b/mcp-wrapper/src/tools.ts @@ -250,8 +250,18 @@ export const toolSchemas: Record = { properties: { kind: { type: "string", + enum: [ + "s4_contradiction", + "trajectory_metric", + "schema_induction_run", + "llm_health", + "curiosity_silent_log", + "curiosity_question", + "cls_consolidation_run", + "crypto_key_rotated", + ], description: - "Event kind. Must be in the whitelist (see tool description).", + "Event kind — must be one of the enum values above.", }, since: { type: "string", diff --git a/scripts/install.sh b/scripts/install.sh index eec85ca..54b2969 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -31,13 +31,18 @@ die() { printf '\n\033[0;31m✗ %s\033[0m\n' "$*" >&2; exit 1; } # IAI_TEST_SKIP_BUILD=1 short-circuits the whole bootstrap so the LaunchAgent # section (6) can be exercised in isolation by tests/test_install_uninstall.py # (Plan 07.1-03 Task 3) without spending ~30s on venv + npm. +# +# VENV_BASE: all venvs live in ~/.venv/ for central backup. # --------------------------------------------------------------------------- +VENV_BASE="${HOME}/.venv" +VENV_PATH="${VENV_BASE}/iai-mcp" + if [[ "${IAI_TEST_SKIP_BUILD:-0}" == "1" ]]; then step "build skip (IAI_TEST_SKIP_BUILD=1)" ok "skipping sections 1-4 (venv/pip/npm/symlink) — test mode" else # ----------------------------------------------------------------------- - # 1. venv + # 1. venv (central location: ~/.venv/iai-mcp) # ----------------------------------------------------------------------- step "python venv" # iai-mcp requires Python 3.11 or 3.12 (torch + lancedb on 3.13/3.14 @@ -60,19 +65,19 @@ else fi [ -n "$PY" ] || die "Python 3.11 or 3.12 not found. macOS: brew install python@3.12 | Linux: apt install python3.12 (or use pyenv)" ok "using $PY ($($PY --version))" - if [ ! -d .venv ]; then - "$PY" -m venv .venv - ok ".venv created" + if [ ! -d "${VENV_PATH}" ]; then + "$PY" -m venv "${VENV_PATH}" + ok "venv created at ${VENV_PATH}" else - ok ".venv already exists" + ok "venv already exists at ${VENV_PATH}" fi # ----------------------------------------------------------------------- # 2. editable install # ----------------------------------------------------------------------- step "editable install (pip -e .)" - .venv/bin/pip install --quiet --upgrade pip - .venv/bin/pip install --quiet -e . + "${VENV_PATH}"/bin/pip install --quiet --upgrade pip + "${VENV_PATH}"/bin/pip install --quiet -e . ok "iai-mcp python package installed into venv" # ----------------------------------------------------------------------- @@ -99,7 +104,7 @@ else step "global CLI symlink" LOCAL_BIN="${HOME}/.local/bin" LINK_PATH="${LOCAL_BIN}/iai-mcp" - TARGET="${REPO_ROOT}/.venv/bin/iai-mcp" + TARGET="${VENV_PATH}/bin/iai-mcp" [ -x "${TARGET}" ] || die "venv entry point not found at ${TARGET}" @@ -114,7 +119,7 @@ else ok "${LINK_PATH} -> ${TARGET}" # PATH sanity check using python (grep is hook-blocked in this dev env). - PATH_HAS_LOCAL_BIN="$(.venv/bin/python - < Removing old venv" +rm -rf "${VENV_PATH}" + +echo "==> Creating fresh venv" +/usr/bin/python3.12 -m venv "${VENV_PATH}" + +echo "==> Upgrading pip" +"${VENV_PATH}"/bin/pip install --quiet --upgrade pip + +echo "==> Installing torch CPU-only (no AVX needed)" +"${VENV_PATH}"/bin/pip install --quiet \ + --extra-index-url https://download.pytorch.org/whl/cpu \ + torch torchvision torchaudio + +echo "==> Installing iai-mcp dependencies" +"${VENV_PATH}"/bin/pip install --quiet -e . + +echo "==> Building TS wrapper" +if [ -d mcp-wrapper ]; then + pushd mcp-wrapper >/dev/null + npm ci --silent --no-audit --no-fund + npm run build --silent + popd >/dev/null + echo " ✓ mcp-wrapper/dist built" +else + echo " ! mcp-wrapper/ missing" +fi + +echo "==> Done" +echo " venv: ${VENV_PATH}" +echo " test: ${VENV_PATH}/bin/python -c 'import torch; print(torch.__version__)'" diff --git a/src/iai_mcp/core.py b/src/iai_mcp/core.py index 27d0819..ebaa714 100644 --- a/src/iai_mcp/core.py +++ b/src/iai_mcp/core.py @@ -222,7 +222,7 @@ def dispatch(store: MemoryStore, method: str, params: dict) -> dict: # Plan 02 dispatch: non-empty store -> 5-stage pipeline; # empty store -> baseline cosine recall (Plan 01 fallback). - records_count = store.db.open_table("records").count_rows() + records_count = store.count_rows("records") if records_count == 0: cue_embedding = params.get("cue_embedding") or [0.0] * EMBED_DIM resp = retrieve.recall( @@ -675,7 +675,7 @@ def dispatch(store: MemoryStore, method: str, params: dict) -> dict: if method == "topology": from iai_mcp import sigma as sigma_mod - records_count = store.db.open_table("records").count_rows() + records_count = store.count_rows("records") if records_count == 0: return { "N": 0, "C": 0.0, "L": 0.0, "sigma": None, @@ -741,7 +741,7 @@ def dispatch(store: MemoryStore, method: str, params: dict) -> dict: # wake_depth knob reaches the assembler. from iai_mcp.session import assemble_session_start, SessionStartPayload sid = params.get("session_id", "-") - records_count = store.db.open_table("records").count_rows() + records_count = store.count_rows("records") if records_count == 0: empty = SessionStartPayload( l0="", @@ -810,7 +810,7 @@ def _schema_list_dispatch(store: MemoryStore, params: dict) -> dict: records = store.all_records() schema_records = [r for r in records if "schema" in (r.tags or [])] - edges_df = store.db.open_table("edges").to_pandas() + edges_df = store.edges_as_dataframe() if not edges_df.empty: schema_edges = edges_df[edges_df["edge_type"] == "schema_instance_of"] else: diff --git a/src/iai_mcp/daemon.py b/src/iai_mcp/daemon.py index c7d3121..422868c 100644 --- a/src/iai_mcp/daemon.py +++ b/src/iai_mcp/daemon.py @@ -51,7 +51,7 @@ from iai_mcp.quiet_window import ( should_relearn, ) from iai_mcp.socket_server import SocketServer -from iai_mcp.store import MemoryStore +from iai_mcp.store import MemoryStore, get_store, _use_qdrant from iai_mcp.tz import load_user_tz # --------------------------------------------------------------------------- @@ -1076,7 +1076,7 @@ async def main() -> int: # consistency — each read re-checks the latest committed version at # negligible cost (one manifest stat per query) and restores the # tick body's ability to see work. - store = MemoryStore(read_consistency_interval=timedelta(seconds=0)) + store = get_store(read_consistency_interval=timedelta(seconds=0)) try: from iai_mcp.crypto_key_watch import check_crypto_key_file_rotation_event @@ -1098,7 +1098,10 @@ async def main() -> int: # - partial_swap_inconsistent -> STOP daemon; surface remediation prompt # (manual recovery; no rollback anchor). from iai_mcp.migrate import detect_partial_migration - _migration_state = detect_partial_migration(store.db) + if _use_qdrant(): + _migration_state = {"state": "clean"} + else: + _migration_state = detect_partial_migration(store.db) if _migration_state["state"] == "partial_swap_inconsistent": try: sys.stderr.write( diff --git a/src/iai_mcp/embed.py b/src/iai_mcp/embed.py index d743cde..39e6b4e 100644 --- a/src/iai_mcp/embed.py +++ b/src/iai_mcp/embed.py @@ -1,4 +1,4 @@ -"""Embedding layer -- configurable embedder with a 3-model registry. +"""Embedding layer -- configurable embedder with a 4-model registry + remote. Plan 05-08 (2026-04-20): the DEFAULT is now ``bge-small-en-v1.5`` (384d English-only), reverting the Phase-2 deviation. PROJECT.md line @@ -8,11 +8,12 @@ swapped in bge-m3 (1024d multilingual) as D-08a. User directive job. bge-m3 stays selectable via env var / kwarg for anyone who needs multilingual semantic match at the 5x RAM cost. -Configurable 4-model registry: +Configurable 4-model registry (local) + remote OpenAI-compatible endpoint: - "bge-m3" -> BAAI/bge-m3 -> 1024d (opt-in, multilingual) - "multilingual-e5-small" -> intfloat/multilingual-e5-small -> 384d (compromise) - "bge-small-en-v1.5" -> BAAI/bge-small-en-v1.5 -> 384d (DEFAULT, English) - "all-MiniLM-L6-v2" -> sentence-transformers/all-MiniLM-L6-v2 -> 384d (English alternative embedder option; included for compatibility testing) +- "remote-bge-m3" -> OpenAI-compatible API -> 1024d (remote, no local model load) Selection priority at Embedder() instantiation: 1. Explicit `model_key` constructor arg @@ -31,14 +32,23 @@ from __future__ import annotations import os import threading +import httpx from sentence_transformers import SentenceTransformer -# 4-model registry. Name convention: short logical key -> HF repo id + dim. +# 4-model registry + remote entry. Name convention: short logical key -> HF +# repo id / endpoint + dim. # (2026-04-29): all-MiniLM-L6-v2 added as additive ablation entry; # DEFAULT_MODEL_KEY unchanged (English-Only Brain lock from / Plan 05-08). +# (2026-05-11): bge-m3 configured as remote (non-AVX CPU) — delegates embedding +# to an OpenAI-compatible server (bge-m3 @ 1024d). MODEL_REGISTRY: dict[str, dict] = { - "bge-m3": {"hf": "BAAI/bge-m3", "dim": 1024}, + "bge-m3": { + "endpoint": "http://192.168.0.50:12434/v1/embeddings", + "model": "bge-m3", + "dim": 1024, + "remote": True, + }, "multilingual-e5-small": {"hf": "intfloat/multilingual-e5-small", "dim": 384}, "bge-small-en-v1.5": {"hf": "BAAI/bge-small-en-v1.5", "dim": 384}, "all-MiniLM-L6-v2": {"hf": "sentence-transformers/all-MiniLM-L6-v2", "dim": 384}, @@ -64,6 +74,11 @@ def _resolve_model_key(model_key: str | None = None) -> str: return DEFAULT_MODEL_KEY +def _is_remote_model(model_key: str) -> bool: + """Check if a model key refers to a remote embedder.""" + return MODEL_REGISTRY.get(model_key, {}).get("remote", False) + + _MODEL_LOCK = threading.Lock() _MODEL_CACHE: dict[str, SentenceTransformer] = {} @@ -158,7 +173,90 @@ class Embedder: return [v.tolist() for v in vecs] -def embedder_for_store(store) -> "Embedder": +class RemoteEmbedder: + """Embedder that delegates to an OpenAI-compatible remote endpoint. + + Used when the local CPU cannot run sentence-transformers (e.g. no AVX). + Sends text to a remote bge-m3 instance and returns L2-normalised 1024d + vectors. + + The remote endpoint must speak the OpenAI `/v1/embeddings` protocol: + POST /v1/embeddings + {"model": "bge-m3", "input": ["text"]} + -> {"data": [{"embedding": [0.0, ...], ...}]} + """ + + def __init__( + self, + model_key: str | None = None, + *, + endpoint: str | None = None, + model_name: str | None = None, + ) -> None: + if model_key is not None and model_key in MODEL_REGISTRY: + spec = MODEL_REGISTRY[model_key] + self.model_key: str = model_key + self._endpoint: str = spec["endpoint"] + self._model_name: str = spec["model"] + self.DIM: int = int(spec["dim"]) + elif endpoint is not None and model_name is not None: + self.model_key = "custom-remote" + self._endpoint = endpoint + self._model_name = model_name + # Discover dim from a probe call + self.DIM = self._probe_dim() + else: + raise ValueError( + "RemoteEmbedder requires model_key from MODEL_REGISTRY " + "or explicit endpoint + model_name" + ) + + self._client = httpx.Client(timeout=30.0) + + def _probe_dim(self) -> int: + """Make a single embedding call to discover the output dimension.""" + resp = self._client.post( + self._endpoint, + json={"model": self._model_name, "input": ["probe"]}, + ) + resp.raise_for_status() + data = resp.json() + return len(data["data"][0]["embedding"]) + + def embed(self, text: str) -> list[float]: + """Encode a single string. Returns L2-normalised vector.""" + resp = self._client.post( + self._endpoint, + json={"model": self._model_name, "input": [text]}, + ) + resp.raise_for_status() + data = resp.json() + vec = data["data"][0]["embedding"] + # Normalise if not already (bge-m3 on Ollama returns normalised) + norm = (sum(x * x for x in vec)) ** 0.5 + if norm > 0: + vec = [x / norm for x in vec] + return vec + + def embed_batch(self, texts: list[str]) -> list[list[float]]: + """Batch-encode preserving input order.""" + resp = self._client.post( + self._endpoint, + json={"model": self._model_name, "input": texts}, + ) + resp.raise_for_status() + data = resp.json() + results = [] + for item in data["data"]: + vec = item["embedding"] + norm = (sum(x * x for x in vec)) ** 0.5 + if norm > 0: + vec = [x / norm for x in vec] + results.append(vec) + return results + + +def embedder_for_store(store) -> "Embedder | RemoteEmbedder": """Store-aware Embedder factory. Picks the model whose output dim matches the existing LanceDB records schema, so a legacy 1024d store from the pre-Plan-05-08 bge-m3 era stays queryable until it is re-embedded down to @@ -168,14 +266,24 @@ def embedder_for_store(store) -> "Embedder": 1. If store.embed_dim has an exact match in MODEL_REGISTRY, prefer the model whose logical key name indicates the canonical model at that dim (bge-small-en-v1.5 for 384d default; bge-m3 for legacy/opt-in 1024d). - 2. Otherwise fall through to the env/registry default via Embedder(). + 2. If IAI_MCP_EMBED_MODEL points to a remote model, use RemoteEmbedder. + 3. Otherwise fall through to the env/registry default via Embedder(). This decouples runtime model selection from a global env var so a single process can operate multiple stores at different dims while the migration from a legacy 1024d store down to 384d completes. """ target_dim = getattr(store, "embed_dim", None) + env_key = os.environ.get("IAI_MCP_EMBED_MODEL") + + # Check if user explicitly requested remote embedder + if env_key and _is_remote_model(env_key): + return RemoteEmbedder(model_key=env_key) + if target_dim is None: + # No existing store — check if remote is requested + if env_key and _is_remote_model(env_key): + return RemoteEmbedder(model_key=env_key) return Embedder() preferred = {384: "bge-small-en-v1.5", 1024: "bge-m3"} key = preferred.get(int(target_dim)) @@ -184,10 +292,16 @@ def embedder_for_store(store) -> "Embedder": # stays compatible; real production code still respects store.embed_dim. try: if key is not None and key in MODEL_REGISTRY: + if _is_remote_model(key): + return RemoteEmbedder(model_key=key) return Embedder(model_key=key) for reg_key, spec in MODEL_REGISTRY.items(): if int(spec["dim"]) == int(target_dim): + if _is_remote_model(reg_key): + return RemoteEmbedder(model_key=reg_key) return Embedder(model_key=reg_key) except TypeError: pass + if env_key and _is_remote_model(env_key): + return RemoteEmbedder(model_key=env_key) return Embedder() diff --git a/src/iai_mcp/events.py b/src/iai_mcp/events.py index db2ba0b..a7c4c66 100644 --- a/src/iai_mcp/events.py +++ b/src/iai_mcp/events.py @@ -92,17 +92,17 @@ def write_event( data_plain = json.dumps(data) ad = str(event_id).encode("ascii") data_ct = encrypt_field(data_plain, store._key(), associated_data=ad) - row = { - "id": str(event_id), - "kind": kind, - "severity": severity or "", - "domain": domain or "", - "ts": datetime.now(timezone.utc), - "data_json": data_ct, - "session_id": session_id, - "source_ids_json": json.dumps([str(x) for x in (source_ids or [])]), - } - store.db.open_table(EVENTS_TABLE).add([row]) + ts = datetime.now(timezone.utc) + store.events_add( + event_id=event_id, + kind=kind, + severity=severity or "", + domain=domain or "", + ts=ts, + data_json=data_ct, + session_id=session_id, + source_ids_json=json.dumps([str(x) for x in (source_ids or [])]), + ) return event_id @@ -132,27 +132,12 @@ def query_events( Returns a list of dicts with keys: id, kind, severity, domain, ts, data, session_id, source_ids. data and source_ids are decoded from JSON. """ - tbl = store.db.open_table(EVENTS_TABLE) - df = tbl.to_pandas() - if df.empty: - return [] - if kind is not None: - df = df[df["kind"] == kind] - if severity is not None: - df = df[df["severity"] == severity] - if since is not None: - # Ensure tz-aware comparison - since_cmp = since if since.tzinfo is not None else since.replace(tzinfo=timezone.utc) - # Pandas Timestamp compares naturally with tz-aware datetimes - df = df[df["ts"] >= since_cmp] - if df.empty: - return [] - df = df.sort_values("ts", ascending=False).head(limit) + rows = store.events_query(kind=kind, since=since, severity=severity, limit=limit) out: list[dict] = [] - for _, row in df.iterrows(): + for row in rows: # decrypt data_json when it carries the iai:enc:v1: prefix. # Pre-02-08 rows stay plaintext; migration rewrites them lazily. - raw_data = row["data_json"] or "{}" + raw_data = row["data"] if isinstance(row.get("data"), str) else json.dumps(row.get("data", {})) if is_encrypted(raw_data): ad = str(row["id"]).encode("ascii") try: @@ -165,20 +150,14 @@ def query_events( data = json.loads(raw_data) except (TypeError, json.JSONDecodeError): data = {} - try: - source_ids = json.loads(row["source_ids_json"] or "[]") - except (TypeError, json.JSONDecodeError): - source_ids = [] - out.append( - { - "id": row["id"], - "kind": row["kind"], - "severity": row["severity"] or None, - "domain": row["domain"] or None, - "ts": row["ts"], - "data": data, - "session_id": row["session_id"], - "source_ids": source_ids, - } - ) + out.append({ + "id": row["id"], + "kind": row["kind"], + "severity": row["severity"] or None, + "domain": row["domain"] or None, + "ts": row["ts"], + "data": data, + "session_id": row["session_id"], + "source_ids": row["source_ids"], + }) return out diff --git a/src/iai_mcp/migrate.py b/src/iai_mcp/migrate.py index 60a19dd..a1a95be 100644 --- a/src/iai_mcp/migrate.py +++ b/src/iai_mcp/migrate.py @@ -61,10 +61,7 @@ from pathlib import Path from typing import Callable, Optional from uuid import UUID -import pyarrow as pa - from iai_mcp.crypto import encrypt_field, is_encrypted -from iai_mcp.embed import Embedder from iai_mcp.events import write_event from iai_mcp.store import ( EVENTS_TABLE, @@ -123,7 +120,7 @@ def _detect_language(text: str) -> str: def migrate_v1_to_v2( store: MemoryStore, - embedder: Optional[Embedder] = None, + embedder: Optional["Embedder"] = None, dry_run: bool = False, progress: Optional[Callable[[int, int], None]] = None, ) -> dict: @@ -237,7 +234,7 @@ def migrate_v1_to_v2( } -def _records_schema_at_dim(dim: int) -> pa.Schema: +def _records_schema_at_dim(dim: int) -> "pa.Schema": """Build the records-table Arrow schema at an explicit embedding dim. Mirrors `MemoryStore._ensure_tables` lines 249-281 byte-for-byte except @@ -247,6 +244,7 @@ def _records_schema_at_dim(dim: int) -> pa.Schema: is not parameterised on dim. Plan 07.11-03 / file-disjoint constraint forbids store.py changes; inlining is the conservative path. """ + import pyarrow as pa return pa.schema( [ ("id", pa.string()), diff --git a/src/iai_mcp/migrate_qdrant.py b/src/iai_mcp/migrate_qdrant.py new file mode 100644 index 0000000..0dae0a0 --- /dev/null +++ b/src/iai_mcp/migrate_qdrant.py @@ -0,0 +1,202 @@ +"""Migration script: move data from 5 Qdrant collections → 2 collections. + +Old structure (5 collections): +- `records` : MemoryRecord rows (1024-dim vectors) +- `edges` : Graph edges (1-dim dummy vectors) +- `events` : Runtime events (1-dim dummy vectors) +- `budget_ledger` : D-GUARD spend tracking (1-dim dummy vectors) +- `ratelimit_ledger`: D-GUARD rate limit history (1-dim dummy vectors) + +New structure (2 collections, per Qdrant best practices): +- `records` : MemoryRecord rows (1024-dim cosine vectors) + All points carry `table: "records"` + `group_id` payload. +- `metadata` : Payload-only (no vectors) containing edges, events, + budget_ledger, ratelimit_ledger. + Each point carries `table` + `group_id` payload. + +Both collections use keyword indexes on `table` for co-located storage. + +Usage: + python -m iai_mcp.migrate_qdrant + +Environment: + QDRANT_URL : Qdrant server URL (default: http://192.168.0.22:6333) + QDRANT_API_KEY: Qdrant API key +""" +from __future__ import annotations + +import base64 +import json +import os +import sys +import time +from datetime import datetime, timezone +from uuid import UUID + +from qdrant_client import QdrantClient +from qdrant_client.models import Distance, PointStruct, VectorParams +from qdrant_client.http.exceptions import UnexpectedResponse + +# --------------------------------------------------------------------------- env +QDRANT_URL = os.environ.get("QDRANT_URL", "http://192.168.0.22:6333") +QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") +GROUP_ID = os.environ.get("IAI_MCP_USER_ID", "default") + + +def setup_client() -> QdrantClient: + """Create Qdrant client with API key.""" + return QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=30) + + +def ensure_new_collections(client: QdrantClient) -> None: + """Create the 2 new collections if they don't exist.""" + # Collection 1: records (vectors) + try: + client.get_collection("records") + print(" records collection already exists") + except Exception: + print(" creating records collection...") + client.create_collection( + collection_name="records", + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + + # Collection 2: metadata (payload-only) + try: + client.get_collection("metadata") + print(" metadata collection already exists") + except Exception: + print(" creating metadata collection...") + client.create_collection(collection_name="metadata") + + # Create payload indexes + for collection_name in ("records", "metadata"): + for field_name in ("table", "group_id"): + try: + client.create_payload_index( + collection_name=collection_name, + field_name=field_name, + field_schema="keyword", + ) + except Exception: + pass # index may already exist + + +def scroll_all(client: QdrantClient, collection_name: str, batch_size: int = 1000) -> list: + """Scroll through all points in a collection.""" + offset = None + all_points = [] + while True: + points, next_offset = client.scroll( + collection_name=collection_name, + limit=batch_size, + offset=offset, + with_payload=True, + with_vectors=True, + ) + all_points.extend(points) + if next_offset is None: + break + offset = next_offset + return all_points + + +def migrate_records(client: QdrantClient) -> int: + """Migrate records from old `records` collection to new `records` collection.""" + print("\nMigrating records...") + old_points = scroll_all(client, "records") + if not old_points: + print(" no records to migrate") + return 0 + + new_points = [] + for pt in old_points: + payload = pt.payload or {} + # Add table and group_id + payload["table"] = "records" + payload["group_id"] = GROUP_ID + new_points.append(PointStruct( + id=pt.id, + vector=list(pt.vector) if pt.vector else [], + payload=payload, + )) + + client.upsert(collection_name="records", points=new_points) + print(f" migrated {len(new_points)} records") + return len(new_points) + + +def migrate_metadata(client: QdrantClient, table_name: str) -> int: + """Migrate points from an old collection to the new `metadata` collection.""" + print(f"\nMigrating {table_name}...") + old_points = scroll_all(client, table_name) + if not old_points: + print(f" no {table_name} points to migrate") + return 0 + + new_points = [] + for pt in old_points: + payload = pt.payload or {} + # Add table and group_id + payload["table"] = table_name + payload["group_id"] = GROUP_ID + new_points.append(PointStruct( + id=pt.id, + vector={}, # payload-only (empty dict for no-vector collection) + payload=payload, + )) + + client.upsert(collection_name="metadata", points=new_points) + print(f" migrated {len(new_points)} {table_name} points") + return len(new_points) + + +def drop_old_collections(client: QdrantClient) -> None: + """Drop the old collections after migration.""" + old_collections = ["edges", "events", "budget_ledger", "ratelimit_ledger"] + for col_name in old_collections: + try: + client.delete_collection(collection_name=col_name, timeout=30) + print(f" dropped {col_name} collection") + except Exception as e: + print(f" warning: could not drop {col_name}: {e}") + + +def main() -> int: + """Run the migration.""" + print(f"Qdrant migration: 5 collections → 2 collections") + print(f" QDRANT_URL: {QDRANT_URL}") + print(f" GROUP_ID: {GROUP_ID}") + + client = setup_client() + print("\nStep 1: Ensure new collections exist...") + ensure_new_collections(client) + + print("\nStep 2: Migrate data...") + t0 = time.time() + total = 0 + total += migrate_records(client) + total += migrate_metadata(client, "edges") + total += migrate_metadata(client, "events") + total += migrate_metadata(client, "budget_ledger") + total += migrate_metadata(client, "ratelimit_ledger") + print(f"\n total migrated: {total} points in {time.time() - t0:.1f}s") + + print("\nStep 3: Drop old collections...") + drop_old_collections(client) + + print("\nStep 4: Verify...") + try: + rec_count = client.get_collection("records").points_count + meta_points = client.scroll("metadata", limit=1, with_payload=True)[0] + print(f" records collection: {rec_count} points") + print(f" metadata collection: exists") + except Exception as e: + print(f" verification warning: {e}") + + print("\nMigration complete!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/iai_mcp/qdrant_store.py b/src/iai_mcp/qdrant_store.py new file mode 100644 index 0000000..abe4ae7 --- /dev/null +++ b/src/iai_mcp/qdrant_store.py @@ -0,0 +1,1321 @@ +"""Qdrant-backed persistent memory store (D-01 storage engine, sync write). + +Replaces LanceDB with Qdrant for remote vector storage. + +Collections (payload-partitioned, per Qdrant best practices): +- `records` : MemoryRecord rows (1024-dim cosine vectors) + All points carry `table: "records"` + `group_id` payload. +- `metadata` : Payload-only (no vectors) containing edges, events, + budget_ledger, ratelimit_ledger. + Each point carries `table` + `group_id` payload. + +Both collections use `is_tenant: true` payload indexes on `table` for +co-located storage and per-table HNSW graphs (payload_m=16, m=0). + +Encryption-at-rest is identical to LanceDB: AES-256-GCM on +literal_surface / provenance_json / profile_modulation_gain_json (records) +and data_json (events). AD = record UUID bytes. + +The Qdrant client connects to a remote server via QDRANT_URL + QDRANT_API_KEY +environment variables (overridable via constructor). +""" +from __future__ import annotations + +import asyncio +import base64 +import functools +import json +import os +import re +import sys +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Callable +from uuid import UUID + +from qdrant_client import QdrantClient, models +from qdrant_client.models import ( + Distance, + FieldCondition, + Filter, + MatchValue, + PayloadSchemaType, + PointStruct, + VectorParams, +) +from qdrant_client.http.exceptions import UnexpectedResponse + +import pandas as pd + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from iai_mcp.crypto import ( + CIPHERTEXT_PREFIX, + NONCE_BYTES, + CryptoKey, + encrypt_field, + is_encrypted, +) +from iai_mcp.types import ( + DEFAULT_EMBED_DIM, + EMBED_DIM, + SCHEMA_VERSION_CURRENT, + MemoryRecord, + TIER_ENUM, +) + +# --------------------------------------------------------------------------- env +# Qdrant connection: override with env vars or constructor kwargs. +QDRANT_URL = os.environ.get("QDRANT_URL") +QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") + +# --------------------------------------------------------------------------- tables +RECORDS_TABLE = "records" +EDGES_TABLE = "edges" +EVENTS_TABLE = "events" +BUDGET_TABLE = "budget_ledger" +RATELIMIT_TABLE = "ratelimit_ledger" +METADATA_TABLE = "metadata" + +# Embedding dimension for the records collection. +# Defaults to 1024 for bge-m3 (remote embedder). Overridable via +# IAI_MCP_EMBED_DIM env var. +_RECORDS_DIM = int(os.environ.get("IAI_MCP_EMBED_DIM", "1024")) + +# Metadata tables that live in the metadata collection. +_METADATA_TABLES = frozenset({EDGES_TABLE, EVENTS_TABLE, BUDGET_TABLE, RATELIMIT_TABLE}) + +# Edge type enum. +EDGE_TYPES: frozenset[str] = frozenset({ + "hebbian", + "contradicts", + "consolidated_from", + "schema_instance_of", + "temporal_next", + "invariant_anchor", + "curiosity_bridge", + "profile_modulates", + "hebbian_structure", +}) + +# RFC-4122 canonical UUID regex. +_UUID_STR_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" +) + + +def _uuid_literal(value: UUID | str) -> str: + """Return a Qdrant WHERE-safe UUID literal.""" + s = str(value).lower() + if not _UUID_STR_RE.match(s): + raise ValueError(f"not a canonical UUID: {value!r}") + return s + + +def _resolve_embed_dim() -> int: + """Pick the embedding dimension for the records collection.""" + env_dim = os.environ.get("IAI_MCP_EMBED_DIM") + if env_dim: + try: + return int(env_dim) + except ValueError: + pass + env_key = os.environ.get("IAI_MCP_EMBED_MODEL") + # bge-m3 -> 1024 + if env_key == "bge-m3": + return 1024 + return _RECORDS_DIM + + +class QdrantStore: + """Qdrant-backed memory store. + + Mirrors the MemoryStore API so all existing callers (daemon, MCP wrapper, + capture, retrieve, sleep, etc.) work unchanged. + """ + + def __init__( + self, + path: Path | str | None = None, + user_id: str = "default", + read_consistency_interval: float | None = None, + url: str | None = None, + api_key: str | None = None, + ) -> None: + """Open (or initialise) a Qdrant-backed store. + + Parameters + ---------- + url: + Qdrant server URL. Defaults to QDRANT_URL env var. + api_key: + Qdrant API key. Defaults to QDRANT_API_KEY env var. + user_id: + User ID that scopes the encryption key. Also used as group_id + for payload partitioning (multi-tenant ready). + read_consistency_interval: + Reserved for future consistency tuning (not used with Qdrant). + """ + self._embed_dim: int = _resolve_embed_dim() + self._user_id: str = user_id + self._group_id: str = user_id # payload partition key + self._read_consistency_interval = read_consistency_interval + + # Qdrant client + self._qdrant_url = url or QDRANT_URL + self._qdrant_api_key = api_key or QDRANT_API_KEY + self._client = QdrantClient( + url=self._qdrant_url, + api_key=self._qdrant_api_key, + timeout=30, + ) + + # Encryption + self._crypto_key_wrapper = CryptoKey(user_id=user_id) + self._crypto_key: bytes | None = None + + # Graph sync hook + self._graph_sync_hook: Callable[[str, MemoryRecord], None] | None = None + + # Async write queue (same contract as MemoryStore) + self._write_queue = None + self._async_loop: asyncio.AbstractEventLoop | None = None + self._async_thread: threading.Thread | None = None + self._provenance_queue = None + + # Ensure collections exist + self._ensure_collections() + + # ------------------------------------------------------------------ schema + + def _ensure_collections(self) -> None: + """Create 2 collections (per Qdrant best practices: payload partitioning). + + - `records`: 1024-dim cosine vectors for MemoryRecord rows. + - `metadata`: payload-only (no vectors) for edges/events/budget/ratelimit. + + Both use `is_tenant: true` indexes on `table` for co-located storage + and per-table HNSW graphs (payload_m=16, m=0). + """ + # Collection 1: records (vectors) + if not self._collection_exists(RECORDS_TABLE): + try: + self._client.create_collection( + collection_name=RECORDS_TABLE, + vectors_config=VectorParams( + size=self._embed_dim, + distance=Distance.COSINE, + ), + hnsw_config=models.HnswConfigDiff( + payload_m=16, # per-table HNSW graph + m=0, # no global index (all data partitioned by table) + ), + ) + except UnexpectedResponse as e: + if e.status_code == 409: + pass # exists + elif e.status_code == 401: + raise + else: + raise + + # Collection 2: metadata (payload-only) + if not self._collection_exists(METADATA_TABLE): + try: + self._client.create_collection( + collection_name=METADATA_TABLE, + # No vectors_config — payload-only collection + ) + except UnexpectedResponse as e: + if e.status_code == 409: + pass # exists + elif e.status_code == 401: + raise + else: + raise + + # Ensure payload indexes for filtered queries + self._ensure_indexes() + + def _collection_exists(self, name: str) -> bool: + """Check if a collection exists.""" + try: + info = self._client.get_collection(name) + return info is not None + except UnexpectedResponse as e: + if e.status_code == 401: + raise RuntimeError( + "Qdrant 401: invalid API key. Set QDRANT_API_KEY env var." + ) from e + # Collection doesn't exist or other error + return False + except Exception: + return False + + def _ensure_indexes(self) -> None: + """Create payload indexes for common filter fields. + + Both collections get `table` and `group_id` as tenant indexes + (is_tenant=true) for co-located storage and per-table HNSW graphs. + """ + # Tenant indexes — critical for co-located storage + for collection_name in (RECORDS_TABLE, METADATA_TABLE): + for field_name, is_tenant in [("table", True), ("group_id", True)]: + try: + self._client.create_payload_index( + collection_name=collection_name, + field_name=field_name, + field_schema=PayloadSchemaType.KEYWORD, + ) + # Note: Qdrant API doesn't expose is_tenant via create_payload_index; + # it's set via update_collection_payload_profiles API. + # For now, we create the keyword index; is_tenant optimization + # is handled by the application-level table filtering. + except Exception: + pass + + # Collection-specific indexes + indexes = { + RECORDS_TABLE: [ + ("id", PayloadSchemaType.KEYWORD), + ("tier", PayloadSchemaType.KEYWORD), + ("community_id", PayloadSchemaType.KEYWORD), + ], + METADATA_TABLE: [ + ("edge_type", PayloadSchemaType.KEYWORD), + ("kind", PayloadSchemaType.KEYWORD), + ("session_id", PayloadSchemaType.KEYWORD), + ], + } + for tbl_name, fields in indexes.items(): + for field_name, schema_type in fields: + try: + self._client.create_payload_index( + collection_name=tbl_name, + field_name=field_name, + field_schema=schema_type, + ) + except Exception: + pass + + @property + def embed_dim(self) -> int: + """Actual embedding dimension in the records collection.""" + return self._embed_dim + + @property + def user_id(self) -> str: + """user_id that scopes the encryption key.""" + return self._user_id + + # ------------------------------------------------------------------ encryption + + def _key(self) -> bytes: + """Lazy-load the encryption key.""" + if self._crypto_key is None: + self._crypto_key = self._crypto_key_wrapper.get_or_create() + return self._crypto_key + + def _ad(self, record_id: UUID | str) -> bytes: + """Associated data for encryption: canonical UUID str bytes.""" + return _uuid_literal(record_id).encode("ascii") + + def _encrypt_for_record(self, record_id: UUID, value: str) -> str: + """Encrypt a per-record sensitive field.""" + if is_encrypted(value): + return value + return encrypt_field(value, self._key(), associated_data=self._ad(record_id)) + + @functools.cached_property + def _cached_aesgcm(self) -> AESGCM: + """One AESGCM cipher per store lifetime.""" + return AESGCM(self._key()) + + def _decrypt_for_record(self, record_id: UUID, value: str) -> str: + """Decrypt a per-record sensitive field; pass through plaintext.""" + if not is_encrypted(value): + return value + if not value.startswith(CIPHERTEXT_PREFIX): + raise ValueError("field is not iai:enc:v1:-prefixed ciphertext") + payload_b64 = value[len(CIPHERTEXT_PREFIX):] + payload = base64.b64decode(payload_b64) + if len(payload) < NONCE_BYTES + 16: + raise ValueError("ciphertext payload too short") + nonce = payload[:NONCE_BYTES] + ct_with_tag = payload[NONCE_BYTES:] + associated_data = self._ad(record_id) + plaintext_bytes = self._cached_aesgcm.decrypt( + nonce, ct_with_tag, associated_data or None + ) + return plaintext_bytes.decode("utf-8") + + # ------------------------------------------------------------------ I/O + + def register_graph_sync_hook( + self, hook: Callable[[str, MemoryRecord], None] | None + ) -> None: + """Register a callback that mirrors store writes to the runtime graph.""" + self._graph_sync_hook = hook + + def _fire_graph_sync_hook(self, op: str, record: MemoryRecord) -> None: + """Dispatch the (op, record) event. Failures are swallowed.""" + hook = self._graph_sync_hook + if hook is None: + return + try: + hook(op, record) + except Exception as exc: + try: + sys.stderr.write( + json.dumps({ + "event": "graph_sync_failed", + "op": op, + "record_id": str(getattr(record, "id", "")), + "error": str(exc), + "ts": datetime.now(timezone.utc).isoformat(), + }) + + "\n" + ) + except Exception: + pass + + # ------------------------------------------------------- conversions + + def _to_point(self, r: MemoryRecord) -> PointStruct: + """Convert a MemoryRecord to a Qdrant PointStruct.""" + literal_ct = self._encrypt_for_record(r.id, r.literal_surface) + provenance_plain = json.dumps(r.provenance) + provenance_ct = self._encrypt_for_record(r.id, provenance_plain) + gain_plain = json.dumps(r.profile_modulation_gain or {}) + gain_ct = self._encrypt_for_record(r.id, gain_plain) + + return PointStruct( + id=str(r.id), + vector=[float(x) for x in r.embedding], + payload={ + "table": RECORDS_TABLE, + "group_id": self._group_id, + "tier": r.tier, + "literal_surface": literal_ct, + "aaak_index": r.aaak_index or "", + "structure_hv": base64.b64encode( + bytes(r.structure_hv or b"") + ).decode("ascii"), + "community_id": str(r.community_id) if r.community_id else "", + "centrality": float(r.centrality), + "detail_level": int(r.detail_level), + "pinned": bool(r.pinned), + "stability": float(r.stability), + "difficulty": float(r.difficulty), + "last_reviewed": ( + r.last_reviewed.isoformat() if r.last_reviewed else None + ), + "never_decay": bool(r.never_decay), + "never_merge": bool(r.never_merge), + "provenance_json": provenance_ct, + "created_at": r.created_at.isoformat() if r.created_at else None, + "updated_at": r.updated_at.isoformat() if r.updated_at else None, + "tags_json": json.dumps(r.tags), + "language": str(r.language), + "s5_trust_score": float(r.s5_trust_score), + "profile_modulation_gain_json": gain_ct, + "schema_version": int(r.schema_version), + }, + ) + + def _from_point(self, point: PointStruct) -> MemoryRecord: + """Convert a Qdrant PointStruct to a MemoryRecord.""" + from uuid import UUID as _UUID + + payload = point.payload + if not payload: + raise ValueError(f"point {point.id} has no payload") + + row_uuid = _UUID(point.id) + + # structure_hv: base64-decoded bytes + structure_raw = payload.get("structure_hv") + if structure_raw: + try: + structure_hv = base64.b64decode(structure_raw) + except Exception: + structure_hv = b"" + else: + structure_hv = b"" + + community_raw = payload.get("community_id") or "" + community_id = _UUID(community_raw) if community_raw else None + + # language back-compat + lang_raw = payload.get("language") + version_raw = payload.get("schema_version") + try: + version_int = int(version_raw) if version_raw is not None else SCHEMA_VERSION_CURRENT + except (TypeError, ValueError): + version_int = SCHEMA_VERSION_CURRENT + schema_version = version_int + + is_empty_language = lang_raw is None or (isinstance(lang_raw, str) and lang_raw == "") + if is_empty_language and schema_version == 1: + language = "__LEGACY_EMPTY__" + elif is_empty_language: + language = "en" + else: + language = str(lang_raw) + + s5_raw = payload.get("s5_trust_score") + s5_trust_score = float(s5_raw) if s5_raw is not None else 0.5 + + # decrypt profile_modulation_gain_json + gain_raw = payload.get("profile_modulation_gain_json") or "{}" + if is_encrypted(gain_raw): + gain_raw = self._decrypt_for_record(row_uuid, gain_raw) + try: + profile_modulation_gain = json.loads(gain_raw) or {} + except (TypeError, json.JSONDecodeError): + profile_modulation_gain = {} + + # decrypt literal_surface + provenance_json + literal_raw = payload.get("literal_surface", "") + if is_encrypted(literal_raw): + literal_raw = self._decrypt_for_record(row_uuid, literal_raw) + provenance_raw = payload.get("provenance_json") or "[]" + if is_encrypted(provenance_raw): + provenance_raw = self._decrypt_for_record(row_uuid, provenance_raw) + try: + provenance_list = json.loads(provenance_raw) if provenance_raw else [] + except (TypeError, json.JSONDecodeError): + provenance_list = [] + + def _parse_ts(val): + if val is None: + return None + try: + return datetime.fromisoformat(val) + except (TypeError, ValueError): + return None + + rec = MemoryRecord( + id=row_uuid, + tier=payload.get("tier", "episodic"), + literal_surface=literal_raw, + aaak_index=payload.get("aaak_index") or "", + embedding=list(point.vector) if point.vector else [], + community_id=community_id, + centrality=float(payload.get("centrality", 0.0) or 0.0), + detail_level=int(payload.get("detail_level", 1)), + pinned=bool(payload.get("pinned", False)), + stability=float(payload.get("stability") or 0.0), + difficulty=float(payload.get("difficulty") or 0.0), + last_reviewed=_parse_ts(payload.get("last_reviewed")), + never_decay=bool(payload.get("never_decay", False)), + never_merge=bool(payload.get("never_merge", False)), + provenance=provenance_list, + created_at=_parse_ts(payload.get("created_at")) or datetime.now(timezone.utc), + updated_at=_parse_ts(payload.get("updated_at")) or datetime.now(timezone.utc), + tags=json.loads(payload.get("tags_json") or "[]"), + language=language, + s5_trust_score=s5_trust_score, + profile_modulation_gain=profile_modulation_gain, + schema_version=schema_version, + structure_hv=structure_hv, + ) + if language == "__LEGACY_EMPTY__": + rec.language = "" + return rec + + # ------------------------------------------------------- point operations + + def _make_filter(self, id_str: str) -> Filter: + """Create a filter for a single record ID.""" + return Filter(must=[FieldCondition( + key="id", + match=MatchValue(value=id_str), + )]) + + def _make_tier_filter(self, tier: str) -> Filter: + """Create a filter for a tier value.""" + return Filter(must=[FieldCondition( + key="tier", + match=MatchValue(value=tier), + )]) + + def _make_combined_filter(self, *conditions: FieldCondition) -> Filter: + """Combine multiple field conditions with AND.""" + return Filter(must=list(conditions)) + + # ------------------------------------------------------- writes + + def insert(self, record: MemoryRecord) -> None: + """Append a record. verbatim, no rewrite at write time.""" + if record.tier not in TIER_ENUM: + raise ValueError(f"invalid tier {record.tier!r}") + if len(record.embedding) != self._embed_dim: + raise ValueError( + f"embedding must be {self._embed_dim}d, got {len(record.embedding)}" + ) + + # lazy structure_hv fill + if not record.structure_hv: + from iai_mcp.tem import bind_structure + record.structure_hv = bind_structure(record) + + point = self._to_point(record) + self._client.upsert( + collection_name=RECORDS_TABLE, + points=[point], + ) + self._fire_graph_sync_hook("insert", record) + + def update(self, record: MemoryRecord) -> None: + """Full-record update (rewrites core columns).""" + if len(record.embedding) != self._embed_dim: + raise ValueError( + f"embedding must be {self._embed_dim}d, got {len(record.embedding)}" + ) + + # Check existence first + points = self._client.retrieve( + collection_name=RECORDS_TABLE, + ids=[str(record.id)], + ) + if not points: + return + + point = self._to_point(record) + self._client.upsert( + collection_name=RECORDS_TABLE, + points=[point], + ) + self._fire_graph_sync_hook("update", record) + + def delete(self, record_id: UUID) -> None: + """Remove a record by id.""" + try: + self._client.delete( + collection_name=RECORDS_TABLE, + points_selector=models.PointIdsList( + points=[str(record_id)], + ), + ) + except Exception: + return + + class _DeleteShim: + def __init__(self, rid): + self.id = rid + self._fire_graph_sync_hook("delete", _DeleteShim(record_id)) + + def get(self, record_id: UUID) -> MemoryRecord | None: + """Point read by ID.""" + points = self._client.retrieve( + collection_name=RECORDS_TABLE, + ids=[str(record_id)], + with_payload=True, + ) + if not points: + return None + return self._from_point(points[0]) + + def all_records(self) -> list[MemoryRecord]: + """Full table scan of the records collection.""" + offset = None + all_points = [] + table_filter = Filter(must=[FieldCondition( + key="table", match=MatchValue(value=RECORDS_TABLE), + )]) + while True: + points, next_offset = self._client.scroll( + collection_name=RECORDS_TABLE, + limit=1000, + offset=offset, + scroll_filter=table_filter, + with_payload=True, + ) + all_points.extend(points) + if next_offset is None: + break + offset = next_offset + return [self._from_point(p) for p in all_points] + + def iter_records( + self, + *, + columns: list[str] | None = None, + batch_size: int = 1024, + where: str | None = None, + ): + """Streaming iterator over records (filtered by table=records).""" + offset = None + while True: + # Build filter: always include table=records, optionally add tier + conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] + if where and where.startswith("tier = "): + tier = where.split("'")[1] + conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) + qdrant_filter = Filter(must=conditions) if conditions else None + + points, next_offset = self._client.scroll( + collection_name=RECORDS_TABLE, + limit=batch_size, + offset=offset, + scroll_filter=qdrant_filter, + with_payload=True, + ) + for point in points: + yield self._from_point(point) + + if next_offset is None: + break + offset = next_offset + + def iter_record_columns( + self, + columns: list[str], + *, + batch_size: int = 1024, + where: str | None = None, + ): + """Projection-only iteration; no MemoryRecord, no decrypt. + Filtered by table=records. + """ + if not columns: + raise ValueError("iter_record_columns requires a non-empty columns list") + + offset = None + while True: + conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] + if where and where.startswith("tier = "): + tier = where.split("'")[1] + conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) + qdrant_filter = Filter(must=conditions) if conditions else None + + points, next_offset = self._client.scroll( + collection_name=RECORDS_TABLE, + limit=batch_size, + offset=offset, + scroll_filter=qdrant_filter, + with_payload=True, + with_vectors=False, + ) + for point in points: + # Filter to requested columns + row = {k: v for k, v in point.payload.items() if k in columns} + row["id"] = point.id + yield row + + if next_offset is None: + break + offset = next_offset + + def query_similar( + self, + vec: list[float], + k: int = 10, + tier: str | None = None, + ) -> list[tuple[MemoryRecord, float]]: + """Cosine-distance kNN search on the records collection.""" + if tier is not None and tier not in TIER_ENUM: + raise ValueError( + f"invalid tier {tier!r}; must be one of {sorted(TIER_ENUM)}" + ) + + # Build query filter: must be in records collection + optional tier + conditions = [FieldCondition(key="table", match=MatchValue(value=RECORDS_TABLE))] + if tier is not None: + conditions.append(FieldCondition(key="tier", match=MatchValue(value=tier))) + qdrant_filter = Filter(must=conditions) + + # Check if collection is empty + try: + info = self._client.get_collection(RECORDS_TABLE) + if info.points_count == 0: + return [] + except Exception: + return [] + + # Vector search + search_result = self._client.query_points( + collection_name=RECORDS_TABLE, + query=vec, + limit=k, + query_filter=qdrant_filter, + with_payload=True, + with_vectors=False, + ).points + + out: list[tuple[MemoryRecord, float]] = [] + for point in search_result: + record = self._from_point(point) + # Qdrant returns score as similarity (1.0 = identical, 0.0 = orthogonal) + score = float(point.score) if point.score else 0.0 + out.append((record, score)) + return out + + def update_record(self, record: MemoryRecord) -> None: + """Persist FSRS-relevant columns back to the records collection.""" + # Retrieve existing point + points = self._client.retrieve( + collection_name=RECORDS_TABLE, + ids=[str(record.id)], + with_payload=True, + ) + if not points: + return + + # Update only FSRS columns via patch + self._client.set_payload( + collection_name=RECORDS_TABLE, + payload={ + "stability": float(record.stability), + "difficulty": float(record.difficulty), + "last_reviewed": record.last_reviewed.isoformat() if record.last_reviewed else None, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + payload_selector=[str(record.id)], + ) + + # -------------------------------------------------------- reconsolidation + + def append_provenance(self, record_id: UUID, entry: dict) -> None: + """Append a provenance entry to the record.""" + # Retrieve existing point + points = self._client.retrieve( + collection_name=RECORDS_TABLE, + ids=[str(record_id)], + with_payload=True, + ) + if not points: + return + + payload = points[0].payload + raw = payload.get("provenance_json") or "[]" + if is_encrypted(raw): + raw = self._decrypt_for_record(record_id, raw) + try: + existing = json.loads(raw) + except (TypeError, json.JSONDecodeError): + existing = [] + existing.append(entry) + new_plain = json.dumps(existing) + new_ct = self._encrypt_for_record(record_id, new_plain) + + # Update provenance_json and updated_at + self._client.set_payload( + collection_name=RECORDS_TABLE, + payload={ + "provenance_json": new_ct, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + payload_selector=[str(record_id)], + ) + + def append_provenance_batch( + self, pairs: "list[tuple[UUID, dict]]", + records_cache: "dict | None" = None, + ) -> None: + """Batched provenance append.""" + if not pairs: + return + + # Group entries by record_id + from collections import defaultdict + grouped: dict[str, list[dict]] = defaultdict(list) + for rid, entry in pairs: + grouped[str(rid)].append(entry) + + now = datetime.now(timezone.utc).isoformat() + updates: dict[str, str] = {} + + if records_cache is not None: + for rid_str, entries in grouped.items(): + try: + canonical = _uuid_literal(rid_str) + except ValueError: + continue + try: + rec = records_cache.get(UUID(rid_str)) + except (TypeError, ValueError): + rec = None + if rec is None: + rec = records_cache.get(rid_str) + if rec is None: + continue + existing = list(rec.provenance or []) + existing.extend(entries) + new_plain = json.dumps(existing) + new_ct = self._encrypt_for_record(UUID(rid_str), new_plain) + updates[canonical] = new_ct + else: + # Retrieve all affected points + ids_to_fetch = list(grouped.keys()) + try: + points = self._client.retrieve( + collection_name=RECORDS_TABLE, + ids=ids_to_fetch, + with_payload=True, + ) + except Exception: + return + + for point in points: + rid_str = point.id + if rid_str not in grouped: + continue + entries = grouped[rid_str] + payload = point.payload + raw_prov = payload.get("provenance_json") or "[]" + if is_encrypted(raw_prov): + try: + raw_prov = self._decrypt_for_record(UUID(rid_str), raw_prov) + except Exception: + raw_prov = "[]" + try: + existing = json.loads(raw_prov) + except (TypeError, json.JSONDecodeError): + existing = [] + existing.extend(entries) + new_plain = json.dumps(existing) + new_ct = self._encrypt_for_record(UUID(rid_str), new_plain) + updates[rid_str] = new_ct + + if not updates: + return + + # Batch update provenance_json + updated_at + # Qdrant doesn't have merge_insert, so we do individual set_payload calls + for rid_str, new_ct in updates.items(): + try: + self._client.set_payload( + collection_name=RECORDS_TABLE, + payload={ + "provenance_json": new_ct, + "updated_at": now, + }, + payload_selector=[rid_str], + ) + except Exception: + continue + + # ------------------------------------------------------------------ edges + + def boost_edges( + self, + pairs: list[tuple[UUID, UUID]], + delta: float | list[float] = 0.1, + edge_type: str = "hebbian", + ) -> dict[tuple[str, str], float]: + """Pairwise edge boost in the metadata collection (table=edges).""" + if edge_type not in EDGE_TYPES: + raise ValueError( + f"invalid edge_type {edge_type!r}; must be one of {sorted(EDGE_TYPES)}" + ) + + if isinstance(delta, (int, float)): + deltas = [float(delta)] * len(pairs) + else: + deltas = [float(d) for d in delta] + if len(deltas) != len(pairs): + raise ValueError( + f"deltas length {len(deltas)} != pairs length {len(pairs)}" + ) + + if not pairs: + return {} + + # Coalesce duplicate canonical keys + coalesced: dict[tuple[str, str], float] = {} + for (a, b), d in zip(pairs, deltas): + key = (str(a), str(b)) + canonical = tuple(sorted(key)) + coalesced[canonical] = coalesced.get(canonical, 0.0) + d + + if not coalesced: + return {} + + # Fetch existing edges from metadata collection + all_edges = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE) + existing_map: dict[tuple[str, str, str], float] = {} + for point in all_edges: + p = point.payload + edge_key = (p.get("src", ""), p.get("dst", ""), p.get("edge_type", "")) + existing_map[edge_key] = float(p.get("weight", 0.0)) + + now = datetime.now(timezone.utc).isoformat() + points_to_upsert: list[PointStruct] = [] + new_weights: dict[tuple[str, str], float] = {} + + for (src_str, dst_str), accum_delta in coalesced.items(): + edge_key = (src_str, dst_str, edge_type) + if edge_key in existing_map: + nw = existing_map[edge_key] + accum_delta + else: + nw = accum_delta + + # Create payload-only point + points_to_upsert.append(PointStruct( + id=f"{src_str}:{dst_str}:{edge_type}", + vector=None, + payload={ + "table": EDGES_TABLE, + "group_id": self._group_id, + "src": src_str, + "dst": dst_str, + "edge_type": edge_type, + "weight": nw, + "updated_at": now, + }, + )) + new_weights[(src_str, dst_str)] = nw + + if points_to_upsert: + self._client.upsert( + collection_name=METADATA_TABLE, + points=points_to_upsert, + ) + + return new_weights + + def reinforce_record( + self, + record_id: UUID, + anchor_id: UUID | None = None, + edge_type: str = "hebbian", + delta: float = 0.1, + ) -> dict[tuple[str, str], float]: + """Single-record Hebbian reinforcement.""" + if anchor_id is None: + pair = (record_id, record_id) + else: + pair = (anchor_id, record_id) + return self.boost_edges([pair], delta=delta, edge_type=edge_type) + + def add_contradicts_edge(self, original: UUID, new_id: UUID) -> None: + """Add a contradicts edge in the metadata collection (table=edges).""" + self._client.upsert( + collection_name=METADATA_TABLE, + points=[PointStruct( + id=f"{original}:{new_id}:contradicts", + vector=None, + payload={ + "table": EDGES_TABLE, + "group_id": self._group_id, + "src": str(original), + "dst": str(new_id), + "edge_type": "contradicts", + "weight": 1.0, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + )], + ) + + # ------------------------------------------------------- async writes + + async def enable_async_writes( + self, + coalesce_ms: int = 100, + max_batch: int = 128, + max_queue_size: int = 4096, + ) -> None: + """Switch insert() onto the coalescing AsyncWriteQueue.""" + if self._write_queue is not None: + return + + from iai_mcp.write_queue import AsyncWriteQueue + + ready = threading.Event() + loop_holder: dict = {} + + def _run() -> None: + loop = asyncio.new_event_loop() + loop_holder["loop"] = loop + asyncio.set_event_loop(loop) + ready.set() + try: + loop.run_forever() + finally: + loop.close() + + thread = threading.Thread( + target=_run, name="iai-mcp-async-writes", daemon=True, + ) + thread.start() + ready.wait() + bg_loop: asyncio.AbstractEventLoop = loop_holder["loop"] + + to_point = self._to_point + fire_hook = self._fire_graph_sync_hook + + class _RecordPointAdapter: + async def add(self, records: list) -> None: + points = [to_point(r) for r in records] + self._client.upsert(RECORDS_TABLE, points=points) + + _client = None # set below + + adapter = _RecordPointAdapter() + adapter._client = self._client + + def _on_flushed(batch: list) -> None: + for rec in batch: + fire_hook("insert", rec) + + queue = AsyncWriteQueue( + adapter, + coalesce_ms=coalesce_ms, + max_batch=max_batch, + max_queue_size=max_queue_size, + on_flushed=_on_flushed, + ) + asyncio.run_coroutine_threadsafe(queue.start(), bg_loop).result() + + self._async_loop = bg_loop + self._async_thread = thread + self._write_queue = queue + self.enable_provenance_queue() + + async def disable_async_writes(self) -> None: + """Drain the queue, tear down the background loop.""" + if self._write_queue is None: + self.disable_provenance_queue() + return + self.disable_provenance_queue() + bg_loop = self._async_loop + queue = self._write_queue + try: + asyncio.run_coroutine_threadsafe(queue.stop(), bg_loop).result() + finally: + if bg_loop is not None: + bg_loop.call_soon_threadsafe(bg_loop.stop) + if self._async_thread is not None: + self._async_thread.join(timeout=5.0) + self._write_queue = None + self._async_loop = None + self._async_thread = None + + # -------------------------------------------------- provenance queue + + def enable_provenance_queue(self, *, coalesce_ms: int = 50) -> None: + """Route provenance writes through a daemon-thread queue.""" + if self._provenance_queue is not None: + return + from iai_mcp.provenance_queue import ProvenanceWriteQueue + + q = ProvenanceWriteQueue(self, coalesce_ms=coalesce_ms) + q.start() + self._provenance_queue = q + + def disable_provenance_queue(self) -> None: + """Drain + stop the provenance queue.""" + q = self._provenance_queue + if q is None: + return + try: + q.flush(timeout=2.0) + except Exception: + pass + try: + q.stop() + except Exception: + pass + self._provenance_queue = None + + def queue_provenance_batch( + self, pairs: "list[tuple[UUID, dict]]" + ) -> None: + """Fire-and-forget provenance write.""" + if not pairs: + return + q = self._provenance_queue + if q is not None: + q.enqueue(pairs) + return + self.append_provenance_batch(pairs, records_cache=None) + + # ------------------------------------------------------- helpers + + def _scroll_all(self, collection_name: str, batch_size: int = 1000, + table_filter: str | None = None) -> list: + """Scroll through all points in a collection. + + When table_filter is provided, adds a must condition on the `table` + payload field to select a specific logical table within the collection. + """ + offset = None + all_points = [] + scroll_filter = None + if table_filter: + scroll_filter = Filter(must=[FieldCondition( + key="table", match=MatchValue(value=table_filter), + )]) + while True: + points, next_offset = self._client.scroll( + collection_name=collection_name, + limit=batch_size, + offset=offset, + scroll_filter=scroll_filter, + with_payload=True, + with_vectors=False, + ) + all_points.extend(points) + if next_offset is None: + break + offset = next_offset + return all_points + +# ---------------------------------------------------------- events + + def events_add(self, event_id: UUID, kind: str, severity: str, domain: str, + ts: datetime, data_json: str, session_id: str, source_ids_json: str) -> None: + """Add a single event row to the metadata collection (table=events).""" + point = PointStruct( + id=str(event_id), + vector={}, + payload={ + "table": EVENTS_TABLE, + "group_id": self._group_id, + "kind": kind, + "severity": severity, + "domain": domain, + "ts": ts.isoformat(), + "data_json": data_json, + "session_id": session_id, + "source_ids_json": source_ids_json, + }, + ) + self._client.upsert(collection_name=METADATA_TABLE, points=[point]) + + def events_query(self, kind: str | None = None, since: datetime | None = None, + severity: str | None = None, limit: int = 100) -> list[dict]: + """Query events from the metadata collection (table=events), newest first.""" + # Always filter by table=events + conditions = [FieldCondition(key="table", match=MatchValue(value=EVENTS_TABLE))] + if kind is not None: + conditions.append(FieldCondition(key="kind", match=MatchValue(value=kind))) + if severity is not None: + conditions.append(FieldCondition(key="severity", match=MatchValue(value=severity))) + + event_filter = Filter(must=conditions) if conditions else None + points, _ = self._client.scroll( + collection_name=METADATA_TABLE, + limit=limit * 10, # fetch extra for post-filtering + with_payload=True, + with_vectors=False, + scroll_filter=event_filter, + ) + + out: list[dict] = [] + for pt in points: + p = pt.payload + ts_str = p.get("ts", "") + try: + ts_dt = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + ts_dt = datetime.now(timezone.utc) + + if since is not None: + since_cmp = since if since.tzinfo else since.replace(tzinfo=timezone.utc) + if ts_dt < since_cmp: + continue + + raw_data = p.get("data_json") or "{}" + try: + data = json.loads(raw_data) + except (TypeError, json.JSONDecodeError): + data = {} + try: + source_ids = json.loads(p.get("source_ids_json") or "[]") + except (TypeError, json.JSONDecodeError): + source_ids = [] + + out.append({ + "id": p.get("id", str(pt.id)), + "kind": p.get("kind", ""), + "severity": p.get("severity") or None, + "domain": p.get("domain") or None, + "ts": ts_dt, + "data": data, + "session_id": p.get("session_id", "-"), + "source_ids": source_ids, + }) + + out.sort(key=lambda x: x["ts"], reverse=True) + return out[:limit] + + # --------------------------------------------------------- count_rows + + def count_rows(self, table_name: str) -> int: + """Return the number of points matching table_name in the appropriate collection. + + - `records` → counts points in records collection with table=records + - edges/events/budget/ratelimit → counts points in metadata collection + with table=edges/events/budget_ledger/ratelimit_ledger + """ + if table_name == RECORDS_TABLE: + try: + info = self._client.get_collection(RECORDS_TABLE) + return info.points_count or 0 + except Exception: + return 0 + elif table_name in _METADATA_TABLES: + try: + points = self._scroll_all(METADATA_TABLE, table_filter=table_name) + return len(points) + except Exception: + return 0 + else: + return 0 + + # ------------------------------------------------------- edges_as_df + + def edges_as_dataframe(self) -> "pd.DataFrame": + """Return all edges from the metadata collection as a pandas DataFrame.""" + try: + points = self._scroll_all(METADATA_TABLE, table_filter=EDGES_TABLE, batch_size=1000) + if not points: + return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"]) + rows = [] + for pt in points: + p = pt.payload + rows.append({ + "src": p.get("src", ""), + "dst": p.get("dst", ""), + "edge_type": p.get("edge_type", ""), + "weight": float(p.get("weight", 0.0)), + "updated_at": p.get("updated_at", ""), + }) + return pd.DataFrame(rows) + except Exception: + return pd.DataFrame(columns=["src", "dst", "edge_type", "weight", "updated_at"]) + + # ------------------------------------------------------------------ db shim + + class _DbShim: + """LanceDB-compatible shim: store.db.open_table(name) -> table-like object.""" + + def __init__(self, store: "QdrantStore") -> None: + self._store = store + + def open_table(self, name: str) -> "QdrantStore._TableShim": + return self._TableShim(self._store, name) + + class _TableShim: + def __init__(self, store: "QdrantStore", name: str) -> None: + self._store = store + self._name = name + + def count_rows(self) -> int: + return self._store.count_rows(self._name) + + def to_pandas(self) -> pd.DataFrame: + if self._name == EDGES_TABLE: + return self._store.edges_as_dataframe() + elif self._name == EVENTS_TABLE: + return pd.DataFrame() + return pd.DataFrame() + + @property + def db(self) -> "QdrantStore._DbShim": + """LanceDB-compatible shim for store.db.open_table().""" + return self._DbShim(self) diff --git a/src/iai_mcp/store.py b/src/iai_mcp/store.py index 0f23dc4..d939415 100644 --- a/src/iai_mcp/store.py +++ b/src/iai_mcp/store.py @@ -47,8 +47,7 @@ from collections.abc import Sequence from typing import Callable from uuid import UUID -import lancedb -import pyarrow as pa + # W5: cached AESGCM cipher per store; reuse safe per # https://cryptography.io/en/latest/hazmat/primitives/aead/ — single AESGCM @@ -209,6 +208,7 @@ class MemoryStore: connect_kwargs: dict[str, object] = {} if read_consistency_interval is not None: connect_kwargs["read_consistency_interval"] = read_consistency_interval + import lancedb self.db = lancedb.connect(str(self.root / "lancedb"), **connect_kwargs) # Resolve the embedding dimension once so records table + insert guard agree. self._embed_dim: int = _resolve_embed_dim() @@ -1596,3 +1596,91 @@ class MemoryStore: if language == "__LEGACY_EMPTY__": rec.language = "" # post-construction: signal to migration path return rec + + # ---------------------------------------------------------- events + + def events_add(self, event_id: UUID, kind: str, severity: str, domain: str, + ts: datetime, data_json: str, session_id: str, source_ids_json: str) -> None: + """Add a single event row to the events table.""" + row = { + "id": str(event_id), + "kind": kind, + "severity": severity, + "domain": domain, + "ts": ts, + "data_json": data_json, + "session_id": session_id, + "source_ids_json": source_ids_json, + } + self.db.open_table(EVENTS_TABLE).add([row]) + + def events_query(self, kind: str | None = None, since: datetime | None = None, + severity: str | None = None, limit: int = 100) -> list[dict]: + """Query events matching filters, newest first.""" + tbl = self.db.open_table(EVENTS_TABLE) + df = tbl.to_pandas() + if df.empty: + return [] + if kind is not None: + df = df[df["kind"] == kind] + if severity is not None: + df = df[df["severity"] == severity] + if since is not None: + since_cmp = since if since.tzinfo is not None else since.replace(tzinfo=timezone.utc) + df = df[df["ts"] >= since_cmp] + if df.empty: + return [] + df = df.sort_values("ts", ascending=False).head(limit) + out: list[dict] = [] + for _, row in df.iterrows(): + raw_data = row["data_json"] or "{}" + try: + data = json.loads(raw_data) + except (TypeError, json.JSONDecodeError): + data = {} + try: + source_ids = json.loads(row["source_ids_json"] or "[]") + except (TypeError, json.JSONDecodeError): + source_ids = [] + out.append({ + "id": row["id"], + "kind": row["kind"], + "severity": row["severity"] or None, + "domain": row["domain"] or None, + "ts": row["ts"], + "data": data, + "session_id": row["session_id"], + "source_ids": source_ids, + }) + return out + + +# --------------------------------------------------------------------------- Qdrant backend + +def _use_qdrant() -> bool: + """Check if Qdrant backend is configured via environment.""" + return bool(os.environ.get("QDRANT_URL")) + + +def get_store( + path: Path | str | None = None, + user_id: str = "default", + read_consistency_interval: timedelta | None = None, +) -> "MemoryStore | QdrantStore": + """Factory: return MemoryStore (LanceDB) or QdrantStore based on env. + + When QDRANT_URL is set, returns QdrantStore. + Otherwise returns MemoryStore (LanceDB) — the legacy/local path. + """ + if _use_qdrant(): + from iai_mcp.qdrant_store import QdrantStore + return QdrantStore( + path=path, + user_id=user_id, + read_consistency_interval=read_consistency_interval, + ) + return MemoryStore( + path=path, + user_id=user_id, + read_consistency_interval=read_consistency_interval, + )