diff --git a/src/iai_mcp/qdrant_store.py b/src/iai_mcp/qdrant_store.py index 99e674c..8668535 100644 --- a/src/iai_mcp/qdrant_store.py +++ b/src/iai_mcp/qdrant_store.py @@ -4,10 +4,10 @@ Replaces LanceDB with Qdrant for remote vector storage. Collections (payload-partitioned, per Qdrant best practices): - `records` : MemoryRecord rows (1024-dim cosine vectors) - All points carry `table: "records"` + `group_id` payload. + All points carry `table: "records"` + `group_id` payload. - `metadata` : Payload-only (no vectors) containing edges, events, - budget_ledger, ratelimit_ledger. - Each point carries `table` + `group_id` payload. + budget_ledger, ratelimit_ledger. + Each point carries `table` + `group_id` payload. Both collections use `is_tenant: true` payload indexes on `table` for co-located storage and per-table HNSW graphs (payload_m=16, m=0). @@ -18,6 +18,18 @@ and data_json (events). AD = record UUID bytes. The Qdrant client connects to a remote server via QDRANT_URL + QDRANT_API_KEY environment variables (overridable via constructor). + +Shim interface (_TableShim): +- ``add()`` routes by table name to dedicated upsert methods for all 5 tables + (EDGES_TABLE, EVENTS_TABLE, BUDGET_TABLE, RATELIMIT_TABLE, RECORDS_TABLE). +- ``to_pandas()`` constructs full DataFrames for edges, events, budget, and + ratelimit tables (records uses records_as_dataframe). +- ``_parse_where()`` translates LanceDB-style SQL clauses into Qdrant Filter: + ``>`` / ``>=`` / ``<`` / ``<=`` emit Range conditions; ``=`` emits + MatchValue; multiple clauses joined with AND. +- ``events_query()`` scrolls all matching events and applies Python-side + filtering so events lacking ts_epoch (pre-migration rows) are never + silently dropped. Accepts both str and datetime for the since parameter. """ from __future__ import annotations @@ -661,7 +673,7 @@ class QdrantStore: limit=batch_size, offset=offset, scroll_filter=qdrant_filter, - with_payload=True, + with_payload=columns if columns else True, ) for point in points: yield self._from_point(point) @@ -702,7 +714,8 @@ class QdrantStore: for point in points: # Filter to requested columns row = {k: v for k, v in point.payload.items() if k in columns} - row["id"] = point.id + if "id" in columns: + row["id"] = point.id yield row if next_offset is None: @@ -1199,6 +1212,7 @@ class QdrantStore: "severity": severity, "domain": domain, "ts": ts.isoformat(), + "ts_epoch": ts.timestamp(), "data_json": data_json, "session_id": session_id, "source_ids_json": source_ids_json, @@ -1206,9 +1220,14 @@ class QdrantStore: ) self._client.upsert(collection_name=METADATA_TABLE, points=[point]) - def events_query(self, kind: str | None = None, since: datetime | None = None, + def events_query(self, kind: str | None = None, since: str | datetime | None = None, severity: str | None = None, limit: int = 100) -> list[dict]: - """Query events from the metadata collection (table=events), newest first.""" + """Query events from the metadata collection (table=events), newest first. + + Scrolls all matching events and applies Python-side filtering so that + events lacking ``ts_epoch`` (pre-migration rows) are never silently + dropped. + """ # Always filter by table=events conditions = [FieldCondition(key="table", match=MatchValue(value=EVENTS_TABLE))] if kind is not None: @@ -1217,16 +1236,35 @@ class QdrantStore: conditions.append(FieldCondition(key="severity", match=MatchValue(value=severity))) event_filter = Filter(must=conditions) if conditions else None - points, _ = self._client.scroll( - collection_name=METADATA_TABLE, - limit=limit * 10, # fetch extra for post-filtering - with_payload=True, - with_vectors=False, - scroll_filter=event_filter, - ) + # Resolve since to a datetime for Python-side comparison + since_dt: datetime | None = None + if since is not None: + if isinstance(since, str): + since_dt = datetime.fromisoformat(since) + else: + since_dt = since if since.tzinfo else since.replace(tzinfo=timezone.utc) + + # Scroll ALL matching events (no limit on scroll — limit applies after sorting) + all_points: list = [] + offset = None + while True: + points, next_offset = self._client.scroll( + collection_name=METADATA_TABLE, + limit=1000, + offset=offset, + with_payload=True, + with_vectors=False, + scroll_filter=event_filter, + ) + all_points.extend(points) + if next_offset is None: + break + offset = next_offset + + # Python-side filtering and sorting out: list[dict] = [] - for pt in points: + for pt in all_points: p = pt.payload ts_str = p.get("ts", "") try: @@ -1234,9 +1272,10 @@ class QdrantStore: except (ValueError, TypeError): ts_dt = datetime.now(timezone.utc) - if since is not None: - since_cmp = since if since.tzinfo else since.replace(tzinfo=timezone.utc) - if ts_dt < since_cmp: + # Python-side since filter (works for all events, regardless of ts_epoch) + if since_dt is not None: + cmp = ts_dt if ts_dt.tzinfo else ts_dt.replace(tzinfo=timezone.utc) + if cmp < since_dt: continue raw_data = p.get("data_json") or "{}" @@ -1321,6 +1360,7 @@ class QdrantStore: "stability", "difficulty", "last_reviewed", "never_decay", "never_merge", "detail_level", "s5_trust_score", "structure_hv", + "provenance_json", "created_at", "schema_version", ]) rows = [] for r in records: @@ -1343,6 +1383,9 @@ class QdrantStore: "detail_level": r.detail_level, "s5_trust_score": r.s5_trust_score, "structure_hv": r.structure_hv.hex() if r.structure_hv else "", + "provenance_json": json.dumps(r.provenance) if r.provenance else "[]", + "created_at": str(r.created_at) if r.created_at else None, + "schema_version": r.schema_version, }) return pd.DataFrame(rows) except Exception: @@ -1353,6 +1396,7 @@ class QdrantStore: "stability", "difficulty", "last_reviewed", "never_decay", "never_merge", "detail_level", "s5_trust_score", "structure_hv", + "provenance_json", "created_at", "schema_version", ]) # ------------------------------------------------------------------ db shim @@ -1380,9 +1424,100 @@ class QdrantStore: elif self._name == RECORDS_TABLE: return self._store.records_as_dataframe() elif self._name == EVENTS_TABLE: - return pd.DataFrame() + return self._events_df() + elif self._name == BUDGET_TABLE: + return self._budget_df() + elif self._name == RATELIMIT_TABLE: + return self._ratelimit_df() return pd.DataFrame() + def _events_df(self) -> pd.DataFrame: + """Return all events as a pandas DataFrame.""" + try: + points = self._store._scroll_all( + METADATA_TABLE, table_filter=EVENTS_TABLE, batch_size=1000, + ) + if not points: + return pd.DataFrame(columns=[ + "id", "kind", "severity", "domain", "ts", + "data_json", "session_id", "source_ids_json", + ]) + rows = [] + for pt in points: + p = pt.payload + ts_str = p.get("ts", "") + try: + ts_dt = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + ts_dt = datetime.fromisoformat("1970-01-01") + rows.append({ + "id": p.get("id", str(pt.id)), + "kind": p.get("kind", ""), + "severity": p.get("severity") or "", + "domain": p.get("domain") or "", + "ts": ts_dt, + "data_json": p.get("data_json", ""), + "session_id": p.get("session_id", "-"), + "source_ids_json": p.get("source_ids_json", "[]"), + }) + return pd.DataFrame(rows) + except Exception: + return pd.DataFrame(columns=[ + "id", "kind", "severity", "domain", "ts", + "data_json", "session_id", "source_ids_json", + ]) + + def _budget_df(self) -> pd.DataFrame: + """Return all budget_ledger rows as a pandas DataFrame.""" + try: + points = self._store._scroll_all( + METADATA_TABLE, table_filter=BUDGET_TABLE, batch_size=1000, + ) + if not points: + return pd.DataFrame(columns=["date", "usd_spent", "kind", "ts"]) + rows = [] + for pt in points: + p = pt.payload + ts_str = p.get("ts", "") + try: + ts_dt = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + ts_dt = datetime.fromisoformat("1970-01-01") + rows.append({ + "date": p.get("date", ""), + "usd_spent": float(p.get("usd_spent", 0.0)), + "kind": p.get("kind", ""), + "ts": ts_dt, + }) + return pd.DataFrame(rows) + except Exception: + return pd.DataFrame(columns=["date", "usd_spent", "kind", "ts"]) + + def _ratelimit_df(self) -> pd.DataFrame: + """Return all ratelimit_ledger rows as a pandas DataFrame.""" + try: + points = self._store._scroll_all( + METADATA_TABLE, table_filter=RATELIMIT_TABLE, batch_size=1000, + ) + if not points: + return pd.DataFrame(columns=["ts", "status_code", "endpoint"]) + rows = [] + for pt in points: + p = pt.payload + ts_str = p.get("ts", "") + try: + ts_dt = datetime.fromisoformat(ts_str) + except (ValueError, TypeError): + ts_dt = datetime.fromisoformat("1970-01-01") + rows.append({ + "ts": ts_dt, + "status_code": int(p.get("status_code", 0)), + "endpoint": p.get("endpoint", ""), + }) + return pd.DataFrame(rows) + except Exception: + return pd.DataFrame(columns=["ts", "status_code", "endpoint"]) + def delete(self, where: str) -> None: """Delete rows from the table matching the where clause. @@ -1486,29 +1621,228 @@ class QdrantStore: except Exception: pass - @staticmethod - def _parse_where(where: str | None) -> Filter | None: + def add(self, rows: list[dict]) -> None: + """Insert rows into the table (LanceDB-compatible shim). + + LanceDB-compatible shim: ``table.add([{"col": "val", ...}])``. + Routes to the appropriate Qdrant collection based on table name. + """ + if not rows: + return + if self._name == EDGES_TABLE: + self._add_edges(rows) + elif self._name == EVENTS_TABLE: + self._add_events(rows) + elif self._name == BUDGET_TABLE: + self._add_budget(rows) + elif self._name == RATELIMIT_TABLE: + self._add_ratelimit(rows) + elif self._name == RECORDS_TABLE: + self._add_records(rows) + + def _add_edges(self, rows: list[dict]) -> None: + points = [] + for row in rows: + points.append(PointStruct( + id=str(uuid4()), + vector={}, + payload={ + "table": EDGES_TABLE, + "group_id": self._store._group_id, + "src": row.get("src", ""), + "dst": row.get("dst", ""), + "edge_type": row.get("edge_type", ""), + "weight": float(row.get("weight", 0.0)), + "updated_at": ( + row["updated_at"].isoformat() + if isinstance(row.get("updated_at"), datetime) + else str(row.get("updated_at", "")) + ), + }, + )) + if points: + self._store._client.upsert( + collection_name=METADATA_TABLE, points=points, + ) + + def _add_events(self, rows: list[dict]) -> None: + points = [] + for row in rows: + ts_val = row.get("ts") + ts_str = ( + ts_val.isoformat() + if isinstance(ts_val, datetime) + else str(ts_val) if ts_val else "" + ) + points.append(PointStruct( + id=str(row.get("id", uuid4())), + vector={}, + payload={ + "table": EVENTS_TABLE, + "group_id": self._store._group_id, + "id": str(row.get("id", "")), + "kind": row.get("kind", ""), + "severity": row.get("severity") or "", + "domain": row.get("domain") or "", + "ts": ts_str, + "data_json": row.get("data_json", ""), + "session_id": row.get("session_id", "-"), + "source_ids_json": row.get("source_ids_json", "[]"), + }, + )) + if points: + self._store._client.upsert( + collection_name=METADATA_TABLE, points=points, + ) + + def _add_budget(self, rows: list[dict]) -> None: + points = [] + for row in rows: + ts_val = row.get("ts") + ts_str = ( + ts_val.isoformat() + if isinstance(ts_val, datetime) + else str(ts_val) if ts_val else "" + ) + points.append(PointStruct( + id=str(uuid4()), + vector={}, + payload={ + "table": BUDGET_TABLE, + "group_id": self._store._group_id, + "date": row.get("date", ""), + "usd_spent": float(row.get("usd_spent", 0.0)), + "kind": row.get("kind", ""), + "ts": ts_str, + }, + )) + if points: + self._store._client.upsert( + collection_name=METADATA_TABLE, points=points, + ) + + def _add_ratelimit(self, rows: list[dict]) -> None: + points = [] + for row in rows: + ts_val = row.get("ts") + ts_str = ( + ts_val.isoformat() + if isinstance(ts_val, datetime) + else str(ts_val) if ts_val else "" + ) + points.append(PointStruct( + id=str(uuid4()), + vector={}, + payload={ + "table": RATELIMIT_TABLE, + "group_id": self._store._group_id, + "ts": ts_str, + "status_code": int(row.get("status_code", 0)), + "endpoint": row.get("endpoint", ""), + }, + )) + if points: + self._store._client.upsert( + collection_name=METADATA_TABLE, points=points, + ) + + def _add_records(self, rows: list[dict]) -> None: + """Insert record rows — converts dicts to PointStruct via _to_point.""" + points = [] + for row in rows: + # Build a minimal MemoryRecord from the dict so _to_point works + from uuid import UUID as _UUID + rec = MemoryRecord( + id=_UUID(row["id"]), + tier=row.get("tier", "episodic"), + literal_surface="", + aaak_index=row.get("aaak_index", ""), + embedding=list(row.get("embedding", [0.0] * self._store._embed_dim)), + community_id=_UUID(row["community_id"]) if row.get("community_id") else None, + centrality=float(row.get("centrality", 0.0)), + detail_level=int(row.get("detail_level", 1)), + pinned=bool(row.get("pinned", False)), + stability=float(row.get("stability", 0.0)), + difficulty=float(row.get("difficulty", 0.0)), + last_reviewed=None, + never_decay=bool(row.get("never_decay", False)), + never_merge=bool(row.get("never_merge", False)), + provenance=[], + created_at=None, + updated_at=None, + tags=json.loads(row.get("tags_json") or "[]"), + language=row.get("language", "en"), + s5_trust_score=float(row.get("s5_trust_score", 0.5)), + profile_modulation_gain={}, + schema_version=int(row.get("schema_version", 1)), + structure_hv=b"", + ) + points.append(self._store._to_point(rec)) + if points: + self._store._client.upsert( + collection_name=RECORDS_TABLE, points=points, + ) + + def _parse_where(self, where: str | None) -> Filter | None: """Parse a LanceDB-style where clause into a Qdrant Filter. - Supported format: ``"key = 'value' AND key2 = 'value2'"``. - Each ``key = 'value'`` segment becomes a ``FieldCondition`` - combined with ``must`` (AND semantics). + Supported grammar: + - ``key = 'value'`` or ``key = "value"`` → MatchValue + - ``key > N`` or ``key >= N`` → Range(gt/gte) + - ``key < N`` or ``key <= N`` → Range(lt/lte) + - Multiple clauses joined by AND → Filter(must=[...]) - Returns None if the clause is empty or cannot be parsed. + Numeric values are auto-detected; string values must be + quoted. Returns None if the clause is empty or cannot be parsed. """ if not where: return None - # Find all key = 'value' or key = "value" patterns. - pairs = re.findall( - r"""([a-z_]+)\s*=\s*['"]([^'"]+)['"]""", - where, - ) - if not pairs: + + # Split on AND (case-insensitive) to get individual clauses. + # Use regex to avoid splitting inside quoted strings. + clauses = re.split(r'\s+AND\s+', where, flags=re.IGNORECASE) + conditions: list = [] + + for clause in clauses: + clause = clause.strip() + if not clause: + continue + + # Try numeric comparison operators first: >, >=, <, <= + m = re.match( + r"""([a-z_]+)\s*(>=|<=|>|<)\s*(-?\d+(?:\.\d+)?)\s*$""", + clause, + ) + if m: + key, op, val_str = m.group(1), m.group(2), m.group(3) + val = float(val_str) + if op == ">": + conditions.append(FieldCondition(key=key, range=models.Range(gt=val))) + elif op == ">=": + conditions.append(FieldCondition(key=key, range=models.Range(gte=val))) + elif op == "<": + conditions.append(FieldCondition(key=key, range=models.Range(lt=val))) + elif op == "<=": + conditions.append(FieldCondition(key=key, range=models.Range(lte=val))) + continue + + # Try string equality: key = 'value' or key = "value" + m = re.match( + r"""([a-z_]+)\s*=\s*['"]([^'"]+)['"]\s*$""", + clause, + ) + if m: + key, val = m.group(1), m.group(2) + conditions.append( + FieldCondition(key=key, match=MatchValue(value=val)) + ) + continue + + # Unparseable clause — skip silently (LanceDB would too). + continue + + if not conditions: return None - conditions = [ - FieldCondition(key=key, match=MatchValue(value=val)) - for key, val in pairs - ] if len(conditions) == 1: return Filter(must=conditions) return Filter(must=conditions)