diff --git a/deploy/opencode/iai-mcp-memory-inject.js b/deploy/opencode/iai-mcp-memory-inject.js new file mode 100644 index 0000000..a1c47a3 --- /dev/null +++ b/deploy/opencode/iai-mcp-memory-inject.js @@ -0,0 +1,111 @@ +/** + * iai-mcp memory-injection plugin for OpenCode (approach A). + * + * Pulls session-start memory from the iai-mcp daemon's localhost HTTP adapter + * and injects it directly into the model's system prompt via the + * `experimental.chat.system.transform` hook. The memory CONTENT is placed in + * context — no tool call, no injected "INIT" user turn, so the session title + * is generated from the user's real first message (clean). + * + * REPLACES iai-mcp-session-init.js — do NOT run both. session-init forces a + * tool call via a phantom turn (hijacks the title); this plugin needs neither. + * + * Requires the daemon's HTTP listener: + * - set IAI_DAEMON_HTTP_PORT in the daemon's systemd unit (e.g. "0" for an + * OS-assigned port) and restart it. The daemon writes the live port to + * ~/.iai-mcp/.http.port, which this plugin reads. + * + * wake_depth=standard is requested so l0/l1/l2/rich_club carry real content + * (minimal mode returns only opaque handles — nothing worth injecting). + * Override via IAI_MCP_WAKE_DEPTH (minimal|standard|deep). + * + * Fail-safe: any error is swallowed; system-prompt assembly must never break. + */ + +const HOME = process.env.HOME || process.cwd(); +const PORT_FILE = `${HOME}/.iai-mcp/.http.port`; +const WAKE_DEPTH = process.env.IAI_MCP_WAKE_DEPTH || "standard"; +const FETCH_TIMEOUT_MS = 10000; // a cold runtime-graph build can exceed 5s +const MAX_ATTEMPTS = 3; // give up after this many failed fetches per session + +const memo = new Map(); // sessionID -> injected text (cached on success) +const attempts = new Map(); // sessionID -> failed-fetch count +const inflight = new Set(); // sessionIDs with a fetch in flight (dedupe warm+inject) + +async function readPort() { + const fs = await import("node:fs"); + try { + const port = parseInt(fs.readFileSync(PORT_FILE, "utf8").trim(), 10); + return Number.isInteger(port) && port > 0 ? port : null; + } catch { + return null; // daemon HTTP not enabled / not up yet + } +} + +async function fetchMemory(sessionId) { + const port = await readPort(); + if (!port) return ""; + const url = + `http://127.0.0.1:${port}/memory/session-context` + + `?session_id=${encodeURIComponent(sessionId)}` + + `&wake_depth=${encodeURIComponent(WAKE_DEPTH)}&format=text`; + try { + const res = await fetch(url, { signal: AbortSignal.timeout(FETCH_TIMEOUT_MS) }); + if (!res.ok) return ""; + return (await res.text()).trim(); + } catch { + return ""; + } +} + +// Centralised fetch+cache so the warm-on-create event and the per-turn +// transform hook share ONE in-flight fetch, one cache, and one attempt budget. +// This dedupes concurrent calls so a session never emits duplicate daemon-side +// session_started events. +async function ensureMemory(sessionId) { + const cached = memo.get(sessionId); + if (cached !== undefined) return cached; + if (inflight.has(sessionId)) return ""; // a fetch is already running + if ((attempts.get(sessionId) || 0) >= MAX_ATTEMPTS) return ""; + inflight.add(sessionId); + try { + const text = await fetchMemory(sessionId); + if (text) memo.set(sessionId, text); + else attempts.set(sessionId, (attempts.get(sessionId) || 0) + 1); + return text; + } finally { + inflight.delete(sessionId); + } +} + +export const IaiMcpMemoryInject = async () => { + return { + // Warm-on-create: prime the daemon graph cache (and our memo) as soon as a + // session appears — before the user's first turn — so the first transform + // injects immediately instead of paying the cold-build latency. + event: async ({ event }) => { + if (event.type !== "session.updated") return; + const sid = event.properties?.info?.id; + if (!sid) return; + try { + await ensureMemory(sid); + } catch { + // never throw from an event handler + } + }, + + "experimental.chat.system.transform": async (input, output) => { + try { + const sid = input?.sessionID; + if (!sid || !output || !Array.isArray(output.system)) return; + const text = await ensureMemory(sid); + if (text) { + output.system.push(`# iai-mcp memory (session start)\n${text}`); + } + } catch (err) { + // NEVER throw — a plugin error must not break system-prompt assembly. + console.error(`[iai-mcp] memory inject failed: ${err.message}`); + } + }, + }; +}; diff --git a/src/iai_mcp/core.py b/src/iai_mcp/core.py index ebaa714..3dfdfb9 100644 --- a/src/iai_mcp/core.py +++ b/src/iai_mcp/core.py @@ -741,6 +741,15 @@ def dispatch(store: MemoryStore, method: str, params: dict) -> dict: # wake_depth knob reaches the assembler. from iai_mcp.session import assemble_session_start, SessionStartPayload sid = params.get("session_id", "-") + # D5-02 per-call wake_depth override: a caller (e.g. the HTTP adapter + # serving session-context for system-prompt injection) can request + # standard/deep content WITHOUT mutating the global profile knob. + # Invalid/absent values fall back to the per-process profile state. + profile_state = _profile_state + wd_override = params.get("wake_depth") + if wd_override in ("minimal", "standard", "deep"): + profile_state = dict(_profile_state or {}) + profile_state["wake_depth"] = wd_override records_count = store.count_rows("records") if records_count == 0: empty = SessionStartPayload( @@ -756,7 +765,7 @@ def dispatch(store: MemoryStore, method: str, params: dict) -> dict: payload = assemble_session_start( store, assignment, rc, session_id=sid, - profile_state=_profile_state, + profile_state=profile_state, ) return _payload_to_json(payload) diff --git a/src/iai_mcp/daemon.py b/src/iai_mcp/daemon.py index 8c7fc82..24ebafb 100644 --- a/src/iai_mcp/daemon.py +++ b/src/iai_mcp/daemon.py @@ -1403,6 +1403,27 @@ async def main() -> int: mcp_socket = SocketServer(store, lock=lock, state=state) mcp_socket_task = asyncio.create_task(mcp_socket.serve()) + # Tier-1 PoC: optional localhost HTTP adapter over the same core.dispatch. + # OFF by default -- enabled only when IAI_DAEMON_HTTP_PORT is set (value "0" + # means OS-assigned ephemeral port, written to ~/.iai-mcp/.http.port for + # discovery). Lets MCP hosts that lack a server->client push mechanism + # (e.g. OpenCode) fetch session memory by URL. NO auth in this tier -- + # 127.0.0.1 bind is the only boundary; see http_server.py header. + http_server = None + http_task = None + _http_port_env = os.environ.get("IAI_DAEMON_HTTP_PORT") + if _http_port_env is not None: + try: + from iai_mcp.http_server import HttpServer + http_server = HttpServer(store, port=int(_http_port_env)) + http_task = asyncio.create_task(http_server.serve()) + except (ValueError, OSError): + # Bad port value or bind failure must NOT crash daemon boot + # (mirrors the maintenance try/except above). The unix socket + # and all lifecycle tasks come up regardless. + http_server = None + http_task = None + # Plan 10.6-01 Task 1.4: REMOVED `_propagate_idle_shutdown` # bridge task. The socket-side `idle_watcher` (which set # mcp_socket.shutdown_event after IDLE_CHECK_INTERVAL_SECS of @@ -1661,12 +1682,20 @@ async def main() -> int: mcp_socket.shutdown_event.set() except Exception: pass + # Drain the optional HTTP adapter the same way before cancellation. + if http_server is not None: + try: + http_server.shutdown_event.set() + except Exception: + pass _cancel_targets = [ tick_task, audit_task, s4_task, cascade_task, mcp_socket_task, cpu_watchdog_task, lifecycle_tick_task, ] + if http_task is not None: + _cancel_targets.append(http_task) for t in _cancel_targets: t.cancel() # Drain task exceptions silently: we're shutting down. diff --git a/src/iai_mcp/http_server.py b/src/iai_mcp/http_server.py new file mode 100644 index 0000000..19a72ed --- /dev/null +++ b/src/iai_mcp/http_server.py @@ -0,0 +1,325 @@ +"""Tier-1 localhost HTTP adapter over core.dispatch (PoC). + +Sibling transport to socket_server.py: both are thin adapters around the +single transport-agnostic entry point ``core.dispatch(store, method, params)`` +(D7-08 -- one dispatch function per method, no transport branching). This one +speaks HTTP/1.0 on 127.0.0.1 so an MCP host that has *no* server->client push +mechanism (e.g. OpenCode, which ignores the MCP ``instructions`` field) can +still pull session memory into its system prompt by fetching a URL. + +PoC SCOPE / NON-GOALS (deliberately minimal -- Tier 1): + - localhost-only bind. NEVER 0.0.0.0. The unix socket's security boundary is + filesystem perms (chmod 0o600, T-04-07); a TCP port has none, so we limit + blast radius to loopback and ship NO auth here. Token auth is the Tier-2 + follow-up before this is anything but a local convenience. + - one request per connection (Connection: close). No keep-alive, no chunked + encoding, no HTTP/1.1 pipelining. Simplest robust thing for fetch() clients. + - disabled unless IAI_DAEMON_HTTP_PORT is set (see daemon.main wiring). Absent + => no TCP surface at all, identical to today. + +Constitutional guards (mirror socket_server.py): + - C-DISPATCHER-FSM-ISOLATION: this adapter calls core.dispatch ONLY; it never + touches the daemon FSM. FSM transitions stay owned by daemon.py's tick. + - C3 ZERO API COST: stdlib + core.dispatch only; no SDK references. + - C5 LITERAL PRESERVATION: transport-only; zero record mutation here. + - R3 head-of-line: dispatch is sync + 50-500 ms, so it runs on + asyncio.to_thread exactly like socket_server.py:257. + - R5 fail-loud: dispatch raises map to HTTP status codes (see _STATUS_FOR). +""" + +from __future__ import annotations + +import asyncio +import json +import os +import sys +import time +from pathlib import Path +from typing import Any +from urllib.parse import parse_qs, urlsplit + +from iai_mcp.concurrency import SOCKET_PATH +from iai_mcp.core import UnknownMethodError, dispatch + +# Discovery file: consumers (OpenCode plugin) read the live port from here, +# mirroring the SOCKET_PATH convention (~/.iai-mcp/.daemon.sock). Sits next to +# the socket so both transports share the one well-known dir. +HTTP_PORT_FILE: Path = SOCKET_PATH.parent / ".http.port" + +# dispatch-exception -> HTTP status, mirroring socket_server.py's JSON-RPC code +# mapping (ERR_METHOD_NOT_FOUND/-32601, ERR_INVALID_PARAMS/-32602, ...). +_BAD_REQUEST = 400 +_NOT_FOUND = 404 +_METHOD_NOT_ALLOWED = 405 +_INTERNAL = 500 + +_MAX_BODY_BYTES = 1 << 20 # 1 MiB cap on POST bodies; refuse larger with 400. + + +def _payload_to_text(result: Any) -> str: + """Flatten a session_start_payload dict into a system-prompt-ready block. + + Best-effort and forgiving: any missing/empty key is skipped, so this is + correct at every wake_depth: + - standard/deep populate the human-readable content layers + (l0/l1/rich_club/l2) -- the useful case for system-prompt injection. + - minimal (the default) leaves those empty and emits only the compact + pointer/handle fields (~30 tok). Those are opaque references, not + content; we still surface them on a trailing line so the block is + never silently empty, but a host that wants real memory text in the + prompt must run wake_depth>=standard. + Token-accounting ints and breakpoint_marker are dropped. + """ + if not isinstance(result, dict): + return str(result) + parts: list[str] = [] + for key in ("l0", "l1", "rich_club"): + val = result.get(key) + if isinstance(val, str) and val.strip(): + parts.append(val.strip()) + l2 = result.get("l2") + if isinstance(l2, list): + for item in l2: + text = item if isinstance(item, str) else json.dumps(item, ensure_ascii=False) + if text.strip(): + parts.append(text.strip()) + # Compact handles (populated at minimal; also present at standard/deep). + handles = [ + result.get(k) + for k in ("identity_pointer", "compact_handle", "brain_handle", "topic_cluster_hint") + ] + handles = [h for h in handles if isinstance(h, str) and h.strip()] + if handles: + parts.append(" ".join(handles)) + return "\n\n".join(parts) + + +class HttpServer: + """Per-connection HTTP/1.0 server routing to core.dispatch. + + Routes (all GET unless noted): + - GET /healthz -> {"ok": true} (no dispatch) + - GET /memory/session-context -> dispatch("session_start_payload", + {session_id?}); ?format=text returns + a flattened text/plain block. + - POST /rpc {method, params} -> dispatch(method, params); generic + passthrough mirroring the socket. + + Constructor args mirror SocketServer: a shared MemoryStore singleton plus + bind knobs. shutdown_event lets daemon.main drain gracefully. + """ + + def __init__( + self, + store: Any, + *, + host: str = "127.0.0.1", + port: int = 0, + port_file: Path | None = None, + ) -> None: + self.store = store + self.host = host + self.port = port # 0 => OS-assigned ephemeral port (written to port_file) + self.port_file = port_file if port_file is not None else HTTP_PORT_FILE + self.bound_port: int | None = None + self.last_activity_ts: float = time.monotonic() + self.active_connections: int = 0 + self.shutdown_event: asyncio.Event = asyncio.Event() + + # -- request parsing ---------------------------------------------------- + + async def _read_request( + self, reader: asyncio.StreamReader + ) -> tuple[str, str, dict[str, str], bytes] | None: + """Parse one HTTP request. Returns (method, target, headers, body) or None on EOF.""" + request_line = await reader.readline() + if not request_line: + return None + try: + method, target, _version = request_line.decode("latin-1").rstrip("\r\n").split(" ", 2) + except ValueError: + raise _HttpError(_BAD_REQUEST, "malformed request line") + headers: dict[str, str] = {} + while True: + line = await reader.readline() + if line in (b"\r\n", b"\n", b""): + break + raw = line.decode("latin-1").rstrip("\r\n") + if ":" in raw: + name, _, value = raw.partition(":") + headers[name.strip().lower()] = value.strip() + body = b"" + length = headers.get("content-length") + if length is not None: + try: + n = int(length) + except ValueError: + raise _HttpError(_BAD_REQUEST, "bad content-length") + if n > _MAX_BODY_BYTES: + raise _HttpError(_BAD_REQUEST, "body too large") + if n > 0: + body = await reader.readexactly(n) + return method.upper(), target, headers, body + + # -- routing ------------------------------------------------------------ + + async def _route(self, method: str, target: str, body: bytes) -> tuple[int, str, str]: + """Return (status, content_type, text). Raises _HttpError for 4xx/5xx shapes.""" + split = urlsplit(target) + path = split.path + query = parse_qs(split.query) + + if path == "/healthz": + if method != "GET": + raise _HttpError(_METHOD_NOT_ALLOWED, "GET only") + return 200, "application/json", json.dumps({"ok": True}) + + if path == "/memory/session-context": + if method != "GET": + raise _HttpError(_METHOD_NOT_ALLOWED, "GET only") + session_id = (query.get("session_id") or [None])[0] + params: dict = {} + if session_id: + params["session_id"] = session_id + # Per-call wake_depth override (minimal|standard|deep). standard/deep + # populate l0/l1/l2/rich_club -- the content worth injecting into a + # system prompt; core.py validates and ignores junk values. + wake_depth = (query.get("wake_depth") or [None])[0] + if wake_depth: + params["wake_depth"] = wake_depth + result = await self._dispatch("session_start_payload", params) + if (query.get("format") or ["json"])[0] == "text": + return 200, "text/plain; charset=utf-8", _payload_to_text(result) + return 200, "application/json", json.dumps(result, ensure_ascii=False) + + if path == "/rpc": + if method != "POST": + raise _HttpError(_METHOD_NOT_ALLOWED, "POST only") + try: + envelope = json.loads(body or b"{}") + except json.JSONDecodeError as e: + raise _HttpError(_BAD_REQUEST, f"invalid json: {e}") + if not isinstance(envelope, dict) or not isinstance(envelope.get("method"), str): + raise _HttpError(_BAD_REQUEST, "expected {method, params}") + result = await self._dispatch(envelope["method"], envelope.get("params") or {}) + return 200, "application/json", json.dumps(result, ensure_ascii=False) + + raise _HttpError(_NOT_FOUND, f"no route for {path}") + + async def _dispatch(self, method: str, params: dict) -> Any: + """Run core.dispatch off-loop (R3) and map raises to _HttpError (R5).""" + try: + # CRITICAL R3: dispatch is sync + 50-500 ms; to_thread prevents + # head-of-line blocking across connections (same as socket_server). + return await asyncio.to_thread(dispatch, self.store, method, params) + except UnknownMethodError as e: + raise _HttpError(_NOT_FOUND, f"unknown method '{e.args[0]}'") + except KeyError as e: + raise _HttpError(_BAD_REQUEST, f"missing required param: {e.args[0]!r}") + except TypeError as e: + raise _HttpError(_BAD_REQUEST, str(e)) + except Exception as e: # noqa: BLE001 -- HTTP must never crash daemon + raise _HttpError(_INTERNAL, str(e)) + + # -- connection handler ------------------------------------------------- + + async def handle(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + """One coroutine per connection. Reads one request, responds, closes.""" + self.active_connections += 1 + t0 = time.monotonic() + method, target = "-", "-" + try: + self.last_activity_ts = t0 + try: + parsed = await self._read_request(reader) + if parsed is None: + return + method, target, _headers, body = parsed + status, content_type, text = await self._route(method, target, body) + except _HttpError as e: + status, content_type, text = e.status, "application/json", json.dumps( + {"error": e.message} + ) + except asyncio.IncompleteReadError: + status, content_type, text = _BAD_REQUEST, "application/json", json.dumps( + {"error": "truncated body"} + ) + self._write_response(writer, status, content_type, text) + await writer.drain() + # Per-request access log -> stderr (journal). Matches the daemon's + # existing {"event":...} JSON-line style so it greps cleanly. + print( + json.dumps({ + "event": "http_request", + "method": method, + "target": target, + "status": status, + "bytes": len(text.encode("utf-8")), + "ms": round((time.monotonic() - t0) * 1000), + }), + file=sys.stderr, + flush=True, + ) + except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError): + # Client hung up mid-write (fetch aborted, host killed). Expected; + # not a daemon fault -- mirrors socket_server.py's handling. + pass + finally: + self.active_connections -= 1 + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + @staticmethod + def _write_response( + writer: asyncio.StreamWriter, status: int, content_type: str, text: str + ) -> None: + body = text.encode("utf-8") + reason = { + 200: "OK", 400: "Bad Request", 404: "Not Found", + 405: "Method Not Allowed", 500: "Internal Server Error", + }.get(status, "OK") + head = ( + f"HTTP/1.0 {status} {reason}\r\n" + f"Content-Type: {content_type}\r\n" + f"Content-Length: {len(body)}\r\n" + "Connection: close\r\n" + "\r\n" + ) + writer.write(head.encode("latin-1") + body) + + # -- lifecycle ---------------------------------------------------------- + + async def serve(self) -> None: + """Bind 127.0.0.1:port, write the port-file, serve until shutdown_event.""" + server = await asyncio.start_server(self.handle, host=self.host, port=self.port) + self.bound_port = server.sockets[0].getsockname()[1] + # Discovery: publish the live port so consumers needn't guess. Best + # effort -- a failed write just means callers must be told the port. + try: + self.port_file.parent.mkdir(parents=True, exist_ok=True) + self.port_file.write_text(str(self.bound_port), encoding="utf-8") + os.chmod(self.port_file, 0o600) + except OSError: + pass + try: + async with server: + await self.shutdown_event.wait() + server.close() + await server.wait_closed() + finally: + try: + self.port_file.unlink() + except (FileNotFoundError, OSError): + pass + + +class _HttpError(Exception): + """Internal control-flow exception carrying an HTTP status + message.""" + + def __init__(self, status: int, message: str) -> None: + super().__init__(message) + self.status = status + self.message = message diff --git a/tests/test_http_server.py b/tests/test_http_server.py new file mode 100644 index 0000000..06da13c --- /dev/null +++ b/tests/test_http_server.py @@ -0,0 +1,247 @@ +"""End-to-end tests for the Tier-1 localhost HTTP adapter (http_server.py). + +Mirrors tests/test_daemon_dispatcher.py: boot the REAL HttpServer on an +ephemeral 127.0.0.1 port and drive it with raw HTTP/1.0 over a real TCP +socket. core.dispatch is monkeypatched to a fake so these tests exercise the +*transport adapter* (routing, error->status mapping, port-file discovery, +graceful shutdown) without needing a real MemoryStore / LanceDB. +""" +from __future__ import annotations + +import asyncio +import json + +import pytest + +from iai_mcp.core import UnknownMethodError + + +def _fake_dispatch(store, method, params): + """Stand-in for core.dispatch: deterministic, no store needed.""" + if method == "session_start_payload": + return { + "l0": "pinned identity", + "l1": "recent summary", + "l2": ["episode one", "episode two"], + "rich_club": "hub concepts", + "total_cached_tokens": 0, + "total_dynamic_tokens": 1000, + # echo the override so the transport-threading test can assert it. + "wake_depth": params.get("wake_depth", "minimal"), + } + if method == "minimal_payload": + # wake_depth=minimal: content layers empty, only opaque handles set. + return { + "l0": "", "l1": "", "l2": [], "rich_club": "", + "identity_pointer": "", + "brain_handle": "", + "topic_cluster_hint": "", + "compact_handle": "", + "wake_depth": "minimal", + } + if method == "needs_param": + # Simulate core.py's params["cue"] KeyError path. + return {"cue": params["cue"]} + if method == "boom": + raise RuntimeError("kaboom") + raise UnknownMethodError(method) + + +async def _http_request(port, method, path, body=None, *, timeout=5.0): + """Send one HTTP/1.0 request, read the full response to EOF, parse it.""" + reader, writer = await asyncio.wait_for( + asyncio.open_connection("127.0.0.1", port), timeout=timeout + ) + try: + body_bytes = body.encode("utf-8") if isinstance(body, str) else (body or b"") + head = f"{method} {path} HTTP/1.0\r\n" + if body_bytes: + head += f"Content-Length: {len(body_bytes)}\r\n" + head += "Connection: close\r\n\r\n" + writer.write(head.encode("latin-1") + body_bytes) + await writer.drain() + raw = await asyncio.wait_for(reader.read(-1), timeout=timeout) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + head_bytes, _, payload = raw.partition(b"\r\n\r\n") + lines = head_bytes.split(b"\r\n") + status = int(lines[0].decode("latin-1").split(" ")[1]) + headers = {} + for line in lines[1:]: + name, _, val = line.decode("latin-1").partition(":") + headers[name.strip().lower()] = val.strip() + return status, headers, payload.decode("utf-8") + + +async def _with_http_server(coro_fn, *, port_file, monkeypatch): + """Boot the real HttpServer with _fake_dispatch on an ephemeral port.""" + from iai_mcp import http_server as hs + + monkeypatch.setattr(hs, "dispatch", _fake_dispatch) + server = hs.HttpServer(store=object(), host="127.0.0.1", port=0, port_file=port_file) + task = asyncio.create_task(server.serve()) + for _ in range(250): # wait for bind (bound_port set inside serve()) + if server.bound_port is not None: + break + await asyncio.sleep(0.01) + if server.bound_port is None: + server.shutdown_event.set() + await asyncio.wait_for(task, timeout=5) + raise AssertionError("server never bound") + try: + return await coro_fn(server) + finally: + server.shutdown_event.set() + try: + await asyncio.wait_for(task, timeout=5) + except Exception: + pass + + +def test_healthz_returns_ok(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request(server.bound_port, "GET", "/healthz") + + status, headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 200 + assert headers["content-type"] == "application/json" + assert json.loads(body) == {"ok": True} + + +def test_session_context_json(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, "GET", "/memory/session-context?session_id=abc" + ) + + status, headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 200 + assert headers["content-type"] == "application/json" + payload = json.loads(body) + assert payload["l0"] == "pinned identity" + assert payload["l2"] == ["episode one", "episode two"] + + +def test_session_context_threads_wake_depth(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, + "GET", + "/memory/session-context?session_id=abc&wake_depth=standard", + ) + + status, _headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 200 + assert json.loads(body)["wake_depth"] == "standard" + + +def test_session_context_text_format(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, "GET", "/memory/session-context?format=text" + ) + + status, headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 200 + assert headers["content-type"].startswith("text/plain") + # Flattened block keeps the human layers in order, drops token ints. + assert body == ( + "pinned identity\n\nrecent summary\n\nhub concepts\n\nepisode one\n\nepisode two" + ) + + +def test_payload_to_text_renders_handles_at_minimal(): + """At minimal wake_depth, _payload_to_text surfaces the compact handles + (content layers empty) so the injected block is never silently empty.""" + from iai_mcp.http_server import _payload_to_text + + text = _payload_to_text(_fake_dispatch(None, "minimal_payload", {})) + assert text == " " + + +def test_rpc_post_passthrough(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, + "POST", + "/rpc", + json.dumps({"method": "session_start_payload", "params": {"session_id": "z"}}), + ) + + status, _headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 200 + assert json.loads(body)["rich_club"] == "hub concepts" + + +def test_unknown_route_404(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request(server.bound_port, "GET", "/nope") + + status, _headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 404 + assert "no route" in json.loads(body)["error"] + + +def test_unknown_method_maps_to_404(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, "POST", "/rpc", json.dumps({"method": "ghost"}) + ) + + status, _headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 404 + assert "unknown method 'ghost'" in json.loads(body)["error"] + + +def test_internal_error_maps_to_500(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request( + server.bound_port, "POST", "/rpc", json.dumps({"method": "boom"}) + ) + + status, _headers, body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 500 + assert "kaboom" in json.loads(body)["error"] + + +def test_wrong_verb_405(tmp_path, monkeypatch): + async def _runner(server): + return await _http_request(server.bound_port, "POST", "/healthz") + + status, _headers, _body = asyncio.run( + _with_http_server(_runner, port_file=tmp_path / ".http.port", monkeypatch=monkeypatch) + ) + assert status == 405 + + +def test_port_file_written_and_cleaned(tmp_path, monkeypatch): + port_file = tmp_path / ".http.port" + + async def _runner(server): + # While serving, the port-file holds the live bound port. + assert port_file.read_text() == str(server.bound_port) + return server.bound_port + + asyncio.run(_with_http_server(_runner, port_file=port_file, monkeypatch=monkeypatch)) + # After graceful shutdown, the discovery file is removed. + assert not port_file.exists() diff --git a/tests/test_session_start_wake_depth.py b/tests/test_session_start_wake_depth.py new file mode 100644 index 0000000..4cdef16 --- /dev/null +++ b/tests/test_session_start_wake_depth.py @@ -0,0 +1,58 @@ +"""Unit tests for the per-call wake_depth override in core.dispatch. + +The HTTP adapter (and any caller) can request standard/deep session-start +content without flipping the global profile knob. These tests verify the +override is threaded into assemble_session_start's profile_state, that invalid +values fall back to the per-process state, and that the global _profile_state is +never mutated. assemble_session_start + build_runtime_graph are monkeypatched +so no real store/embedder is needed. +""" +from __future__ import annotations + +import iai_mcp.core as core +import iai_mcp.retrieve as retrieve +import iai_mcp.session as session + + +class _FakeStore: + def count_rows(self, table): # non-empty -> takes the assemble branch + return 5 + + +def _patch(monkeypatch): + captured = {} + + def fake_assemble(store, assignment, rc, *, session_id, profile_state): + captured["profile_state"] = profile_state + captured["session_id"] = session_id + return session.SessionStartPayload(l0="content") + + monkeypatch.setattr(session, "assemble_session_start", fake_assemble) + monkeypatch.setattr(retrieve, "build_runtime_graph", lambda store: (None, {}, None)) + monkeypatch.setattr(core, "_profile_state", {"wake_depth": "minimal", "literal_preservation": 0.5}) + return captured + + +def test_wake_depth_override_threaded(monkeypatch): + captured = _patch(monkeypatch) + core.dispatch(_FakeStore(), "session_start_payload", {"session_id": "s", "wake_depth": "standard"}) + ps = captured["profile_state"] + assert ps["wake_depth"] == "standard" + # other profile knobs are preserved in the override copy. + assert ps["literal_preservation"] == 0.5 + # the module global must NOT be mutated by the override. + assert core._profile_state["wake_depth"] == "minimal" + + +def test_invalid_wake_depth_falls_back_to_profile(monkeypatch): + captured = _patch(monkeypatch) + core.dispatch(_FakeStore(), "session_start_payload", {"session_id": "s", "wake_depth": "ultra"}) + # junk value ignored -> uses the per-process profile state (identity, minimal). + assert captured["profile_state"] is core._profile_state + assert captured["profile_state"]["wake_depth"] == "minimal" + + +def test_absent_wake_depth_uses_profile(monkeypatch): + captured = _patch(monkeypatch) + core.dispatch(_FakeStore(), "session_start_payload", {"session_id": "s"}) + assert captured["profile_state"] is core._profile_state