Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
403 lines
13 KiB
Python
403 lines
13 KiB
Python
"""Tests for — the 7th unix-socket message type `session_open`.
|
|
|
|
Covers:
|
|
- Valid session_open message is accepted; reply = {"ok": True, "reason": "session_open_queued"}.
|
|
- Missing session_id is tolerated (optional field per spec).
|
|
- Wrong-typed session_id is rejected at validation.
|
|
- After a valid session_open, state contains:
|
|
* first_turn_pending[session_id] = True
|
|
* hippea_cascade_request with pending=True
|
|
- The 6 prior message types still work (no regression).
|
|
|
|
Uses a real `serve_control_socket(store, lock, state, shutdown)` behind a
|
|
threaded background event-loop so asyncio.run() calls in the test body don't
|
|
tear the server down between requests.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import tempfile
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from iai_mcp import concurrency, daemon_state
|
|
from iai_mcp.concurrency import (
|
|
ProcessLock,
|
|
_dispatch_socket_request,
|
|
_validate_socket_message,
|
|
serve_control_socket,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------- fixtures
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_socket(tmp_path: Path) -> Path:
|
|
"""Short unique unix-socket path (macOS ~104-byte limit)."""
|
|
candidate = tmp_path / "d.sock"
|
|
if len(str(candidate)) > 100:
|
|
candidate = Path(tempfile.mkdtemp(prefix="iai-sock-")) / "d.sock"
|
|
return candidate
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_state(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
|
"""Redirect daemon_state.STATE_PATH to a hermetic tmp file."""
|
|
p = tmp_path / ".daemon-state.json"
|
|
monkeypatch.setattr(daemon_state, "STATE_PATH", p)
|
|
return p
|
|
|
|
|
|
# ---------------------------------------------------------------- unit tests
|
|
|
|
|
|
def test_validate_session_open_accepts_valid_message() -> None:
|
|
ok, err = _validate_socket_message(
|
|
{"type": "session_open", "session_id": "s1", "ts": "2026-04-19T00:00:00Z"}
|
|
)
|
|
assert ok is True
|
|
assert err is None
|
|
|
|
|
|
def test_validate_session_open_accepts_missing_session_id() -> None:
|
|
"""session_id is optional per spec; absence is tolerated."""
|
|
ok, err = _validate_socket_message({"type": "session_open"})
|
|
assert ok is True
|
|
assert err is None
|
|
|
|
|
|
def test_validate_session_open_rejects_non_string_session_id() -> None:
|
|
ok, err = _validate_socket_message(
|
|
{"type": "session_open", "session_id": 123, "ts": "x"}
|
|
)
|
|
assert ok is False
|
|
assert err is not None
|
|
assert "session_id" in err
|
|
|
|
|
|
def test_validate_session_open_rejects_non_string_ts() -> None:
|
|
ok, err = _validate_socket_message(
|
|
{"type": "session_open", "session_id": "s1", "ts": 42}
|
|
)
|
|
assert ok is False
|
|
assert err is not None
|
|
assert "ts" in err
|
|
|
|
|
|
# ---------------------------------------------------------------- dispatcher unit
|
|
|
|
|
|
def _make_fake_store() -> Any:
|
|
return MagicMock()
|
|
|
|
|
|
def _make_fake_lock() -> Any:
|
|
return MagicMock(spec=ProcessLock)
|
|
|
|
|
|
# We call asyncio.run() directly in tests below; no asyncio marker needed.
|
|
|
|
|
|
def test_dispatch_session_open_queues_first_turn_and_cascade(
|
|
tmp_state: Path,
|
|
) -> None:
|
|
"""session_open handler: sets first_turn_pending[session_id]=True AND
|
|
hippea_cascade_request with pending=True; persists via save_state."""
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
req = {
|
|
"type": "session_open",
|
|
"session_id": "sess-abc",
|
|
"ts": "2026-04-19T12:00:00Z",
|
|
}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(req, _make_fake_store(), _make_fake_lock(), state)
|
|
)
|
|
assert resp == {"ok": True, "reason": "session_open_queued"}
|
|
# Flag set for first-turn hook.
|
|
pending = state.get("first_turn_pending")
|
|
assert isinstance(pending, dict)
|
|
stamp = pending.get("sess-abc")
|
|
assert isinstance(stamp, str) and stamp # ISO-8601 timestamp, post-fix
|
|
# Flag set for cascade task.
|
|
cascade = state.get("hippea_cascade_request")
|
|
assert isinstance(cascade, dict)
|
|
assert cascade.get("pending") is True
|
|
assert cascade.get("session_id") == "sess-abc"
|
|
# Echo for introspection.
|
|
last = state.get("last_session_open")
|
|
assert isinstance(last, dict)
|
|
assert last.get("session_id") == "sess-abc"
|
|
# Persisted to disk.
|
|
assert tmp_state.exists()
|
|
on_disk = json.loads(tmp_state.read_text())
|
|
assert on_disk.get("hippea_cascade_request", {}).get("pending") is True
|
|
|
|
|
|
def test_dispatch_session_open_missing_session_id_ok(tmp_state: Path) -> None:
|
|
"""No session_id -> defaults to empty string; still queues cascade."""
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
req = {"type": "session_open", "ts": "2026-04-19T12:00:00Z"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(req, _make_fake_store(), _make_fake_lock(), state)
|
|
)
|
|
assert resp.get("ok") is True
|
|
assert resp.get("reason") == "session_open_queued"
|
|
|
|
|
|
def test_dispatch_session_open_clips_long_session_id(tmp_state: Path) -> None:
|
|
"""session_id is clipped to 128 chars (ASVS V5 output hardening)."""
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
long_id = "a" * 1000
|
|
req = {"type": "session_open", "session_id": long_id, "ts": "x"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(req, _make_fake_store(), _make_fake_lock(), state)
|
|
)
|
|
assert resp["ok"] is True
|
|
last = state.get("last_session_open") or {}
|
|
assert len(last.get("session_id", "")) <= 128
|
|
|
|
|
|
# ---------------------------------------------------------------- no-regression
|
|
|
|
|
|
def test_dispatch_force_wake_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "force_wake", "ts": "x"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp == {"ok": True, "reason": "wake_queued"}
|
|
|
|
|
|
def test_dispatch_force_rem_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "force_rem", "ts": "x"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp == {"ok": True, "reason": "rem_queued"}
|
|
|
|
|
|
def test_dispatch_pause_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "pause"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp == {"ok": True, "paused": True}
|
|
assert state["scheduler_paused"] is True
|
|
|
|
|
|
def test_dispatch_resume_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE", "scheduler_paused": True}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "resume"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp == {"ok": True, "paused": False}
|
|
assert state["scheduler_paused"] is False
|
|
|
|
|
|
def test_dispatch_user_initiated_sleep_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "user_initiated_sleep", "reason": "night", "ts": "x"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp.get("ok") is True
|
|
assert resp.get("state") == "TRANSITIONING"
|
|
|
|
|
|
def test_dispatch_status_still_works(tmp_state: Path) -> None:
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
resp = asyncio.run(
|
|
_dispatch_socket_request(
|
|
{"type": "status"},
|
|
_make_fake_store(),
|
|
_make_fake_lock(),
|
|
state,
|
|
)
|
|
)
|
|
assert resp.get("ok") is True
|
|
assert resp.get("state") == "WAKE"
|
|
# Version echoed per Plan 04-gap-1.
|
|
assert "version" in resp
|
|
|
|
|
|
# ---------------------------------------------------------------- round-trip
|
|
|
|
|
|
class _ThreadedDaemon:
|
|
"""Real serve_control_socket on background thread + event loop.
|
|
|
|
Reuses the pattern from tests/test_core_bedtime_inject.py but drives the
|
|
production _dispatch_socket_request so we exercise the real 7th-message
|
|
path end-to-end.
|
|
"""
|
|
|
|
def __init__(self, path: Path, state: dict) -> None:
|
|
self.path = path
|
|
self.state = state
|
|
self.lock = MagicMock(spec=ProcessLock)
|
|
self.store = MagicMock()
|
|
self.shutdown = None # populated on the loop thread
|
|
self._loop: asyncio.AbstractEventLoop | None = None
|
|
self._thread: threading.Thread | None = None
|
|
self._ready = threading.Event()
|
|
|
|
def start(self) -> None:
|
|
def _run() -> None:
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
self.shutdown = asyncio.Event()
|
|
|
|
async def _serve() -> None:
|
|
# Hand the real dispatcher the state we own.
|
|
async def _dispatcher(req: dict) -> dict:
|
|
return await _dispatch_socket_request(
|
|
req, self.store, self.lock, self.state
|
|
)
|
|
|
|
task = asyncio.create_task(
|
|
serve_control_socket(
|
|
self.store,
|
|
self.lock,
|
|
self.state,
|
|
self.shutdown, # type: ignore[arg-type]
|
|
dispatcher=_dispatcher,
|
|
socket_path=self.path,
|
|
)
|
|
)
|
|
# Give the server a moment to bind before signalling ready.
|
|
await asyncio.sleep(0.1)
|
|
self._ready.set()
|
|
await task
|
|
|
|
try:
|
|
self._loop.run_until_complete(_serve())
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
self._loop.close()
|
|
except Exception:
|
|
pass
|
|
|
|
self._thread = threading.Thread(target=_run, daemon=True)
|
|
self._thread.start()
|
|
assert self._ready.wait(timeout=5.0), "threaded daemon failed to start"
|
|
|
|
def stop(self) -> None:
|
|
if self._loop is None:
|
|
return
|
|
if self.shutdown is not None:
|
|
self._loop.call_soon_threadsafe(self.shutdown.set)
|
|
self._thread and self._thread.join(timeout=5.0)
|
|
|
|
|
|
async def _send(path: Path, msg: dict, *, timeout: float = 5.0) -> dict:
|
|
reader, writer = await asyncio.open_unix_connection(str(path))
|
|
try:
|
|
writer.write((json.dumps(msg) + "\n").encode("utf-8"))
|
|
await writer.drain()
|
|
line = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
|
return json.loads(line)
|
|
finally:
|
|
try:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def test_session_open_end_to_end_round_trip(
|
|
tmp_socket: Path, tmp_state: Path,
|
|
) -> None:
|
|
"""Real NDJSON round-trip over a unix socket — the 7th message type."""
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
daemon = _ThreadedDaemon(tmp_socket, state)
|
|
daemon.start()
|
|
try:
|
|
resp = asyncio.run(
|
|
_send(
|
|
tmp_socket,
|
|
{
|
|
"type": "session_open",
|
|
"session_id": "e2e-sess-1",
|
|
"ts": "2026-04-19T12:00:00Z",
|
|
},
|
|
)
|
|
)
|
|
assert resp == {"ok": True, "reason": "session_open_queued"}
|
|
# State mutations visible to the test after the reply.
|
|
pending = state.get("first_turn_pending")
|
|
assert isinstance(pending, dict)
|
|
stamp = pending.get("e2e-sess-1")
|
|
assert isinstance(stamp, str) and stamp # ISO-8601 timestamp, post-fix
|
|
cascade = state.get("hippea_cascade_request")
|
|
assert isinstance(cascade, dict)
|
|
assert cascade.get("pending") is True
|
|
finally:
|
|
daemon.stop()
|
|
|
|
|
|
def test_session_open_does_not_regress_other_6_types(
|
|
tmp_socket: Path, tmp_state: Path,
|
|
) -> None:
|
|
"""Force_wake / force_rem / pause / resume / status / user_initiated_sleep
|
|
all still succeed end-to-end."""
|
|
state: dict = {"fsm_state": "WAKE"}
|
|
daemon = _ThreadedDaemon(tmp_socket, state)
|
|
daemon.start()
|
|
try:
|
|
# force_wake
|
|
r = asyncio.run(_send(tmp_socket, {"type": "force_wake", "ts": "x"}))
|
|
assert r == {"ok": True, "reason": "wake_queued"}
|
|
# force_rem
|
|
r = asyncio.run(_send(tmp_socket, {"type": "force_rem", "ts": "x"}))
|
|
assert r == {"ok": True, "reason": "rem_queued"}
|
|
# pause
|
|
r = asyncio.run(_send(tmp_socket, {"type": "pause"}))
|
|
assert r.get("ok") is True
|
|
# resume
|
|
r = asyncio.run(_send(tmp_socket, {"type": "resume"}))
|
|
assert r.get("ok") is True
|
|
# status
|
|
r = asyncio.run(_send(tmp_socket, {"type": "status"}))
|
|
assert r.get("ok") is True
|
|
# user_initiated_sleep (state is WAKE so this transitions)
|
|
r = asyncio.run(
|
|
_send(
|
|
tmp_socket,
|
|
{"type": "user_initiated_sleep", "reason": "night", "ts": "x"},
|
|
)
|
|
)
|
|
assert r.get("ok") is True
|
|
finally:
|
|
daemon.stop()
|