iai-mcp-opencode/src/iai_mcp/http_server.py

327 lines
14 KiB
Python
Raw Normal View History

"""Tier-1 localhost HTTP adapter over core.dispatch (PoC).
Sibling transport to socket_server.py: both are thin adapters around the
single transport-agnostic entry point ``core.dispatch(store, method, params)``
(D7-08 -- one dispatch function per method, no transport branching). This one
speaks HTTP/1.0 on 127.0.0.1 so an MCP host that has *no* server->client push
mechanism (e.g. OpenCode, which ignores the MCP ``instructions`` field) can
still pull session memory into its system prompt by fetching a URL.
PoC SCOPE / NON-GOALS (deliberately minimal -- Tier 1):
- localhost-only bind. NEVER 0.0.0.0. The unix socket's security boundary is
filesystem perms (chmod 0o600, T-04-07); a TCP port has none, so we limit
blast radius to loopback and ship NO auth here. Token auth is the Tier-2
follow-up before this is anything but a local convenience.
- one request per connection (Connection: close). No keep-alive, no chunked
encoding, no HTTP/1.1 pipelining. Simplest robust thing for fetch() clients.
- disabled unless IAI_DAEMON_HTTP_PORT is set (see daemon.main wiring). Absent
=> no TCP surface at all, identical to today.
Constitutional guards (mirror socket_server.py):
- C-DISPATCHER-FSM-ISOLATION: this adapter calls core.dispatch ONLY; it never
touches the daemon FSM. FSM transitions stay owned by daemon.py's tick.
- C3 ZERO API COST: stdlib + core.dispatch only; no SDK references.
- C5 LITERAL PRESERVATION: transport-only; zero record mutation here.
- R3 head-of-line: dispatch is sync + 50-500 ms, so it runs on
asyncio.to_thread exactly like socket_server.py:257.
- R5 fail-loud: dispatch raises map to HTTP status codes (see _STATUS_FOR).
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import time
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlsplit
from iai_mcp.concurrency import SOCKET_PATH
from iai_mcp.core import UnknownMethodError, dispatch
# Discovery file: consumers (OpenCode plugin) read the live port from here,
# mirroring the SOCKET_PATH convention (~/.iai-mcp/.daemon.sock). Sits next to
# the socket so both transports share the one well-known dir.
HTTP_PORT_FILE: Path = SOCKET_PATH.parent / ".http.port"
# dispatch-exception -> HTTP status, mirroring socket_server.py's JSON-RPC code
# mapping (ERR_METHOD_NOT_FOUND/-32601, ERR_INVALID_PARAMS/-32602, ...).
_BAD_REQUEST = 400
_NOT_FOUND = 404
_METHOD_NOT_ALLOWED = 405
_INTERNAL = 500
_MAX_BODY_BYTES = 1 << 20 # 1 MiB cap on POST bodies; refuse larger with 400.
def _payload_to_text(result: Any) -> str:
"""Flatten a session_start_payload dict into a system-prompt-ready block.
Best-effort and forgiving: any missing/empty key is skipped, so this is
correct at every wake_depth:
- standard/deep populate the human-readable content layers
(l0/l1/rich_club/l2) -- the useful case for system-prompt injection.
- minimal (the default) leaves those empty and emits only the compact
pointer/handle fields (~30 tok). Those are opaque references, not
content; we still surface them on a trailing line so the block is
never silently empty, but a host that wants real memory text in the
prompt must run wake_depth>=standard.
Token-accounting ints and breakpoint_marker are dropped.
"""
if not isinstance(result, dict):
return str(result)
parts: list[str] = []
headers = {"l0": "## L0 Identity", "l1": "## L1 Recent Summary", "rich_club": "## Rich Club"}
for key in ("l0", "l1", "rich_club"):
val = result.get(key)
if isinstance(val, str) and val.strip():
parts.append(f"{headers[key]}\n{val.strip()}")
l2 = result.get("l2")
if isinstance(l2, list):
for item in l2:
text = item if isinstance(item, str) else json.dumps(item, ensure_ascii=False)
if text.strip():
parts.append(f"## L2 Episode\n{text.strip()}")
# Compact handles (populated at minimal; also present at standard/deep).
handles = [
result.get(k)
for k in ("identity_pointer", "compact_handle", "brain_handle", "topic_cluster_hint")
]
handles = [h for h in handles if isinstance(h, str) and h.strip()]
if handles:
parts.append(f"## Handles\n{' '.join(handles)}")
return "\n\n".join(parts)
class HttpServer:
"""Per-connection HTTP/1.0 server routing to core.dispatch.
Routes (all GET unless noted):
- GET /healthz -> {"ok": true} (no dispatch)
- GET /memory/session-context -> dispatch("session_start_payload",
{session_id?}); ?format=text returns
a flattened text/plain block.
- POST /rpc {method, params} -> dispatch(method, params); generic
passthrough mirroring the socket.
Constructor args mirror SocketServer: a shared MemoryStore singleton plus
bind knobs. shutdown_event lets daemon.main drain gracefully.
"""
def __init__(
self,
store: Any,
*,
host: str = "127.0.0.1",
port: int = 0,
port_file: Path | None = None,
) -> None:
self.store = store
self.host = host
self.port = port # 0 => OS-assigned ephemeral port (written to port_file)
self.port_file = port_file if port_file is not None else HTTP_PORT_FILE
self.bound_port: int | None = None
self.last_activity_ts: float = time.monotonic()
self.active_connections: int = 0
self.shutdown_event: asyncio.Event = asyncio.Event()
# -- request parsing ----------------------------------------------------
async def _read_request(
self, reader: asyncio.StreamReader
) -> tuple[str, str, dict[str, str], bytes] | None:
"""Parse one HTTP request. Returns (method, target, headers, body) or None on EOF."""
request_line = await reader.readline()
if not request_line:
return None
try:
method, target, _version = request_line.decode("latin-1").rstrip("\r\n").split(" ", 2)
except ValueError:
raise _HttpError(_BAD_REQUEST, "malformed request line")
headers: dict[str, str] = {}
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
raw = line.decode("latin-1").rstrip("\r\n")
if ":" in raw:
name, _, value = raw.partition(":")
headers[name.strip().lower()] = value.strip()
body = b""
length = headers.get("content-length")
if length is not None:
try:
n = int(length)
except ValueError:
raise _HttpError(_BAD_REQUEST, "bad content-length")
if n > _MAX_BODY_BYTES:
raise _HttpError(_BAD_REQUEST, "body too large")
if n > 0:
body = await reader.readexactly(n)
return method.upper(), target, headers, body
# -- routing ------------------------------------------------------------
async def _route(self, method: str, target: str, body: bytes) -> tuple[int, str, str]:
"""Return (status, content_type, text). Raises _HttpError for 4xx/5xx shapes."""
split = urlsplit(target)
path = split.path
query = parse_qs(split.query)
if path == "/healthz":
if method != "GET":
raise _HttpError(_METHOD_NOT_ALLOWED, "GET only")
return 200, "application/json", json.dumps({"ok": True})
if path == "/memory/session-context":
if method != "GET":
raise _HttpError(_METHOD_NOT_ALLOWED, "GET only")
session_id = (query.get("session_id") or [None])[0]
params: dict = {}
if session_id:
params["session_id"] = session_id
# Per-call wake_depth override (minimal|standard|deep). standard/deep
# populate l0/l1/l2/rich_club -- the content worth injecting into a
# system prompt; core.py validates and ignores junk values.
wake_depth = (query.get("wake_depth") or [None])[0]
if wake_depth:
params["wake_depth"] = wake_depth
result = await self._dispatch("session_start_payload", params)
if (query.get("format") or ["json"])[0] == "text":
return 200, "text/plain; charset=utf-8", _payload_to_text(result)
return 200, "application/json", json.dumps(result, ensure_ascii=False)
if path == "/rpc":
if method != "POST":
raise _HttpError(_METHOD_NOT_ALLOWED, "POST only")
try:
envelope = json.loads(body or b"{}")
except json.JSONDecodeError as e:
raise _HttpError(_BAD_REQUEST, f"invalid json: {e}")
if not isinstance(envelope, dict) or not isinstance(envelope.get("method"), str):
raise _HttpError(_BAD_REQUEST, "expected {method, params}")
result = await self._dispatch(envelope["method"], envelope.get("params") or {})
return 200, "application/json", json.dumps(result, ensure_ascii=False)
raise _HttpError(_NOT_FOUND, f"no route for {path}")
async def _dispatch(self, method: str, params: dict) -> Any:
"""Run core.dispatch off-loop (R3) and map raises to _HttpError (R5)."""
try:
# CRITICAL R3: dispatch is sync + 50-500 ms; to_thread prevents
# head-of-line blocking across connections (same as socket_server).
return await asyncio.to_thread(dispatch, self.store, method, params)
except UnknownMethodError as e:
raise _HttpError(_NOT_FOUND, f"unknown method '{e.args[0]}'")
except KeyError as e:
raise _HttpError(_BAD_REQUEST, f"missing required param: {e.args[0]!r}")
except TypeError as e:
raise _HttpError(_BAD_REQUEST, str(e))
except Exception as e: # noqa: BLE001 -- HTTP must never crash daemon
raise _HttpError(_INTERNAL, str(e))
# -- connection handler -------------------------------------------------
async def handle(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
"""One coroutine per connection. Reads one request, responds, closes."""
self.active_connections += 1
t0 = time.monotonic()
method, target = "-", "-"
try:
self.last_activity_ts = t0
try:
parsed = await self._read_request(reader)
if parsed is None:
return
method, target, _headers, body = parsed
status, content_type, text = await self._route(method, target, body)
except _HttpError as e:
status, content_type, text = e.status, "application/json", json.dumps(
{"error": e.message}
)
except asyncio.IncompleteReadError:
status, content_type, text = _BAD_REQUEST, "application/json", json.dumps(
{"error": "truncated body"}
)
self._write_response(writer, status, content_type, text)
await writer.drain()
# Per-request access log -> stderr (journal). Matches the daemon's
# existing {"event":...} JSON-line style so it greps cleanly.
print(
json.dumps({
"event": "http_request",
"method": method,
"target": target,
"status": status,
"bytes": len(text.encode("utf-8")),
"ms": round((time.monotonic() - t0) * 1000),
}),
file=sys.stderr,
flush=True,
)
except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError):
# Client hung up mid-write (fetch aborted, host killed). Expected;
# not a daemon fault -- mirrors socket_server.py's handling.
pass
finally:
self.active_connections -= 1
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
@staticmethod
def _write_response(
writer: asyncio.StreamWriter, status: int, content_type: str, text: str
) -> None:
body = text.encode("utf-8")
reason = {
200: "OK", 400: "Bad Request", 404: "Not Found",
405: "Method Not Allowed", 500: "Internal Server Error",
}.get(status, "OK")
head = (
f"HTTP/1.0 {status} {reason}\r\n"
f"Content-Type: {content_type}\r\n"
f"Content-Length: {len(body)}\r\n"
"Connection: close\r\n"
"\r\n"
)
writer.write(head.encode("latin-1") + body)
# -- lifecycle ----------------------------------------------------------
async def serve(self) -> None:
"""Bind 127.0.0.1:port, write the port-file, serve until shutdown_event."""
server = await asyncio.start_server(self.handle, host=self.host, port=self.port)
self.bound_port = server.sockets[0].getsockname()[1]
# Discovery: publish the live port so consumers needn't guess. Best
# effort -- a failed write just means callers must be told the port.
try:
self.port_file.parent.mkdir(parents=True, exist_ok=True)
self.port_file.write_text(str(self.bound_port), encoding="utf-8")
os.chmod(self.port_file, 0o600)
except OSError:
pass
try:
async with server:
await self.shutdown_event.wait()
server.close()
await server.wait_closed()
finally:
try:
self.port_file.unlink()
except (FileNotFoundError, OSError):
pass
class _HttpError(Exception):
"""Internal control-flow exception carrying an HTTP status + message."""
def __init__(self, status: int, message: str) -> None:
super().__init__(message)
self.status = status
self.message = message