Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
22
.env.example
Normal file
22
.env.example
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# IAI-MCP environment variables (all optional)
|
||||
#
|
||||
# Copy this file to `.env` and fill in only what you need. The daemon runs
|
||||
# fully offline by default — no API key is required for core memory functions
|
||||
# (capture, recall, sleep consolidation).
|
||||
#
|
||||
# IAI-MCP scrubs the variables below from the host CLI subprocess environment
|
||||
# at spawn time as a defence-in-depth measure (see src/iai_mcp/host_cli.py).
|
||||
# They are listed here only as documented placeholders so you know they exist.
|
||||
|
||||
# ANTHROPIC_API_KEY=
|
||||
# CLAUDE_API_KEY=
|
||||
# CLAUDE_CODE_API_KEY=
|
||||
|
||||
# Override the default storage root (~/.iai-mcp). Useful for tests or
|
||||
# multi-instance setups. Must be a writable directory.
|
||||
# IAI_MCP_STORE=
|
||||
|
||||
# Override the embedder. Default is "bge-small-en-v1.5" (English, 384d).
|
||||
# Other supported values: "bge-m3" (multilingual, 1024d), "all-MiniLM-L6-v2"
|
||||
# (English, 384d, lighter weight).
|
||||
# IAI_MCP_EMBED_MODEL=
|
||||
61
.gitignore
vendored
Normal file
61
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
# Test / type / lint / coverage caches
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov/
|
||||
coverage/
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
package-lock.json.bak
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Env
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Local data stores
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.db
|
||||
*.lancedb
|
||||
lancedb/
|
||||
runtime_*.json
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2026 Areg Aramovich Noya
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
309
README.md
Normal file
309
README.md
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
<p align="center">
|
||||
<img src="logo.png" alt="IAI-MCP" width="600">
|
||||
</p>
|
||||
|
||||
|
||||
<h3 align="center">The best-benchmarked open-source memory system for AI coding assistants.</h3>
|
||||
<p align="center">Every claim ships with the harness that proves it. Run the benchmarks yourself.</p>
|
||||
|
||||
---
|
||||
|
||||
# iai-mcp
|
||||
|
||||
*Independent Autistic Intelligence — a local memory layer for Claude (and other MCP-compatible assistants).*
|
||||
|
||||
## About the name
|
||||
|
||||
*IAI* stands for Independent Autistic Intelligence.
|
||||
|
||||
- **Independent.** Fully local. The daemon runs on your machine, embeddings are computed locally, no telemetry, no cloud dependency. Your memory is your data and stays your data.
|
||||
- **Autistic.** Describes the memory style, not a diagnosis or a metaphor. The memory is built around verbatim recall, attention to specific cues, and refusal to smooth rare events into typical ones. Most memory systems compress and summarize aggressively, aiming to give the assistant a *gist* of the past. This one preserves what was actually said and surfaces it on a precise cue. The trade-off is intentional: more storage and a stricter retrieval interface, in exchange for not losing details.
|
||||
- **Intelligence.** Used in the systems sense, something that observes, adapts, and stays viable over time, not the marketing sense.
|
||||
|
||||
---
|
||||
|
||||
## What it is
|
||||
|
||||
A local server that speaks the [MCP protocol](https://modelcontextprotocol.io) and gives Claude, and any other MCP-compatible assistant, a long-term memory. It captures every turn of every session verbatim, organizes those captures over time into a personal map of who you are, and serves a small slice of relevant memory back at the start of each new conversation. You never have to say *"remember this"* or *"what did we say last time?"*.
|
||||
|
||||
I built this for myself. It worked. I've been running it daily for months, and now I'm sharing it. The benchmarks were mostly for my own curiosity. I wanted to know if it actually works or if I'd just gotten used to it.
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
You do not call `iai-mcp` directly during a session. Once it's connected:
|
||||
|
||||
Capture is automatic. Every turn, yours and the assistant's, is recorded verbatim with timestamps and session metadata. You don't say *"remember this."*
|
||||
|
||||
Recall is automatic. When a new session starts, the daemon assembles a small relevant slice of your history and injects it into the conversation prefix. You don't say *"what did we say."*
|
||||
|
||||
Consolidation runs idle. Between sessions, the daemon merges duplicates, strengthens recall pathways for things retrieved often, and prunes weak edges. The system gets quietly better at remembering you over time.
|
||||
|
||||
After a few weeks of regular use the difference becomes noticeable. The assistant stops asking the same orientation questions, references things you mentioned in passing, and adapts to your style without being told.
|
||||
|
||||
---
|
||||
|
||||
## How it works
|
||||
|
||||
The daemon is a Python process that runs in the background. Your MCP client connects to it via a Unix socket. No network exposure.
|
||||
|
||||
Memory is stored in three tiers:
|
||||
|
||||
*Episodic* is verbatim, timestamped fragments of what was said. Write-once, never overwritten or rewritten.
|
||||
|
||||
*Semantic* is summaries induced from clusters of related episodes during idle-time consolidation.
|
||||
|
||||
*Procedural* is a small set of stable parameters about you, learned over time: preferences, style cues, recurring patterns. Eleven sealed knobs that shift based on what works.
|
||||
|
||||
A background pass runs periodically (sleep cycles): it clusters episodes, builds semantic summaries, decays old unreinforced connections, and reinforces frequently co-retrieved paths. Things you haven't revisited fade naturally. There's an optional "insight of the day" step that makes one Anthropic API call, but it's off by default.
|
||||
|
||||
Recall combines three signals: semantic similarity, graph-link strength, and recency. All ranked together.
|
||||
|
||||
All records are encrypted at rest with AES-256-GCM. The key lives in `~/.iai-mcp/.key` (mode 0600). Back it up. Lose the key, lose the memories.
|
||||
|
||||
Everything lives at `~/.iai-mcp/`. Embeddings are computed locally with `bge-small-en-v1.5`. The only data that leaves the machine is your normal conversation with whatever LLM API your client uses.
|
||||
|
||||
```
|
||||
Claude Code <--MCP-stdio--> TypeScript wrapper <--UNIX socket--> Python daemon <--> LanceDB
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Quick start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- macOS or Linux (Apple Silicon and x86_64 tested)
|
||||
- Python 3.11 or 3.12
|
||||
- Node.js 18+
|
||||
- [Claude Code](https://docs.claude.com/en/docs/claude-code/overview) as the MCP host
|
||||
- ~500 MB free disk
|
||||
|
||||
Windows not supported. WSL2 untested.
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
git clone https://github.com/CodeAbra/iai-mcp.git
|
||||
cd iai-mcp
|
||||
bash scripts/install.sh
|
||||
```
|
||||
|
||||
The installer creates a Python venv, installs dependencies (LanceDB, sentence-transformers, torch-hd, NetworkX, igraph), builds the TypeScript MCP wrapper, pre-downloads the default embedding model (~130 MB), symlinks the CLI to `~/.local/bin/iai-mcp`, and on macOS registers the daemon with launchd.
|
||||
|
||||
Make sure `~/.local/bin` is on your `PATH`:
|
||||
|
||||
```bash
|
||||
export PATH="$HOME/.local/bin:$PATH" # add to ~/.zshrc or ~/.bashrc
|
||||
iai-mcp --version
|
||||
```
|
||||
|
||||
On Linux, install the systemd unit manually:
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.config/systemd/user
|
||||
cp deploy/systemd/iai-mcp-daemon.service ~/.config/systemd/user/
|
||||
systemctl --user daemon-reload
|
||||
systemctl --user enable iai-mcp-daemon
|
||||
systemctl --user start iai-mcp-daemon
|
||||
```
|
||||
|
||||
### Install the Stop hook
|
||||
|
||||
This is what makes capture ambient. Without it you'd have to save memories by hand.
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.claude/hooks
|
||||
cp deploy/hooks/iai-mcp-session-capture.sh ~/.claude/hooks/
|
||||
chmod +x ~/.claude/hooks/iai-mcp-session-capture.sh
|
||||
```
|
||||
|
||||
Register in `~/.claude/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"Stop": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "$HOME/.claude/hooks/iai-mcp-session-capture.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Connect Claude
|
||||
|
||||
```bash
|
||||
claude mcp add iai-mcp -- node "$(pwd)/mcp-wrapper/dist/index.js"
|
||||
```
|
||||
|
||||
Or edit `~/.claude.json` directly:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"iai-mcp": {
|
||||
"command": "node",
|
||||
"args": ["/absolute/path/to/iai-mcp/mcp-wrapper/dist/index.js"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use the absolute path. `~` and `$HOME` won't expand here.
|
||||
|
||||
For Claude Desktop (untested), edit `~/Library/Application Support/Claude/claude_desktop_config.json` on macOS or the equivalent path on Linux.
|
||||
|
||||
### Verify
|
||||
|
||||
```bash
|
||||
iai-mcp doctor
|
||||
iai-mcp daemon status
|
||||
```
|
||||
|
||||
Restart Claude Code. Start a session, do some work, exit. Then:
|
||||
|
||||
```bash
|
||||
tail ~/.iai-mcp/logs/capture-$(date -u +%Y-%m-%d).log
|
||||
```
|
||||
|
||||
You should see a `rc=0` line. That's your first memory.
|
||||
|
||||
---
|
||||
|
||||
## Doctor
|
||||
|
||||
`iai-mcp doctor` runs 14 checks against the daemon, the store, and the runtime state. Output is one line per check: PASS, WARN, or FAIL.
|
||||
|
||||
```bash
|
||||
iai-mcp doctor
|
||||
```
|
||||
|
||||
What it checks:
|
||||
|
||||
| # | Check | What it means |
|
||||
|---|---|---|
|
||||
| a | Daemon alive | Is the daemon process running? |
|
||||
| b | Socket fresh | Can the UNIX socket accept a connection? |
|
||||
| c | Lock healthy | Is the process lock held correctly? |
|
||||
| d | No orphan core | No leftover stdio core process without a daemon |
|
||||
| e | State file valid | `.daemon-state.json` parses and has expected fields |
|
||||
| f | LanceDB readable | Can the records table be opened and queried? |
|
||||
| g | No duplicate binders | Only one process is bound to the socket |
|
||||
| h | Crypto file state | Encryption key exists, correct permissions (0600) |
|
||||
| i | Lance versions count | LanceDB version manifests aren't piling up |
|
||||
| j | Lifecycle current state | Current FSM state is valid |
|
||||
| k | Lifecycle history 24h | Recent lifecycle transitions look sane |
|
||||
| l | Sleep cycle status | Last sleep cycle completed or is running normally |
|
||||
| m | Heartbeat scanner | Wrapper heartbeat files are fresh |
|
||||
| n | HID idle source | Idle detection source is available |
|
||||
|
||||
14/14 PASS is healthy. 13/14 with check (b) failing during a sleep cycle is also normal (the socket is busy during consolidation). Multiple FAILs or a FAIL on (a) or (f) means something is actually wrong.
|
||||
|
||||
---
|
||||
|
||||
## Notes for AI assistants helping with installation
|
||||
|
||||
If you're an AI model helping a user install or debug iai-mcp, read this before diagnosing anything.
|
||||
|
||||
The daemon has four FSM states: `WAKE`, `TRANSITIONING`, `SLEEP`, `DREAMING`. Separately, the wrapper-layer lifecycle has four states: `WAKE`, `DROWSY`, `SLEEP`, `HIBERNATION`. These are two different state machines in two different files. Don't confuse them.
|
||||
|
||||
Things that look broken but are not:
|
||||
|
||||
- `last_tick_at` hasn't changed in hours. This is a sleep-cycle tick timestamp, not a process heartbeat. The daemon can be alive and healthy with a stale `last_tick_at`. Check `iai-mcp daemon status` instead.
|
||||
- `fsm_state: DREAMING` for hours. Normal. The daemon is running a sleep cycle. Wake transitions are event-driven (a new session arriving wakes it), not time-driven.
|
||||
- `status: deferred` in the capture log. This is the only path the Stop hook takes. It writes captures to `~/.iai-mcp/.deferred-captures/` and the daemon picks them up on its next tick. This is by design, not a failure.
|
||||
- `iai-mcp doctor` shows 13/14 PASS with check (b) "socket unreachable." Expected when the daemon FSM is in SLEEP or DREAMING. The socket is intentionally unresponsive during consolidation passes. This is healthy.
|
||||
- Cold start takes a few seconds. The daemon needs to initialize the LanceDB store and load the embedding model into memory. Don't kill it because it hasn't responded in 5 seconds.
|
||||
|
||||
Actual problems look like: `iai-mcp daemon status` returning `ok: False`, doctor showing multiple FAIL results, or the daemon process not existing at all (`pgrep -f iai_mcp.daemon` returns nothing).
|
||||
|
||||
When in doubt, run `iai-mcp doctor` and read what it says. The output is self-explanatory.
|
||||
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
|
||||
I made these because I wanted honest numbers. Every harness ships in `bench/`. Run them on your machine, get your own results.
|
||||
|
||||
| Metric | Target | Measured |
|
||||
|---|---|---|
|
||||
| Verbatim recall (byte-exact) | >=99% | >=99% at N=10k |
|
||||
| Recall p95 latency | <100 ms | <100 ms at N=10k |
|
||||
| RAM at steady state | <=300 MB | ~150-300 MB |
|
||||
| Session-start tokens (warm cache) | <=3,000 | <=3,000 |
|
||||
| Session-start tokens (cold) | <=8,000 | <=8,000 |
|
||||
|
||||
```bash
|
||||
python -m bench.verbatim # verbatim fidelity
|
||||
python -m bench.neural_map # recall latency
|
||||
python -m bench.memory_footprint # RAM usage
|
||||
python -m bench.tokens # session-start cost
|
||||
python -m bench.total_session_cost # full 10-turn cost
|
||||
python -m bench.trajectory # 30-session corpus
|
||||
python -m bench.contradiction_longitudinal # falsifiability
|
||||
python -m bench.longmemeval_blind # LongMemEval-S blind run
|
||||
```
|
||||
|
||||
The LongMemEval-S run is blind on purpose. No dataset-specific tuning, no hyperparameter sweep. The numbers are what they are.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
| Variable | Default | What it does |
|
||||
|---|---|---|
|
||||
| `IAI_MCP_STORE` | `~/.iai-mcp/` | Data directory |
|
||||
| `IAI_MCP_EMBED_MODEL` | `bge-small-en-v1.5` | Embedding model. `bge-m3` for multilingual at ~3x size. |
|
||||
|
||||
Switching embedders requires re-embedding the store: `iai-mcp migrate reembed`.
|
||||
|
||||
---
|
||||
|
||||
## Status and limitations
|
||||
|
||||
This is experimental. I built it for myself, it works on my machine, and I'm sharing it because it might be useful to you. No SLA, no support guarantee. Breaking changes are possible between versions. Pin a commit hash if you depend on stability.
|
||||
|
||||
Limitations worth knowing about:
|
||||
|
||||
- The default embedding model is English-only. The assistant translates to English on the way into memory. The opt-in `bge-m3` model removes this constraint at a cost of ~3x storage and slower indexing.
|
||||
- No cross-machine sync. The data lives where the daemon runs. Backup is `cp -a ~/.iai-mcp/` somewhere safe.
|
||||
- No GUI. Inspection happens through CLI subcommands (`iai-mcp doctor`, `iai-mcp daemon status`, `iai-mcp topology`).
|
||||
- Cold start on a freshly booted machine takes a few seconds while the daemon initializes caches.
|
||||
- Recall quality on the first ~10 sessions is mediocre. The system needs material to consolidate before it gets useful.
|
||||
|
||||
---
|
||||
|
||||
## Compatibility
|
||||
|
||||
Claude Code is the primary host, validated in daily use.
|
||||
|
||||
Claude Desktop should work (uses `claude_desktop_config.json` instead of `~/.claude.json`) but hasn't been tested end to end.
|
||||
|
||||
Other MCP-over-stdio hosts speak the same protocol and should work in principle. Not tested.
|
||||
|
||||
If you get it running on something else, open an issue or PR.
|
||||
|
||||
---
|
||||
|
||||
## Authors
|
||||
|
||||
By Areg Aramovich Noya, in collaboration with the team at [lcgc.dev](https://lcgc.dev).
|
||||
|
||||
I built this because I needed it. It works for me. If it works for you, take it.
|
||||
|
||||
## License
|
||||
|
||||
[MIT](LICENSE)
|
||||
|
||||
## Contributing
|
||||
|
||||
Issues and PRs welcome. If your change touches retrieval, capture, or consolidation, include bench re-runs.
|
||||
10
bench/__init__.py
Normal file
10
bench/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""IAI-MCP benchmark harness.
|
||||
|
||||
Phase-1 benchmarks:
|
||||
- bench.tokens -- (steady <=3000) + (fresh <=8000)
|
||||
- bench.verbatim -- (verbatim recall >=99% on pinned records)
|
||||
|
||||
Both runners are invokable as CLIs (`python -m bench.tokens`, `python -m bench.verbatim`)
|
||||
and exit non-zero on failure. They fall back to a heuristic token count when
|
||||
ANTHROPIC_API_KEY is absent so CI (and first-time users) can run the suite offline.
|
||||
"""
|
||||
1
bench/adapters/__init__.py
Normal file
1
bench/adapters/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""bench/adapters — external-benchmark adapters (Plan 05-11 OPS-17, M-08)."""
|
||||
275
bench/adapters/longmemeval.py
Normal file
275
bench/adapters/longmemeval.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""LongMemEval adapter — / external-bench gate.
|
||||
|
||||
Wires the public LongMemEval memory benchmark (Xie et al., 2024) into the
|
||||
IAI-MCP public API (MemoryStore.insert + retrieve.recall). Strict blind-run
|
||||
discipline: no per-dataset tuning, no field-mapping optimisation, no
|
||||
embedder finetune. The adapter is the ONLY translation layer; everything
|
||||
downstream is stock IAI-MCP.
|
||||
|
||||
## Dataset source
|
||||
|
||||
The plan text (05-11-PLAN.md) cites ``lxucs/longmemeval`` — that repo does
|
||||
NOT exist on HuggingFace Hub (returns 401/Not Found). The canonical public
|
||||
mirror shipped by the paper authors is ``xiaowu0162/longmemeval``.
|
||||
Discovered mid-execution; documented as a Rule 3 deviation in the Plan
|
||||
05-11 SUMMARY. DATASET_ID points at the live mirror; PINNED_REVISION is
|
||||
the 40-char commit hash resolved at execution time so numbers reproduce.
|
||||
|
||||
## Row schema (longmemeval_s split, 500 rows)
|
||||
|
||||
Each row is:
|
||||
|
||||
{
|
||||
"question_id": str (8-hex),
|
||||
"question_type": str (single-session-user, multi-session, ...),
|
||||
"question": str,
|
||||
"answer": str,
|
||||
"question_date": str ("YYYY/MM/DD (Day) HH:MM"),
|
||||
"haystack_dates": list[str],
|
||||
"haystack_session_ids": list[str] # len ~54
|
||||
"haystack_sessions": list[list[{"role","content"}]]
|
||||
"answer_session_ids": list[str] # gold evidence (len typically 1)
|
||||
}
|
||||
|
||||
## LMESession mapping (Plan 05-11 deviation, Rule 1/3)
|
||||
|
||||
The plan's interface says "one session -> many queries". The actual dataset
|
||||
is "one query -> many haystack sessions". We therefore flatten each row to
|
||||
a list of LMESession objects — one per haystack session — with the single
|
||||
eval query attached to every session in the row (so
|
||||
bench/longmemeval_blind.py can iterate LMESessions, insert haystack turns,
|
||||
and run the query against the store). The orchestrator (not the adapter)
|
||||
scores at the standard LongMemEval session-ID granularity.
|
||||
|
||||
The ``score_r_at_k`` method in this module implements the plan's literal
|
||||
formula ``|retrieved ∩ relevant| / |relevant|`` over UUIDs — it is unit-
|
||||
testable and matches the Test 4 contract. The orchestrator also
|
||||
reports session-level R@k using the dataset's native session_id gold.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
# Local imports kept lazy-friendly by using a distinct alias so tests can
|
||||
# mock ``bench.adapters.longmemeval.retrieve_recall`` without touching the
|
||||
# production retrieve module wholesale.
|
||||
from iai_mcp.retrieve import recall as retrieve_recall
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
|
||||
DATASET_ID: str = "xiaowu0162/longmemeval"
|
||||
# Pinned at execution time (2026-04-20) against the
|
||||
# canonical LongMemEval HuggingFace mirror. Reproducers MUST load this
|
||||
# exact revision or disclose the drift.
|
||||
PINNED_REVISION: str = "2ec2a557f339b6c0369619b1ed5793734cc87533"
|
||||
# Split -> filename (the repo ships configs ``longmemeval_s``,
|
||||
# ``longmemeval_m``, ``longmemeval_oracle``). runs the S split.
|
||||
_SPLIT_FILENAMES: dict[str, str] = {
|
||||
"S": "longmemeval_s",
|
||||
"M": "longmemeval_m",
|
||||
"oracle": "longmemeval_oracle",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMESession:
|
||||
"""One flattened haystack session + its attached eval query.
|
||||
|
||||
See module docstring for why this differs from the plan's original
|
||||
"one session many queries" spec.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
turns: list[dict] # [{"role": "user"|"assistant", "content": str}]
|
||||
queries: list[dict] # [{"query": str, "relevant_turn_ids": list[str]}]
|
||||
|
||||
|
||||
class LongMemEvalAdapter:
|
||||
"""Public API: load_dataset / session_to_inserts / query_to_recall /
|
||||
score_r_at_k."""
|
||||
|
||||
DATASET_ID: str = DATASET_ID
|
||||
PINNED_REVISION: str = PINNED_REVISION
|
||||
|
||||
def __init__(self, revision: str | None = None) -> None:
|
||||
self.revision = revision or self.PINNED_REVISION
|
||||
|
||||
# --------------------------------------------------------------- load
|
||||
|
||||
def load_dataset(self, split: str = "S") -> Iterable[LMESession]:
|
||||
"""Stream LMESessions out of the LongMemEval-<split> JSON file.
|
||||
|
||||
Uses ``huggingface_hub.hf_hub_download`` to grab the split file at
|
||||
the pinned revision (the datasets library's JSON auto-detection
|
||||
breaks on this repo because the files ship without a ``.json``
|
||||
extension — see README). Falls back to raising a clear error if
|
||||
HuggingFace is unreachable and nothing is cached.
|
||||
"""
|
||||
import json
|
||||
|
||||
filename = _SPLIT_FILENAMES.get(split)
|
||||
if filename is None:
|
||||
raise ValueError(
|
||||
f"unknown LongMemEval split {split!r}; "
|
||||
f"expected one of {sorted(_SPLIT_FILENAMES)}"
|
||||
)
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError as exc: # pragma: no cover — dev extra
|
||||
raise RuntimeError(
|
||||
"huggingface_hub not installed; run "
|
||||
"`pip install 'datasets>=2.18' huggingface_hub`"
|
||||
) from exc
|
||||
|
||||
print(
|
||||
f"[LongMemEval] resolving split={split} "
|
||||
f"revision={self.revision} filename={filename}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
path = hf_hub_download(
|
||||
repo_id=self.DATASET_ID,
|
||||
filename=filename,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
rows = json.load(f)
|
||||
|
||||
for row in rows:
|
||||
qid = row["question_id"]
|
||||
question = row["question"]
|
||||
# bench/lme500: capture question_type for per-type breakdown.
|
||||
question_type = str(row.get("question_type", "unknown"))
|
||||
answer_session_ids = list(row.get("answer_session_ids", []))
|
||||
haystack_session_ids: list[str] = list(
|
||||
row.get("haystack_session_ids", [])
|
||||
)
|
||||
haystack_sessions: list[list[dict]] = list(
|
||||
row.get("haystack_sessions", [])
|
||||
)
|
||||
|
||||
# Emit one LMESession per haystack session; attach the eval
|
||||
# query to every one so the orchestrator can run ONE recall
|
||||
# per row after inserting all haystack turns.
|
||||
#
|
||||
# The "relevant_turn_ids" field stays session-id-based (the
|
||||
# paper's native gold). We record which session is "gold" so
|
||||
# the orchestrator can score hits.
|
||||
for sess_id, turns in zip(
|
||||
haystack_session_ids, haystack_sessions
|
||||
):
|
||||
yield LMESession(
|
||||
session_id=sess_id,
|
||||
turns=list(turns),
|
||||
queries=[
|
||||
{
|
||||
"query": question,
|
||||
"question_id": qid,
|
||||
"question_type": question_type,
|
||||
# Gold at session granularity; the orchestrator
|
||||
# decides how to use it. score_r_at_k in this
|
||||
# adapter takes whatever the caller passes.
|
||||
"relevant_turn_ids": answer_session_ids,
|
||||
"is_gold_session": sess_id in answer_session_ids,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------- session_to_inserts
|
||||
|
||||
def session_to_inserts(self, session: LMESession) -> list[MemoryRecord]:
|
||||
"""Map each turn to one MemoryRecord (tier=episodic, literal_surface=content).
|
||||
|
||||
Produces a placeholder embedding sized to the default embed dim.
|
||||
The blind-run orchestrator overrides the embedding with the real
|
||||
one from ``embedder_for_store(store).embed(text)`` before calling
|
||||
``store.insert`` — this keeps ``session_to_inserts`` cheap for
|
||||
unit tests that don't want to load sentence-transformers.
|
||||
"""
|
||||
from iai_mcp.embed import Embedder
|
||||
|
||||
dim = Embedder.DEFAULT_DIM
|
||||
records: list[MemoryRecord] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
for turn in session.turns:
|
||||
content = str(turn.get("content", ""))
|
||||
rec = MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=[0.0] * dim, # placeholder; orchestrator overrides
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=[
|
||||
"longmemeval",
|
||||
f"role:{turn.get('role','user')}",
|
||||
f"session:{session.session_id}",
|
||||
],
|
||||
language="en",
|
||||
)
|
||||
records.append(rec)
|
||||
return records
|
||||
|
||||
# ------------------------------------------------------- query_to_recall
|
||||
|
||||
def query_to_recall(self, query: dict, store) -> list[UUID]:
|
||||
"""Call retrieve.recall(cue_text=query['query'], k_hits=10).
|
||||
|
||||
Returns the retrieved record ids in rank order. The orchestrator
|
||||
uses these ids to compute R@k.
|
||||
"""
|
||||
cue_text = str(query["query"])
|
||||
embedder = embedder_for_store(store)
|
||||
cue_embedding = embedder.embed(cue_text)
|
||||
resp = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_embedding,
|
||||
cue_text=cue_text,
|
||||
session_id="longmemeval-blind",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
return [hit.record_id for hit in resp.hits]
|
||||
|
||||
# ------------------------------------------------------- score_r_at_k
|
||||
|
||||
def score_r_at_k(
|
||||
self,
|
||||
retrieved_ids: list,
|
||||
gold_turn_ids: list,
|
||||
k: int = 5,
|
||||
) -> float:
|
||||
"""R@k = |retrieved_top_k ∩ relevant| / |relevant|.
|
||||
|
||||
Empty ``gold_turn_ids`` returns 1.0 (convention — avoids div-by-zero
|
||||
and matches the "no evidence to miss" semantics).
|
||||
|
||||
Both lists are normalised to ``str`` so UUID vs session-id ids work.
|
||||
"""
|
||||
if not gold_turn_ids:
|
||||
return 1.0
|
||||
top_k = retrieved_ids[: max(0, int(k))]
|
||||
gold_set = {str(g) for g in gold_turn_ids}
|
||||
hit = sum(1 for rid in top_k if str(rid) in gold_set)
|
||||
return hit / float(len(gold_set))
|
||||
163
bench/adapters/longmemeval_cleaned.py
Normal file
163
bench/adapters/longmemeval_cleaned.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Cleaned-dataset adapter for LongMemEval-S — D-02.
|
||||
|
||||
Mempalace's reference benchmark uses ``xiaowu0162/longmemeval-cleaned``
|
||||
(commit-pinned via ``huggingface_hub.repo_info()``). This adapter mirrors
|
||||
the ``LongMemEvalAdapter`` shape from ``bench/adapters/longmemeval.py`` so
|
||||
the orchestrator (`bench/longmemeval_blind.py`) can swap raw vs cleaned
|
||||
purely via the ``--dataset {cleaned, raw}`` CLI flag.
|
||||
|
||||
## boundary
|
||||
|
||||
This adapter is NEW (Phase 9 Task 1). The raw adapter at
|
||||
``bench/adapters/longmemeval.py`` is byte-identical to its v2 state — Phase
|
||||
9 does NOT modify the v1/v2 baseline path. ``--dataset raw`` continues to
|
||||
load the raw revision ``2ec2a557f339...``; ``--dataset cleaned`` (the new
|
||||
v3 default) routes to this module.
|
||||
|
||||
## Pinning discipline
|
||||
|
||||
Phase 9 LOCKED: pin via ``huggingface_hub.repo_info(...)``, NEVER
|
||||
hardcode a magic string. The cleaned dataset's HEAD SHA is auto-discovered
|
||||
on first instantiation and stored on ``self.revision`` so v3 output JSON
|
||||
records exactly which dataset variant was measured. On reproducer runs,
|
||||
the caller may pass ``revision=`` to pin a specific historical SHA.
|
||||
|
||||
## Schema
|
||||
|
||||
The cleaned dataset uses the same row schema as the raw dataset (cleaned
|
||||
removed bad evidence; field names preserved). Each row in
|
||||
``longmemeval_s_cleaned.json`` is:
|
||||
|
||||
{
|
||||
"question_id": str,
|
||||
"question_type": str,
|
||||
"question": str,
|
||||
"haystack_session_ids": list[str],
|
||||
"haystack_sessions": list[list[{"role","content"}]],
|
||||
"answer_session_ids": list[str],
|
||||
}
|
||||
|
||||
The adapter emits one ``LMESession`` per haystack session with the eval
|
||||
query attached (matching the raw adapter's emission shape exactly), so
|
||||
``main()`` in ``longmemeval_blind.py`` does NOT branch on adapter type —
|
||||
it groups LMESessions by ``question_id`` either way.
|
||||
|
||||
## Split support
|
||||
|
||||
Only ``split="S"`` is supported. The cleaned dataset ships only the S split
|
||||
as ``longmemeval_s_cleaned.json``; M and oracle remain in the raw dataset.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from typing import Iterable
|
||||
|
||||
from bench.adapters.longmemeval import LMESession
|
||||
|
||||
|
||||
CLEANED_DATASET_ID: str = "xiaowu0162/longmemeval-cleaned"
|
||||
CLEANED_FILENAME: str = "longmemeval_s_cleaned.json"
|
||||
|
||||
|
||||
class CleanedLongMemEvalAdapter:
|
||||
"""Loads ``xiaowu0162/longmemeval-cleaned`` via ``huggingface_hub``.
|
||||
|
||||
Mirrors ``LongMemEvalAdapter`` so ``bench/longmemeval_blind.py`` can
|
||||
treat them interchangeably (same ``LMESession`` iterator shape).
|
||||
|
||||
Pin discipline: ``revision`` defaults to the current HEAD SHA of the
|
||||
HuggingFace dataset, auto-discovered via ``repo_info()``. Pass an
|
||||
explicit revision to reproduce a historical run.
|
||||
"""
|
||||
|
||||
DATASET_ID: str = CLEANED_DATASET_ID
|
||||
|
||||
def __init__(self, revision: str | None = None) -> None:
|
||||
if revision is not None:
|
||||
self.revision = revision
|
||||
return
|
||||
try:
|
||||
from huggingface_hub import repo_info
|
||||
except ImportError as exc: # pragma: no cover — dev extra
|
||||
raise RuntimeError(
|
||||
"huggingface_hub not installed; run "
|
||||
"`pip install 'datasets>=2.18' huggingface_hub`"
|
||||
) from exc
|
||||
info = repo_info(repo_id=CLEANED_DATASET_ID, repo_type="dataset")
|
||||
self.revision = info.sha
|
||||
|
||||
def load_dataset(self, split: str = "S") -> Iterable[LMESession]:
|
||||
"""Stream LMESessions out of ``longmemeval_s_cleaned.json``.
|
||||
|
||||
Only ``split="S"`` is supported (the cleaned dataset ships the S
|
||||
split only). Raises ``ValueError`` on any other split value.
|
||||
"""
|
||||
if split != "S":
|
||||
raise ValueError(
|
||||
f"unknown LongMemEval cleaned split {split!r}; "
|
||||
f"the cleaned dataset ships only the 'S' split"
|
||||
)
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError as exc: # pragma: no cover — dev extra
|
||||
raise RuntimeError(
|
||||
"huggingface_hub not installed; run "
|
||||
"`pip install 'datasets>=2.18' huggingface_hub`"
|
||||
) from exc
|
||||
|
||||
print(
|
||||
f"[LongMemEval-cleaned] resolving split={split} "
|
||||
f"revision={self.revision} filename={CLEANED_FILENAME}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
path = hf_hub_download(
|
||||
repo_id=CLEANED_DATASET_ID,
|
||||
filename=CLEANED_FILENAME,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
rows = json.load(f)
|
||||
|
||||
for row in rows:
|
||||
qid = row["question_id"]
|
||||
question = row["question"]
|
||||
question_type = str(row.get("question_type", "unknown"))
|
||||
answer_session_ids = list(row.get("answer_session_ids", []))
|
||||
haystack_session_ids: list[str] = list(
|
||||
row.get("haystack_session_ids", [])
|
||||
)
|
||||
haystack_sessions: list[list[dict]] = list(
|
||||
row.get("haystack_sessions", [])
|
||||
)
|
||||
|
||||
# Emit one LMESession per haystack session; attach the eval
|
||||
# query to every one so the orchestrator can run ONE recall
|
||||
# per row after inserting all haystack turns. Matches the
|
||||
# raw adapter's emission shape exactly.
|
||||
for sess_id, turns in zip(
|
||||
haystack_session_ids, haystack_sessions
|
||||
):
|
||||
yield LMESession(
|
||||
session_id=sess_id,
|
||||
turns=list(turns),
|
||||
queries=[
|
||||
{
|
||||
"query": question,
|
||||
"question_id": qid,
|
||||
"question_type": question_type,
|
||||
"relevant_turn_ids": answer_session_ids,
|
||||
"is_gold_session": sess_id in answer_session_ids,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CLEANED_DATASET_ID",
|
||||
"CLEANED_FILENAME",
|
||||
"CleanedLongMemEvalAdapter",
|
||||
]
|
||||
80
bench/contradiction_longitudinal.py
Normal file
80
bench/contradiction_longitudinal.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Contradiction-longitudinal falsifiability bench (skeleton + pre-registered criteria).
|
||||
|
||||
**Do not run on the construction host by default** — this module is meant for a
|
||||
dedicated bench machine with an isolated ``IAI_MCP_STORE`` and optional GPU.
|
||||
|
||||
Pre-registered pass criteria:
|
||||
- **Metric B (post-flip):** cues issued after session ``t_0`` (contradiction +
|
||||
consolidation window simulated) must rank the *current* winning fact above
|
||||
flat cosine-only retrieval on the same store slice.
|
||||
- **Metric A (historical verbatim):** probes asking for superseded wording must
|
||||
still surface the archived surface (verbatim MEM-06), not the post-flip fact alone.
|
||||
- **Regression gate:** pipeline score on B must beat cosine baseline; A must not
|
||||
collapse below a configured verbatim hit threshold.
|
||||
|
||||
This file loads :file:`fixtures/contradiction_longitudinal.jsonl` (synthetic JSONL
|
||||
rows: ``session``, ``text``, optional ``probe`` / ``expects``) and documents the
|
||||
evaluation harness contract. A full implementation wires:
|
||||
|
||||
1. Fixture loader → ``MemoryStore`` inserts per session order.
|
||||
2. Explicit ``memory_contradict`` (or edge-equivalent) at ``t_0``.
|
||||
3. Optional sleep/consolidation tick simulation (bench-only knobs).
|
||||
4. Two eval slices: ``pre_flip_cues`` vs ``post_flip_cues`` with separated metrics.
|
||||
|
||||
Exit code 0 only when all gates pass; non-zero on any failure. Until the harness
|
||||
is completed, ``main()`` prints the criteria and exits with code 2 to avoid a
|
||||
silent green run::
|
||||
|
||||
python bench/contradiction_longitudinal.py --fixture bench/fixtures/contradiction_longitudinal.jsonl
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_rows(path: Path) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
|
||||
parser.add_argument(
|
||||
"--fixture",
|
||||
type=Path,
|
||||
default=Path(__file__).resolve().parent / "fixtures" / "contradiction_longitudinal.jsonl",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
rows = load_rows(args.fixture)
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"loaded_rows": len(rows),
|
||||
"fixture": str(args.fixture),
|
||||
"status": "harness_stub",
|
||||
"criteria": [
|
||||
"B: post-flip cues — pipeline beats flat cosine",
|
||||
"A: historical verbatim probes — superseded text still retrievable",
|
||||
"No regression: B gain without A collapse",
|
||||
],
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
# Stub: full eval is intentionally absent so CI never runs heavy retrieval.
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
4
bench/fixtures/contradiction_longitudinal.jsonl
Normal file
4
bench/fixtures/contradiction_longitudinal.jsonl
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
{"session": 0, "role": "user", "text": "The launch date is 2026-06-01.", "gold_fact": "2026-06-01"}
|
||||
{"session": 1, "role": "user", "text": "Correction: launch moved to 2026-09-01.", "gold_fact": "2026-09-01", "contradicts_session": 0}
|
||||
{"session": 2, "role": "user", "text": "What is the launch date?", "probe": "post_flip", "expects": "2026-09-01"}
|
||||
{"session": 2, "role": "user", "text": "Quote the original June announcement verbatim.", "probe": "historical_verbatim", "expects": "2026-06-01"}
|
||||
351
bench/lme500/aggregate.py
Normal file
351
bench/lme500/aggregate.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""bench/lme500/aggregate.py — post-process LongMemEval-S blind-run output.
|
||||
|
||||
Usage:
|
||||
python bench/lme500/aggregate.py \
|
||||
--in bench/lme500/output/lme500-v1.json \
|
||||
--report bench/lme500/output/lme500-v1-report.md \
|
||||
--summary bench/lme500/output/lme500-v1-summary.json
|
||||
|
||||
The --in path may be:
|
||||
- the final summary JSON ({"per_row": [...], ...} schema), or
|
||||
- the per-row JSONL checkpoint (one JSON dict per line — works on
|
||||
partial runs while the bench is still in progress).
|
||||
|
||||
Computes:
|
||||
- Overall R@5 / R@10 per prong (X = retrieve_recall, Y = recall_for_benchmark)
|
||||
- Architecture lift Y - X
|
||||
- Per-question-type stratification with n per bin (low-power flag if n<30)
|
||||
- Bootstrap 95% CI via percentile method (10000 resamples, seed=42)
|
||||
- Errors counted as miss for both prongs
|
||||
|
||||
Output:
|
||||
- Markdown report (--report)
|
||||
- Aggregated JSON summary (--summary)
|
||||
- One-line stderr summary at end
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def load_rows(input_path: Path) -> list[dict[str, Any]]:
|
||||
"""Load per-row dicts from JSON, JSONL, or list-JSON.
|
||||
|
||||
Order of detection:
|
||||
1. JSONL: every non-empty line parses as a dict.
|
||||
2. JSON object with "per_row" key → return per_row.
|
||||
3. JSON list → return as-is.
|
||||
"""
|
||||
text = input_path.read_text(encoding="utf-8")
|
||||
stripped = text.strip()
|
||||
# Try JSON first
|
||||
if stripped.startswith("{"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict) and "per_row" in data:
|
||||
return list(data["per_row"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if stripped.startswith("["):
|
||||
try:
|
||||
return list(json.loads(text))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Fall back to JSONL
|
||||
rows: list[dict[str, Any]] = []
|
||||
for lineno, line in enumerate(text.splitlines(), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rows.append(json.loads(line))
|
||||
except json.JSONDecodeError as exc:
|
||||
print(
|
||||
f"[aggregate] WARN: skipping corrupt line {lineno}: {exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def bootstrap_ci(
|
||||
values: list[float],
|
||||
n_resamples: int = 10000,
|
||||
seed: int = 42,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Bootstrap mean + 95% percentile CI.
|
||||
|
||||
Returns (mean, ci_lo, ci_hi). Empty input → (0, 0, 0).
|
||||
"""
|
||||
if not values:
|
||||
return 0.0, 0.0, 0.0
|
||||
rng = random.Random(seed)
|
||||
n = len(values)
|
||||
means: list[float] = []
|
||||
for _ in range(n_resamples):
|
||||
s = 0.0
|
||||
for _ in range(n):
|
||||
s += values[rng.randrange(n)]
|
||||
means.append(s / n)
|
||||
means.sort()
|
||||
lo_idx = max(0, int(0.025 * n_resamples))
|
||||
hi_idx = min(n_resamples - 1, int(0.975 * n_resamples))
|
||||
return statistics.fmean(values), means[lo_idx], means[hi_idx]
|
||||
|
||||
|
||||
def _get_prong_value(row: dict[str, Any], prong: str, k: int) -> float:
|
||||
"""Extract r_at_<k>_<prong> from a row, treating error rows as 0."""
|
||||
if "error" in row and isinstance(row.get("error"), dict):
|
||||
return 0.0
|
||||
return float(row.get(f"r_at_{k}_{prong}", 0.0))
|
||||
|
||||
|
||||
def aggregate(rows: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Aggregate overall + per-type bootstrap CIs."""
|
||||
if not rows:
|
||||
return {"overall": {"n": 0, "n_errors": 0}, "per_type": {}}
|
||||
|
||||
by_type: dict[str, dict[str, list[float]]] = defaultdict(
|
||||
lambda: {"x5": [], "x10": [], "y5": [], "y10": []}
|
||||
)
|
||||
overall: dict[str, list[float]] = {"x5": [], "x10": [], "y5": [], "y10": []}
|
||||
n_errors = 0
|
||||
|
||||
for row in rows:
|
||||
is_error = "error" in row and isinstance(row.get("error"), dict)
|
||||
if is_error:
|
||||
n_errors += 1
|
||||
qtype = str(row.get("question_type", "unknown"))
|
||||
x5 = _get_prong_value(row, "retrieve", 5)
|
||||
x10 = _get_prong_value(row, "retrieve", 10)
|
||||
y5 = _get_prong_value(row, "pipeline", 5)
|
||||
y10 = _get_prong_value(row, "pipeline", 10)
|
||||
overall["x5"].append(x5)
|
||||
overall["x10"].append(x10)
|
||||
overall["y5"].append(y5)
|
||||
overall["y10"].append(y10)
|
||||
by_type[qtype]["x5"].append(x5)
|
||||
by_type[qtype]["x10"].append(x10)
|
||||
by_type[qtype]["y5"].append(y5)
|
||||
by_type[qtype]["y10"].append(y10)
|
||||
|
||||
def _prong_block(vals_5: list[float], vals_10: list[float]) -> dict:
|
||||
m5, lo5, hi5 = bootstrap_ci(vals_5)
|
||||
m10, lo10, hi10 = bootstrap_ci(vals_10)
|
||||
return {
|
||||
"r_at_5": {"mean": m5, "ci_lo": lo5, "ci_hi": hi5},
|
||||
"r_at_10": {"mean": m10, "ci_lo": lo10, "ci_hi": hi10},
|
||||
}
|
||||
|
||||
overall_block = {
|
||||
"n": len(rows),
|
||||
"n_errors": n_errors,
|
||||
"X_retrieve": _prong_block(overall["x5"], overall["x10"]),
|
||||
"Y_pipeline": _prong_block(overall["y5"], overall["y10"]),
|
||||
}
|
||||
overall_block["lift_Y_minus_X"] = {
|
||||
"r_at_5": (
|
||||
overall_block["Y_pipeline"]["r_at_5"]["mean"]
|
||||
- overall_block["X_retrieve"]["r_at_5"]["mean"]
|
||||
),
|
||||
"r_at_10": (
|
||||
overall_block["Y_pipeline"]["r_at_10"]["mean"]
|
||||
- overall_block["X_retrieve"]["r_at_10"]["mean"]
|
||||
),
|
||||
}
|
||||
|
||||
per_type_out: dict[str, dict[str, Any]] = {}
|
||||
for qt in sorted(by_type.keys()):
|
||||
data = by_type[qt]
|
||||
block = {
|
||||
"n": len(data["x5"]),
|
||||
"X_retrieve": _prong_block(data["x5"], data["x10"]),
|
||||
"Y_pipeline": _prong_block(data["y5"], data["y10"]),
|
||||
}
|
||||
block["lift_Y_minus_X"] = {
|
||||
"r_at_5": (
|
||||
block["Y_pipeline"]["r_at_5"]["mean"]
|
||||
- block["X_retrieve"]["r_at_5"]["mean"]
|
||||
),
|
||||
"r_at_10": (
|
||||
block["Y_pipeline"]["r_at_10"]["mean"]
|
||||
- block["X_retrieve"]["r_at_10"]["mean"]
|
||||
),
|
||||
}
|
||||
per_type_out[qt] = block
|
||||
|
||||
return {"overall": overall_block, "per_type": per_type_out}
|
||||
|
||||
|
||||
def format_markdown_report(agg: dict[str, Any], source_path: Path) -> str:
|
||||
overall = agg["overall"]
|
||||
lines: list[str] = []
|
||||
lines.append("# LongMemEval-S Aggregate Report")
|
||||
lines.append("")
|
||||
lines.append(f"- Source: `{source_path}`")
|
||||
lines.append(f"- n = {overall['n']}, errors = {overall['n_errors']}")
|
||||
lines.append(
|
||||
"- 95% CI via bootstrap percentile method (10000 resamples, seed=42)"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
if overall["n"] == 0:
|
||||
lines.append("**No rows loaded.**")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
lines.append("## Overall")
|
||||
lines.append("")
|
||||
lines.append("| Prong | R@5 | R@5 95% CI | R@10 | R@10 95% CI |")
|
||||
lines.append("|---|---|---|---|---|")
|
||||
x = overall["X_retrieve"]
|
||||
y = overall["Y_pipeline"]
|
||||
lift = overall["lift_Y_minus_X"]
|
||||
lines.append(
|
||||
f"| X (retrieve_recall — flat-cosine baseline) "
|
||||
f"| {x['r_at_5']['mean']:.3f} "
|
||||
f"| [{x['r_at_5']['ci_lo']:.3f}, {x['r_at_5']['ci_hi']:.3f}] "
|
||||
f"| {x['r_at_10']['mean']:.3f} "
|
||||
f"| [{x['r_at_10']['ci_lo']:.3f}, {x['r_at_10']['ci_hi']:.3f}] |"
|
||||
)
|
||||
lines.append(
|
||||
f"| Y (recall_for_benchmark — full graph-native pipeline) "
|
||||
f"| {y['r_at_5']['mean']:.3f} "
|
||||
f"| [{y['r_at_5']['ci_lo']:.3f}, {y['r_at_5']['ci_hi']:.3f}] "
|
||||
f"| {y['r_at_10']['mean']:.3f} "
|
||||
f"| [{y['r_at_10']['ci_lo']:.3f}, {y['r_at_10']['ci_hi']:.3f}] |"
|
||||
)
|
||||
lines.append(
|
||||
f"| **Architecture lift Y − X** "
|
||||
f"| **{lift['r_at_5']:+.3f}** "
|
||||
f"| — "
|
||||
f"| **{lift['r_at_10']:+.3f}** "
|
||||
f"| — |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Per question type")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"| Type | n | X R@5 | Y R@5 | Lift R@5 "
|
||||
"| X R@10 | Y R@10 | Lift R@10 |"
|
||||
)
|
||||
lines.append("|---|---|---|---|---|---|---|---|")
|
||||
for qt, block in agg["per_type"].items():
|
||||
n = block["n"]
|
||||
flag = " ⚠️" if n < 30 else ""
|
||||
x = block["X_retrieve"]
|
||||
y = block["Y_pipeline"]
|
||||
lift = block["lift_Y_minus_X"]
|
||||
lines.append(
|
||||
f"| `{qt}`{flag} | {n} "
|
||||
f"| {x['r_at_5']['mean']:.3f} | {y['r_at_5']['mean']:.3f} "
|
||||
f"| {lift['r_at_5']:+.3f} "
|
||||
f"| {x['r_at_10']['mean']:.3f} | {y['r_at_10']['mean']:.3f} "
|
||||
f"| {lift['r_at_10']:+.3f} |"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("⚠️ = n < 30, low statistical power for that bin.")
|
||||
lines.append("")
|
||||
lines.append("## Notes")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"- Errors (graph-build failures, malformed rows, etc.) are counted "
|
||||
"as miss for **both** prongs (R@k = 0)."
|
||||
)
|
||||
lines.append(
|
||||
"- Mean is the unweighted row average; CI is bootstrap percentile."
|
||||
)
|
||||
lines.append(
|
||||
"- Architecture lift = mean(Y) − mean(X). The CI of the lift "
|
||||
"itself is not computed here (would require paired bootstrap on "
|
||||
"the (Y_i, X_i) tuples — TODO if needed)."
|
||||
)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in",
|
||||
dest="input",
|
||||
required=True,
|
||||
help="Path to per-row JSON / JSONL file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report",
|
||||
default=None,
|
||||
help="Output path for markdown report; default: <input>-report.md",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summary",
|
||||
default=None,
|
||||
help="Output path for aggregated JSON; default: <input>-summary.json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
if not input_path.exists():
|
||||
print(f"[aggregate] ERROR: {input_path} does not exist", file=sys.stderr)
|
||||
return 1
|
||||
rows = load_rows(input_path)
|
||||
if not rows:
|
||||
print(f"[aggregate] WARN: 0 rows loaded from {input_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
agg = aggregate(rows)
|
||||
|
||||
summary_path = (
|
||||
Path(args.summary)
|
||||
if args.summary
|
||||
else input_path.with_name(input_path.stem + "-summary.json")
|
||||
)
|
||||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(agg, f, indent=2)
|
||||
|
||||
report_path = (
|
||||
Path(args.report)
|
||||
if args.report
|
||||
else input_path.with_name(input_path.stem + "-report.md")
|
||||
)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_path.write_text(format_markdown_report(agg, input_path), encoding="utf-8")
|
||||
|
||||
overall = agg["overall"]
|
||||
x = overall["X_retrieve"]
|
||||
y = overall["Y_pipeline"]
|
||||
lift = overall["lift_Y_minus_X"]
|
||||
print(
|
||||
f"[aggregate] n={overall['n']} errors={overall['n_errors']}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] X (retrieve) R@5={x['r_at_5']['mean']:.3f} "
|
||||
f"[{x['r_at_5']['ci_lo']:.3f},{x['r_at_5']['ci_hi']:.3f}] "
|
||||
f"R@10={x['r_at_10']['mean']:.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] Y (pipeline) R@5={y['r_at_5']['mean']:.3f} "
|
||||
f"[{y['r_at_5']['ci_lo']:.3f},{y['r_at_5']['ci_hi']:.3f}] "
|
||||
f"R@10={y['r_at_10']['mean']:.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] Lift Y − X R@5={lift['r_at_5']:+.3f} "
|
||||
f"R@10={lift['r_at_10']:+.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"[aggregate] -> {summary_path}", file=sys.stderr)
|
||||
print(f"[aggregate] -> {report_path}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
328
bench/lme500/debug_pipeline_loss.py
Normal file
328
bench/lme500/debug_pipeline_loss.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""bench/lme500/debug_pipeline_loss.py
|
||||
|
||||
Trace WHICH pipeline stage drops the gold session in loss cases
|
||||
(rows where retrieve_recall hits in top-k but recall_for_benchmark does not).
|
||||
|
||||
Usage:
|
||||
python bench/lme500/debug_pipeline_loss.py <question_id> [<question_id> ...]
|
||||
|
||||
For each qid:
|
||||
- Loads the LongMemEval-S row from the pinned dataset.
|
||||
- Builds a fresh per-row store + runtime graph (same shape as the bench).
|
||||
- Runs retrieve_recall to confirm gold sessions are findable by flat cosine.
|
||||
- Runs recall_for_benchmark STAGE BY STAGE, recording at each cut whether the
|
||||
gold record IDs survived.
|
||||
|
||||
Stages traced:
|
||||
Stage 2 — community gate (top-3 communities by centroid cosine)
|
||||
Stage 3 — seeds (top-3 by cosine within gated candidates)
|
||||
Stage 4 — 2-hop spread + rich-club union
|
||||
Stage 5 — final recall_for_benchmark hits
|
||||
|
||||
Output is a per-stage table showing where gold drops.
|
||||
|
||||
Read-only — no src/iai_mcp changes. Calls private helpers _community_gate
|
||||
and _pick_seeds for stage-level inspection (debug-only path).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
from iai_mcp.pipeline import (
|
||||
_collect_graph_pool,
|
||||
_community_gate,
|
||||
_pick_seeds,
|
||||
recall_for_benchmark,
|
||||
)
|
||||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
from bench.adapters.longmemeval import LongMemEvalAdapter
|
||||
|
||||
|
||||
def _make_record(content: str, session_id: str, role: str, embedding: list[float]) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["longmemeval", f"role:{role}", f"session:{session_id}"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def find_row(qid: str):
|
||||
adapter = LongMemEvalAdapter()
|
||||
sessions = []
|
||||
question = None
|
||||
answer_session_ids = None
|
||||
qtype = None
|
||||
for lme_session in adapter.load_dataset(split="S"):
|
||||
q = lme_session.queries[0]
|
||||
if q["question_id"] == qid:
|
||||
sessions.append(lme_session)
|
||||
if question is None:
|
||||
question = q["query"]
|
||||
answer_session_ids = set(q.get("relevant_turn_ids", []))
|
||||
qtype = q.get("question_type", "?")
|
||||
return question, qtype, answer_session_ids, sessions
|
||||
|
||||
|
||||
def trace_one(qid: str) -> dict:
|
||||
"""Returns a dict with the stage-by-stage gold survival counts."""
|
||||
print(f"\n{'=' * 78}\n=== qid={qid} ===\n{'=' * 78}", flush=True)
|
||||
question, qtype, gold_session_ids, sessions = find_row(qid)
|
||||
if question is None:
|
||||
print(f" qid={qid} NOT FOUND in dataset", flush=True)
|
||||
return {}
|
||||
|
||||
print(f" type={qtype}", flush=True)
|
||||
print(f" question[0:120]={question[:120]!r}", flush=True)
|
||||
print(f" gold session_ids={gold_session_ids}", flush=True)
|
||||
print(f" haystack sessions={len(sessions)}", flush=True)
|
||||
|
||||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_dbg_"))
|
||||
store_dir = tmp_root / f"row-{qid}"
|
||||
store_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore(path=store_dir / "lancedb")
|
||||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||||
embedder = embedder_for_store(store)
|
||||
|
||||
id_to_session: dict[UUID, str] = {}
|
||||
gold_record_ids: set[UUID] = set()
|
||||
n_inserted = 0
|
||||
for sess in sessions:
|
||||
for turn in sess.turns:
|
||||
content = str(turn.get("content", "")).strip()
|
||||
if not content:
|
||||
continue
|
||||
vec = embedder.embed(content)
|
||||
rec = _make_record(
|
||||
content=content,
|
||||
session_id=sess.session_id,
|
||||
role=str(turn.get("role", "user")),
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[rec.id] = sess.session_id
|
||||
if sess.session_id in gold_session_ids:
|
||||
gold_record_ids.add(rec.id)
|
||||
n_inserted += 1
|
||||
|
||||
asyncio.run(store.disable_async_writes())
|
||||
print(f" records inserted: {n_inserted}", flush=True)
|
||||
print(f" gold records: {len(gold_record_ids)}", flush=True)
|
||||
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
print(f" graph nodes: {len(graph._nx.nodes)}", flush=True)
|
||||
print(f" communities: {len(assignment.mid_regions)}", flush=True)
|
||||
print(f" rich-club: {len(rich_club)}", flush=True)
|
||||
cue_emb = embedder.embed(question)
|
||||
|
||||
# --- Baseline: retrieve_recall ---
|
||||
resp_x = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_emb,
|
||||
cue_text=question,
|
||||
session_id=f"debug-{qid}",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
x_ids = [h.record_id for h in resp_x.hits]
|
||||
x_sessions = [id_to_session.get(r, "?") for r in x_ids]
|
||||
x_gold_pos = [i for i, s in enumerate(x_sessions) if s in gold_session_ids]
|
||||
print(f"\n --- retrieve_recall (X) ---", flush=True)
|
||||
print(f" top-10 sessions: {x_sessions}", flush=True)
|
||||
print(f" gold hit positions: {x_gold_pos}", flush=True)
|
||||
|
||||
# --- recall_for_benchmark, stage by stage ---
|
||||
print(f"\n --- recall_for_benchmark (Y) stage-by-stage ---", flush=True)
|
||||
|
||||
gated = _community_gate(cue_emb, assignment, top_n=3)
|
||||
candidates_set: set[UUID] = set()
|
||||
for gc in gated:
|
||||
for cid in assignment.mid_regions.get(gc, []):
|
||||
candidates_set.add(cid)
|
||||
if not candidates_set:
|
||||
candidates_set = {UUID(n) for n in graph._nx.nodes()}
|
||||
print(f" Stage 2 (community gate): EMPTY, fallback to all nodes", flush=True)
|
||||
print(f" Stage 2 (community gate): top-3 communities = {gated}", flush=True)
|
||||
print(f" candidates after gate: {len(candidates_set)}", flush=True)
|
||||
gold_in_gate = gold_record_ids & candidates_set
|
||||
print(f" gold survives gate: {len(gold_in_gate)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
centrality: dict[UUID, float] = {}
|
||||
for nid in graph._nx.nodes:
|
||||
n = graph._nx.nodes[nid]
|
||||
if "centrality" in n:
|
||||
try:
|
||||
centrality[UUID(nid)] = float(n["centrality"])
|
||||
except (TypeError, ValueError):
|
||||
centrality[UUID(nid)] = 0.0
|
||||
if not centrality:
|
||||
try:
|
||||
centrality = graph.centrality()
|
||||
except Exception:
|
||||
centrality = {}
|
||||
# (08-01): _pick_seeds now reads from a shared cosine array.
|
||||
# Build the same array the production pipeline builds.
|
||||
pool_ids, pool_embs = _collect_graph_pool(graph, None, store)
|
||||
cue_vec_norm = np.asarray(cue_emb, dtype=np.float32)
|
||||
cn = float(np.linalg.norm(cue_vec_norm))
|
||||
if cn > 0.0:
|
||||
cue_vec_norm = cue_vec_norm / cn
|
||||
if pool_embs.size:
|
||||
shared_cos = (pool_embs @ cue_vec_norm).astype(np.float32)
|
||||
else:
|
||||
shared_cos = np.empty(0, dtype=np.float32)
|
||||
id_to_idx = {rid: i for i, rid in enumerate(pool_ids)}
|
||||
cand_idx = np.array(
|
||||
[id_to_idx[c] for c in candidates_set if c in id_to_idx],
|
||||
dtype=np.int64,
|
||||
)
|
||||
centrality_arr = np.array(
|
||||
[centrality.get(rid, 0.0) for rid in pool_ids],
|
||||
dtype=np.float32,
|
||||
)
|
||||
seed_idx = _pick_seeds(cand_idx, shared_cos, centrality_arr, n=3)
|
||||
seeds = [pool_ids[int(i)] for i in seed_idx]
|
||||
print(f" Stage 3 (seeds, top-3 by cosine in gated): {len(seeds)}", flush=True)
|
||||
seeds_sessions = [id_to_session.get(s, "?") for s in seeds]
|
||||
print(f" seed sessions: {seeds_sessions}", flush=True)
|
||||
gold_in_seeds = gold_record_ids & set(seeds)
|
||||
print(f" gold in seeds: {len(gold_in_seeds)}", flush=True)
|
||||
|
||||
spread = graph.two_hop_neighborhood(seeds, top_k=5)
|
||||
reachable = set(seeds) | set(spread) | set(rich_club)
|
||||
print(f" Stage 4 (spread + rich-club union):", flush=True)
|
||||
print(f" seeds={len(seeds)} spread={len(spread)} rich={len(rich_club)} reachable={len(reachable)}", flush=True)
|
||||
gold_in_reachable = gold_record_ids & reachable
|
||||
print(f" gold in reachable: {len(gold_in_reachable)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
resp_y = recall_for_benchmark(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=question,
|
||||
session_id=f"debug-{qid}",
|
||||
k_hits=10,
|
||||
profile_state=None,
|
||||
turn=0,
|
||||
mode="concept",
|
||||
)
|
||||
y_ids = [h.record_id for h in resp_y.hits]
|
||||
y_sessions = [id_to_session.get(r, "?") for r in y_ids]
|
||||
y_gold_pos = [i for i, s in enumerate(y_sessions) if s in gold_session_ids]
|
||||
print(f" Stage 5 (rank + budget pack):", flush=True)
|
||||
print(f" final hits: {len(y_ids)}", flush=True)
|
||||
print(f" top-10 sessions: {y_sessions}", flush=True)
|
||||
print(f" gold hit positions: {y_gold_pos}", flush=True)
|
||||
|
||||
# ----- Verdict -----
|
||||
# verdict primary signal is whether gold lands in
|
||||
# recall_for_benchmark's top-10 — which is what matters for R@5/R@10.
|
||||
# Stage-2/3/4 stage-by-stage diagnostics still print above (useful when
|
||||
# gold is missed) but they observe the PRIVATE _community_gate /
|
||||
# _pick_seeds path. The redesign (08-CONTEXT.md D-02) makes the
|
||||
# community gate a soft-bias diagnostic rather than a hard filter, so a
|
||||
# "stage_2 missed" diagnostic with gold present in final hits means:
|
||||
# the gate's communities did not include gold, but the cosine top-K
|
||||
# candidate pool did, and Stage 5 ranking surfaced it.
|
||||
print(f"\n --- VERDICT ---", flush=True)
|
||||
if y_gold_pos:
|
||||
print(f" gold present in top-10 (positions {y_gold_pos}) — no_loss", flush=True)
|
||||
if not gold_in_gate:
|
||||
print(f" (gate would have killed it; augmentation rescued)", flush=True)
|
||||
verdict = "no_loss"
|
||||
elif not gold_in_gate:
|
||||
print(f" >>> GOLD KILLED at STAGE 2 (community gate) — augmentation also failed <<<", flush=True)
|
||||
verdict = "stage_2_community_gate"
|
||||
elif not gold_in_reachable:
|
||||
print(f" >>> GOLD KILLED at STAGE 3-4 (seeds + spread) <<<", flush=True)
|
||||
print(f" gold was {len(gold_in_gate)} candidate(s); none became "
|
||||
f"a seed and none was reached within 2 hops of the chosen seeds", flush=True)
|
||||
verdict = "stage_3_4_seeds_or_spread"
|
||||
else:
|
||||
print(f" >>> GOLD KILLED at STAGE 5 (rank + budget pack) <<<", flush=True)
|
||||
print(f" gold was reachable ({len(gold_in_reachable)}) but not in top-10 hits", flush=True)
|
||||
verdict = "stage_5_rank"
|
||||
|
||||
return {
|
||||
"qid": qid,
|
||||
"qtype": qtype,
|
||||
"verdict": verdict,
|
||||
"n_records": n_inserted,
|
||||
"n_communities": len(assignment.mid_regions),
|
||||
"n_rich_club": len(rich_club),
|
||||
"n_gold_records": len(gold_record_ids),
|
||||
"gold_in_gate": len(gold_in_gate),
|
||||
"gold_in_reachable": len(gold_in_reachable),
|
||||
"x_gold_pos": x_gold_pos,
|
||||
"y_gold_pos": y_gold_pos,
|
||||
}
|
||||
|
||||
|
||||
def main(qids: list[str]) -> int:
|
||||
summary = []
|
||||
for qid in qids:
|
||||
try:
|
||||
summary.append(trace_one(qid))
|
||||
except Exception as exc:
|
||||
print(f"\n qid={qid} TRACE FAILED: {type(exc).__name__}: {exc}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
summary.append({"qid": qid, "verdict": "trace_failed"})
|
||||
|
||||
print("\n\n" + "=" * 78)
|
||||
print("SUMMARY")
|
||||
print("=" * 78)
|
||||
print(f"{'qid':16} {'qtype':28} {'verdict':32} gold(gate→reach)")
|
||||
print("-" * 100)
|
||||
for s in summary:
|
||||
if not s:
|
||||
continue
|
||||
gate = s.get("gold_in_gate", "?")
|
||||
reach = s.get("gold_in_reachable", "?")
|
||||
ngold = s.get("n_gold_records", "?")
|
||||
print(
|
||||
f"{s.get('qid', '?'):16} {s.get('qtype', '?'):28} "
|
||||
f"{s.get('verdict', '?'):32} "
|
||||
f"{gate}→{reach} (of {ngold})"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(__doc__, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
768
bench/longmemeval_blind.py
Normal file
768
bench/longmemeval_blind.py
Normal file
|
|
@ -0,0 +1,768 @@
|
|||
"""Plan 05-11 blind-run orchestrator — / M-08.
|
||||
|
||||
Runs LongMemEval-S through IAI-MCP's public API (MemoryStore.insert +
|
||||
retrieve.recall) in strict blind mode: no per-dataset tuning, no
|
||||
hyperparameter sweep, no late adjustment after seeing numbers. This is
|
||||
the external honesty axis for Phase 5.
|
||||
|
||||
## Row-level protocol
|
||||
|
||||
One evaluation row in LongMemEval-S contains:
|
||||
|
||||
{ "question", "answer_session_ids" (gold),
|
||||
"haystack_session_ids", "haystack_sessions" (the full history) }
|
||||
|
||||
Per row the orchestrator does:
|
||||
|
||||
1. fresh tmp MemoryStore (per-row isolation; no cross-row leakage)
|
||||
2. enable async writes (Plan 05-10 — keeps RAM bounded on a
|
||||
16GB M1 laptop)
|
||||
3. embed + insert every turn of every haystack session; each record
|
||||
is tagged with ``session:<session_id>`` so the orchestrator can
|
||||
score at the dataset's native session-ID granularity.
|
||||
4. disable async writes (flushes the queue; the store now holds the
|
||||
full haystack).
|
||||
5. build_runtime_graph once (Plan 05-09 cache amortises cold start
|
||||
across rows via the shared runtime graph cache dir).
|
||||
6. call retrieve.recall for the eval query, with k_hits=10.
|
||||
7. compute R@5 / R@10 at session-ID granularity (the standard
|
||||
LongMemEval metric): a retrieved record "hits" if its ``session:``
|
||||
tag is in answer_session_ids. R@k is 1.0 if any top-k hits, else 0.
|
||||
8. measure per-query token cost via bench.tokens counters.
|
||||
|
||||
## CLI
|
||||
|
||||
python bench/longmemeval_blind.py \\
|
||||
--split S \\
|
||||
[--limit N] \\
|
||||
[--granularity {session, turn}] \\
|
||||
[--dataset {cleaned, raw}] \\
|
||||
[--qid-include csv] \\
|
||||
--out /tmp/p11_lme_full.json
|
||||
|
||||
Phase 9 added two methodology-alignment flags:
|
||||
|
||||
--granularity session (default; one record per session,
|
||||
content = "\\n".join(user-only turns))
|
||||
--granularity turn (v1/v2 reproducer; one record per turn)
|
||||
--dataset cleaned (default; xiaowu0162/longmemeval-cleaned)
|
||||
--dataset raw (v1/v2 reproducer; xiaowu0162/longmemeval
|
||||
rev 2ec2a557f339)
|
||||
--qid-include csv optional comma-separated question_ids; when
|
||||
set, only those rows run (used by smoke
|
||||
tests for per-qid baseline verification)
|
||||
|
||||
## Output JSON keys
|
||||
|
||||
{
|
||||
"split": "S",
|
||||
"dataset_id": "xiaowu0162/longmemeval-cleaned" | "xiaowu0162/longmemeval",
|
||||
"revision": "<40-hex>",
|
||||
"granularity": "session" | "turn",
|
||||
"dataset_choice": "cleaned" | "raw",
|
||||
"n_rows": int, # rows actually evaluated
|
||||
"r_at_5": float, # session-ID R@5, mean across rows
|
||||
"r_at_10": float, # session-ID R@10, mean across rows
|
||||
"token_p50": int, # per-query cue-text tokens, median
|
||||
"token_p95": int, # per-query cue-text tokens, p95
|
||||
"session_tokens_mean": float, # mean per-row inserted text tokens
|
||||
# (proxy for the rows' storage footprint)
|
||||
"errors": [{"question_id": str, "error_class": str, "error": str}],
|
||||
"hard_limit": int | null,
|
||||
"note": str
|
||||
}
|
||||
|
||||
## discipline
|
||||
|
||||
The run is ONE-SHOT. If a bug crashes a row, it's logged in ``errors``
|
||||
and counted as a MISS against R@k (not silently dropped). The published
|
||||
number is whatever came out. Disclosures (small-N, hardware limit,
|
||||
English-only embedder, etc.) live in the published bench report and
|
||||
05-11-SUMMARY.md — they don't get folded back into this script.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import statistics
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
# Silence the "UNEXPECTED embeddings.position_ids" noise from
|
||||
# sentence-transformers so the blind-run stderr stays focused on errors.
|
||||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||||
|
||||
# IAI-MCP imports — public API only (plan directive).
|
||||
from iai_mcp.embed import Embedder, embedder_for_store
|
||||
from iai_mcp.pipeline import recall_for_benchmark
|
||||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
# Adapter (ships alongside this script).
|
||||
from bench.adapters.longmemeval import (
|
||||
DATASET_ID,
|
||||
PINNED_REVISION,
|
||||
LMESession,
|
||||
LongMemEvalAdapter,
|
||||
)
|
||||
|
||||
# Token counter (reuses bench/tokens.py three-tier helper).
|
||||
from bench.tokens import _char4_count, _tiktoken_count
|
||||
|
||||
|
||||
def _count_tokens(text: str) -> int:
|
||||
"""Prefer tiktoken-cl100k proxy; fall back to char4."""
|
||||
try:
|
||||
return _tiktoken_count(text)
|
||||
except Exception: # pragma: no cover
|
||||
return _char4_count(text)
|
||||
|
||||
|
||||
def _percentile(xs: list[int], p: float) -> int:
|
||||
if not xs:
|
||||
return 0
|
||||
s = sorted(xs)
|
||||
k = max(0, min(len(s) - 1, int(round((len(s) - 1) * p / 100.0))))
|
||||
return s[k]
|
||||
|
||||
|
||||
def _make_record(
|
||||
content: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
embedding: list[float],
|
||||
) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
from uuid import uuid4
|
||||
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=[
|
||||
"longmemeval",
|
||||
f"role:{role}",
|
||||
f"session:{session_id}",
|
||||
],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def _run_one_row(
|
||||
row_id: str,
|
||||
question: str,
|
||||
question_type: str,
|
||||
answer_session_ids: set[str],
|
||||
sessions: list[LMESession],
|
||||
tmp_root: Path,
|
||||
granularity: str = "turn",
|
||||
embedder_key: str = "bge-small-en-v1.5",
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the per-row protocol. Returns a dict with r_at_5/r_at_10
|
||||
for BOTH retrieve_recall (flat-cosine baseline, matches Phase 5
|
||||
n=30) AND recall_for_benchmark (full graph-native architecture; Phase
|
||||
8 entry-point split), token counts plus timing info. Raises
|
||||
only on programmer errors; dataset/runtime errors are caught by the
|
||||
caller.
|
||||
|
||||
bench/lme500 protocol: prong X = retrieve_recall, prong Y =
|
||||
recall_for_benchmark. Both share the same insert phase + retrieved-set
|
||||
mapping, so the architecture-vs-baseline delta is attributable to
|
||||
the recall function only, not retrieval-side variance.
|
||||
|
||||
``granularity`` controls corpus construction.
|
||||
"turn" -> one record per turn (v1/v2 baseline; ~500 records/row)
|
||||
"session" -> one record per session whose content is
|
||||
"\\n".join(user-only turns), matching mempalace's
|
||||
reference verbatim (~53 records/row).
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Fresh store in a per-row tmp dir.
|
||||
store_dir = tmp_root / f"row-{row_id}"
|
||||
store_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore(path=store_dir / "lancedb")
|
||||
|
||||
# async writes: coalesce LanceDB appends across the row.
|
||||
# enable_async_writes is a coroutine — drive it from a fresh loop so
|
||||
# the surrounding orchestrator stays sync.
|
||||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||||
|
||||
# count inserted tokens as a rough storage footprint.
|
||||
inserted_text_tokens = 0
|
||||
|
||||
# route through the explicit registry key so the
|
||||
# embedder ablation experiment can swap to all-MiniLM-L6-v2 without
|
||||
# touching the production-default resolver (embedder_for_store kept
|
||||
# imported for backward-compat; not called on this path).
|
||||
embedder = Embedder(model_key=embedder_key)
|
||||
_ = embedder_for_store # silence unused-import warning when the prod path is bypassed
|
||||
|
||||
# --------- INSERT phase ---------
|
||||
# One pass over all haystack sessions for this row. Each MemoryRecord is
|
||||
# tagged with its session_id so R@k can score at the dataset's native
|
||||
# session granularity. splits this into two paths:
|
||||
# - "turn" (v1/v2 baseline; one record per turn, both roles)
|
||||
# - "session" (mempalace-aligned; one record per session, user-only
|
||||
# turns joined with "\n"; ~10x fewer records per row)
|
||||
id_to_session: dict[str, str] = {} # record_id.hex -> session_id
|
||||
if granularity == "session":
|
||||
# Session-granularity (D-01, mempalace-aligned): ONE record per
|
||||
# session, content = "\n".join(user-only turns). Skip sessions
|
||||
# with no user turns. Verbatim shape match with mempalace's
|
||||
# benchmarks/longmemeval_bench.py reference loop.
|
||||
for sess in sessions:
|
||||
user_turns = [
|
||||
str(turn.get("content", "")).strip()
|
||||
for turn in sess.turns
|
||||
if str(turn.get("role", "user")) == "user"
|
||||
and str(turn.get("content", "")).strip()
|
||||
]
|
||||
if not user_turns:
|
||||
continue
|
||||
doc_text = "\n".join(user_turns)
|
||||
vec = embedder.embed(doc_text)
|
||||
rec = _make_record(
|
||||
content=doc_text,
|
||||
session_id=sess.session_id,
|
||||
role="user",
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[str(rec.id)] = sess.session_id
|
||||
inserted_text_tokens += _count_tokens(doc_text)
|
||||
else:
|
||||
# Turn-granularity (v1/v2 baseline; bytes-identical loop body).
|
||||
for sess in sessions:
|
||||
for turn in sess.turns:
|
||||
content = str(turn.get("content", "")).strip()
|
||||
if not content:
|
||||
continue
|
||||
vec = embedder.embed(content)
|
||||
rec = _make_record(
|
||||
content=content,
|
||||
session_id=sess.session_id,
|
||||
role=str(turn.get("role", "user")),
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[str(rec.id)] = sess.session_id
|
||||
inserted_text_tokens += _count_tokens(content)
|
||||
|
||||
# Flush the async queue before recall. disable_async_writes is a
|
||||
# coroutine too — drive from a fresh loop.
|
||||
asyncio.run(store.disable_async_writes())
|
||||
t_after_insert = time.time()
|
||||
|
||||
# --------- Build runtime graph (Plan 05-09 cache warms cold-start) ---------
|
||||
# bench/lme500: capture the (graph, assignment, rich_club) tuple so
|
||||
# recall_for_benchmark (prong Y) can reuse it. retrieve_recall (prong X)
|
||||
# is unaffected by graph build success/failure.
|
||||
graph = None
|
||||
assignment = None
|
||||
rich_club = None
|
||||
try:
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
except Exception as exc: # pragma: no cover — cache helpers should be robust
|
||||
# Don't fail the row on graph build; retrieve_recall is still
|
||||
# callable from the flat store. recall_for_benchmark will be skipped
|
||||
# for this row and counted as miss for the Y prong.
|
||||
print(
|
||||
f"[LME] row={row_id} build_runtime_graph failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
t_after_graph = time.time()
|
||||
|
||||
# --------- Prong X: retrieve_recall (flat-cosine, baseline) ---------
|
||||
cue_embedding = embedder.embed(question)
|
||||
resp_x = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_embedding,
|
||||
cue_text=question,
|
||||
session_id=f"lme-{row_id}",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
t_after_x = time.time()
|
||||
|
||||
# --------- Prong Y: recall_for_benchmark (full graph-native architecture) ---------
|
||||
# entry-point split: bench harness uses the top-K contract
|
||||
# (k_hits=10, no budget_tokens). mode="concept" preserved verbatim — the
|
||||
# bench is concept-shaped per BENCH_PROTOCOL_lme500.md and the D-02
|
||||
# `_gate_bias_for_mode("concept") == 0.1` bias is what v2 measurements observe.
|
||||
resp_y = None
|
||||
pipeline_error: str | None = None
|
||||
if graph is not None:
|
||||
try:
|
||||
resp_y = recall_for_benchmark(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=question,
|
||||
session_id=f"lme-{row_id}",
|
||||
k_hits=10,
|
||||
profile_state=None,
|
||||
turn=0,
|
||||
mode="concept",
|
||||
)
|
||||
except Exception as exc:
|
||||
pipeline_error = f"{type(exc).__name__}: {str(exc)[:200]}"
|
||||
print(
|
||||
f"[LME] row={row_id} recall_for_benchmark failed: "
|
||||
f"{pipeline_error}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
pipeline_error = "graph_build_failed"
|
||||
t_after_y = time.time()
|
||||
|
||||
def _retrieved_session_ids(resp) -> list[str]:
|
||||
if resp is None:
|
||||
return []
|
||||
out: list[str] = []
|
||||
for hit in resp.hits:
|
||||
sid = id_to_session.get(str(hit.record_id))
|
||||
if sid is not None:
|
||||
out.append(sid)
|
||||
return out
|
||||
|
||||
sids_x = _retrieved_session_ids(resp_x)
|
||||
sids_y = _retrieved_session_ids(resp_y)
|
||||
|
||||
# LongMemEval-standard R@k at session-ID granularity: hit-at-k.
|
||||
# R@k = 1.0 if any of the top-k retrieved records belongs to a gold
|
||||
# session, else 0.0. Aggregated across rows by the caller.
|
||||
def _hit_at_k(sids: list[str], k: int) -> float:
|
||||
top = sids[:k]
|
||||
return 1.0 if any(s in answer_session_ids for s in top) else 0.0
|
||||
|
||||
r5_x = _hit_at_k(sids_x, 5)
|
||||
r10_x = _hit_at_k(sids_x, 10)
|
||||
r5_y = _hit_at_k(sids_y, 5) if resp_y is not None else 0.0
|
||||
r10_y = _hit_at_k(sids_y, 10) if resp_y is not None else 0.0
|
||||
|
||||
query_tokens = _count_tokens(question)
|
||||
|
||||
return {
|
||||
"question_id": row_id,
|
||||
"question_type": question_type,
|
||||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||||
"r_at_5_retrieve": r5_x,
|
||||
"r_at_10_retrieve": r10_x,
|
||||
# Prong Y — recall_for_benchmark (full graph-native pipeline; D-07)
|
||||
"r_at_5_pipeline": r5_y,
|
||||
"r_at_10_pipeline": r10_y,
|
||||
"pipeline_error": pipeline_error,
|
||||
# Shared
|
||||
"query_tokens": query_tokens,
|
||||
"inserted_text_tokens": inserted_text_tokens,
|
||||
"n_haystack_sessions": len(sessions),
|
||||
"n_turns_inserted": len(id_to_session),
|
||||
"timing_seconds": {
|
||||
"insert": round(t_after_insert - t0, 2),
|
||||
"graph": round(t_after_graph - t_after_insert, 2),
|
||||
"recall_retrieve": round(t_after_x - t_after_graph, 2),
|
||||
"recall_pipeline": round(t_after_y - t_after_x, 2),
|
||||
"total": round(t_after_y - t0, 2),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
default="S",
|
||||
choices=["S", "M", "oracle"],
|
||||
help="LongMemEval split (Plan 05-11 runs S)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"practical-cap on rows evaluated. LongMemEval-S = 500 rows; "
|
||||
"at ~500 turns/row and 11ms/embed on a 16GB M1 laptop, the "
|
||||
"full 500-row run is multi-hour. --limit lets the blind pilot "
|
||||
"finish; the SUMMARY discloses the cap honestly."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
default="/tmp/p11_lme_full.json",
|
||||
help="output JSON path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default=None,
|
||||
help=(
|
||||
"JSONL checkpoint path for crash-resume; default = <out>.jsonl. "
|
||||
"Each completed (or errored) row is appended with fsync as one "
|
||||
"JSON line. On restart, rows whose question_id already appears "
|
||||
"in the checkpoint are skipped."
|
||||
),
|
||||
)
|
||||
# granularity flag with mempalace-aligned default.
|
||||
parser.add_argument(
|
||||
"--granularity",
|
||||
choices=["session", "turn"],
|
||||
default="session",
|
||||
help=(
|
||||
"corpus-construction granularity. "
|
||||
"'session' (default, v3): one record per session, "
|
||||
"content = '\\n'.join(user-only turns) — matches mempalace's "
|
||||
"reference. 'turn': one record per turn (v1/v2 baseline; "
|
||||
"use with --dataset raw to reproduce v2's 0.956)."
|
||||
),
|
||||
)
|
||||
# dataset choice flag with mempalace-aligned default.
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=["cleaned", "raw"],
|
||||
default="cleaned",
|
||||
help=(
|
||||
"dataset variant. 'cleaned' (default, v3): "
|
||||
"xiaowu0162/longmemeval-cleaned, SHA pinned via repo_info(). "
|
||||
"'raw' (v1/v2 baseline): xiaowu0162/longmemeval rev "
|
||||
"2ec2a557f339... — use with --granularity turn to reproduce "
|
||||
"v2's 0.956."
|
||||
),
|
||||
)
|
||||
# Step B: per-qid filter for the v2-baseline
|
||||
# smoke reproducer. Applied AFTER --limit so a future caller passing
|
||||
# both flags gets a deterministic intersection (limit narrows by row
|
||||
# count, qid-include narrows by id). Default None preserves v1/v2 behaviour.
|
||||
parser.add_argument(
|
||||
"--qid-include",
|
||||
default=None,
|
||||
help=(
|
||||
"comma-separated list of question_ids; if set, only these "
|
||||
"rows run (used by smoke tests for per-qid baseline "
|
||||
"verification). Applied after --limit."
|
||||
),
|
||||
)
|
||||
# bench-only embedder swap. Default preserves v3
|
||||
# baseline (bge-small-en-v1.5). all-MiniLM-L6-v2 is mempalace's ChromaDB
|
||||
# default — used for the embedder-axis ablation in v3.1. Production
|
||||
# embedder is unchanged regardless of this flag (English-Only Brain lock
|
||||
# from / Plan 05-08; the Embedder.__init__ kwarg is the only
|
||||
# entry point that surfaces the registry's all-MiniLM-L6-v2 entry).
|
||||
parser.add_argument(
|
||||
"--embedder",
|
||||
choices=["bge-small-en-v1.5", "all-MiniLM-L6-v2"],
|
||||
default="bge-small-en-v1.5",
|
||||
help=(
|
||||
"embedder model_key. 'bge-small-en-v1.5' (default, v3 "
|
||||
"baseline) routes via the production English-only embedder. "
|
||||
"'all-MiniLM-L6-v2' (Phase 9.1 ablation) is mempalace's "
|
||||
"ChromaDB default — bench-only swap, production unchanged."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
print(
|
||||
f"[LME] blind run starting "
|
||||
f"split={args.split} limit={args.limit} "
|
||||
f"granularity={args.granularity} dataset={args.dataset} "
|
||||
f"embedder={args.embedder} "
|
||||
f"out={args.out}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# branch the adapter on --dataset.
|
||||
if args.dataset == "cleaned":
|
||||
from bench.adapters.longmemeval_cleaned import (
|
||||
CLEANED_DATASET_ID,
|
||||
CleanedLongMemEvalAdapter,
|
||||
)
|
||||
adapter = CleanedLongMemEvalAdapter()
|
||||
dataset_id_emit = CLEANED_DATASET_ID
|
||||
revision_emit = adapter.revision
|
||||
else:
|
||||
adapter = LongMemEvalAdapter()
|
||||
dataset_id_emit = DATASET_ID
|
||||
revision_emit = PINNED_REVISION
|
||||
# Adapter yields one LMESession per haystack session, but the
|
||||
# blind-run protocol needs rows (one question + all its haystack
|
||||
# sessions). Group by question_id (carried inside queries[0]).
|
||||
grouped: dict[str, dict[str, Any]] = {}
|
||||
row_order: list[str] = []
|
||||
for lme_session in adapter.load_dataset(split=args.split):
|
||||
q = lme_session.queries[0]
|
||||
qid = q["question_id"]
|
||||
if qid not in grouped:
|
||||
grouped[qid] = {
|
||||
"question": q["query"],
|
||||
"question_type": q.get("question_type", "unknown"),
|
||||
"answer_session_ids": set(q.get("relevant_turn_ids", [])),
|
||||
"sessions": [],
|
||||
}
|
||||
row_order.append(qid)
|
||||
grouped[qid]["sessions"].append(lme_session)
|
||||
|
||||
if args.limit is not None:
|
||||
row_order = row_order[: args.limit]
|
||||
|
||||
# Step B: --qid-include filter applied AFTER
|
||||
# --limit so a future caller passing both flags gets a deterministic
|
||||
# intersection. The default None path is a no-op for backward compat.
|
||||
if args.qid_include is not None:
|
||||
wanted = {q.strip() for q in str(args.qid_include).split(",") if q.strip()}
|
||||
row_order = [qid for qid in row_order if qid in wanted]
|
||||
print(
|
||||
f"[LME] qid-include filter: kept {len(row_order)} of "
|
||||
f"{len(wanted)} requested qids",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_blind_"))
|
||||
print(f"[LME] per-row stores rooted at {tmp_root}", file=sys.stderr, flush=True)
|
||||
|
||||
per_row: list[dict[str, Any]] = []
|
||||
errors: list[dict[str, str]] = []
|
||||
# bench/lme500: track BOTH prongs (X = retrieve_recall, Y = recall_for_benchmark).
|
||||
r5_x_values: list[float] = []
|
||||
r10_x_values: list[float] = []
|
||||
r5_y_values: list[float] = []
|
||||
r10_y_values: list[float] = []
|
||||
query_tokens: list[int] = []
|
||||
session_tokens: list[int] = []
|
||||
|
||||
# bench/lme500: per-row JSONL checkpoint for crash resume.
|
||||
# Each row's full result is appended with flush + fsync, so a kill at
|
||||
# row N preserves rows 1..N-1 fully. Restart skips rows already in the
|
||||
# checkpoint (matched by question_id).
|
||||
checkpoint_path = Path(args.checkpoint) if args.checkpoint else Path(str(args.out) + ".jsonl")
|
||||
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
completed_ids: set[str] = set()
|
||||
if checkpoint_path.exists():
|
||||
with open(checkpoint_path, "r", encoding="utf-8") as cp_f:
|
||||
for line in cp_f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rec = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"[LME] WARN: skipping corrupt checkpoint line: {line[:80]!r}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
qid = rec.get("question_id")
|
||||
if not qid:
|
||||
continue
|
||||
completed_ids.add(qid)
|
||||
if "error" in rec and isinstance(rec.get("error"), dict):
|
||||
# Resumed error row: count as full miss for both prongs.
|
||||
errors.append(
|
||||
{
|
||||
"question_id": qid,
|
||||
"error_class": rec["error"].get("error_class", "Unknown"),
|
||||
"error": rec["error"].get("error", ""),
|
||||
}
|
||||
)
|
||||
r5_x_values.append(0.0)
|
||||
r10_x_values.append(0.0)
|
||||
r5_y_values.append(0.0)
|
||||
r10_y_values.append(0.0)
|
||||
query_tokens.append(0)
|
||||
session_tokens.append(0)
|
||||
else:
|
||||
# Resumed success row.
|
||||
per_row.append(rec)
|
||||
r5_x_values.append(float(rec.get("r_at_5_retrieve", 0.0)))
|
||||
r10_x_values.append(float(rec.get("r_at_10_retrieve", 0.0)))
|
||||
r5_y_values.append(float(rec.get("r_at_5_pipeline", 0.0)))
|
||||
r10_y_values.append(float(rec.get("r_at_10_pipeline", 0.0)))
|
||||
query_tokens.append(int(rec.get("query_tokens", 0)))
|
||||
session_tokens.append(int(rec.get("inserted_text_tokens", 0)))
|
||||
if completed_ids:
|
||||
print(
|
||||
f"[LME] resume: {len(completed_ids)} rows already in checkpoint "
|
||||
f"{checkpoint_path}; processing {len(row_order) - len(completed_ids)} remaining",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[LME] checkpoint: writing per-row durable JSONL to {checkpoint_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _checkpoint_append(rec: dict[str, Any]) -> None:
|
||||
"""Append one row record to the checkpoint, flush+fsync for durability."""
|
||||
with open(checkpoint_path, "a", encoding="utf-8") as cp_a:
|
||||
cp_a.write(json.dumps(rec) + "\n")
|
||||
cp_a.flush()
|
||||
os.fsync(cp_a.fileno())
|
||||
|
||||
run_t0 = time.time()
|
||||
for i, qid in enumerate(row_order):
|
||||
if qid in completed_ids:
|
||||
continue
|
||||
row = grouped[qid]
|
||||
try:
|
||||
res = _run_one_row(
|
||||
row_id=qid,
|
||||
question=row["question"],
|
||||
question_type=row["question_type"],
|
||||
answer_session_ids=row["answer_session_ids"],
|
||||
sessions=row["sessions"],
|
||||
tmp_root=tmp_root,
|
||||
granularity=args.granularity,
|
||||
embedder_key=args.embedder,
|
||||
)
|
||||
per_row.append(res)
|
||||
r5_x_values.append(res["r_at_5_retrieve"])
|
||||
r10_x_values.append(res["r_at_10_retrieve"])
|
||||
r5_y_values.append(res["r_at_5_pipeline"])
|
||||
r10_y_values.append(res["r_at_10_pipeline"])
|
||||
query_tokens.append(res["query_tokens"])
|
||||
session_tokens.append(res["inserted_text_tokens"])
|
||||
_checkpoint_append(res)
|
||||
elapsed = time.time() - run_t0
|
||||
print(
|
||||
f"[LME] row {i+1}/{len(row_order)} qid={qid} "
|
||||
f"qtype={res['question_type']} "
|
||||
f"R@5_x={res['r_at_5_retrieve']:.0f} R@5_y={res['r_at_5_pipeline']:.0f} "
|
||||
f"R@10_x={res['r_at_10_retrieve']:.0f} R@10_y={res['r_at_10_pipeline']:.0f} "
|
||||
f"t_row={res['timing_seconds']['total']:.1f}s "
|
||||
f"t_total={elapsed:.1f}s",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
# T-05-11-04 mitigation: log + count as miss, do
|
||||
# NOT silently drop.
|
||||
err_payload = {
|
||||
"error_class": type(exc).__name__,
|
||||
"error": str(exc)[:500],
|
||||
}
|
||||
errors.append({"question_id": qid, **err_payload})
|
||||
# Counted as a full miss for both prongs — preserves
|
||||
# "count against R@5 as 0" from the plan text.
|
||||
r5_x_values.append(0.0)
|
||||
r10_x_values.append(0.0)
|
||||
r5_y_values.append(0.0)
|
||||
r10_y_values.append(0.0)
|
||||
query_tokens.append(0)
|
||||
session_tokens.append(0)
|
||||
# Persist the error row to checkpoint so a restart skips it.
|
||||
_checkpoint_append(
|
||||
{
|
||||
"question_id": qid,
|
||||
"question_type": row.get("question_type", "unknown"),
|
||||
"error": err_payload,
|
||||
}
|
||||
)
|
||||
print(
|
||||
f"[LME] ERROR row={qid}: {type(exc).__name__}: {exc}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
finally:
|
||||
# Free disk aggressively — many rows × ~500 turns per store
|
||||
# adds up even on 64GB.
|
||||
row_dir = tmp_root / f"row-{qid}"
|
||||
if row_dir.exists():
|
||||
shutil.rmtree(row_dir, ignore_errors=True)
|
||||
|
||||
shutil.rmtree(tmp_root, ignore_errors=True)
|
||||
|
||||
def _mean(xs: list[float]) -> float:
|
||||
return (sum(xs) / len(xs)) if xs else 0.0
|
||||
|
||||
out = {
|
||||
"split": args.split,
|
||||
"dataset_id": dataset_id_emit,
|
||||
"revision": revision_emit,
|
||||
# reproducibility fields:
|
||||
"granularity": args.granularity,
|
||||
"dataset_choice": args.dataset,
|
||||
# embedder identity pinned for v3.1 ablation reproducibility.
|
||||
# Default "bge-small-en-v1.5" reproduces v3 baseline; "all-MiniLM-L6-v2"
|
||||
# is the embedder-axis ablation toggle (mempalace ChromaDB default).
|
||||
"embedder_model_key": args.embedder,
|
||||
"embedder_hf_id": Embedder(model_key=args.embedder).model_name,
|
||||
"n_rows": len(row_order),
|
||||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||||
"r_at_5_retrieve": _mean(r5_x_values),
|
||||
"r_at_10_retrieve": _mean(r10_x_values),
|
||||
# Prong Y — recall_for_benchmark (full graph-native architecture; D-07)
|
||||
"r_at_5_pipeline": _mean(r5_y_values),
|
||||
"r_at_10_pipeline": _mean(r10_y_values),
|
||||
# Architecture lift (Y - X)
|
||||
"r_at_5_lift": _mean(r5_y_values) - _mean(r5_x_values),
|
||||
"r_at_10_lift": _mean(r10_y_values) - _mean(r10_x_values),
|
||||
"token_p50": _percentile(query_tokens, 50),
|
||||
"token_p95": _percentile(query_tokens, 95),
|
||||
"session_tokens_mean": (
|
||||
statistics.fmean(session_tokens) if session_tokens else 0.0
|
||||
),
|
||||
"errors": errors,
|
||||
"hard_limit": args.limit,
|
||||
"metric_def": (
|
||||
"Session-ID hit-at-k: R@k = 1.0 if any of top-k retrieved records "
|
||||
"belongs to a gold session_id, else 0.0 (LongMemEval standard)."
|
||||
),
|
||||
"per_row": per_row,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"total_wall_seconds": round(time.time() - run_t0, 2),
|
||||
}
|
||||
|
||||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
json.dump(out, f, indent=2)
|
||||
|
||||
print(
|
||||
f"[LME] DONE n_rows={out['n_rows']} "
|
||||
f"R@5_retrieve={out['r_at_5_retrieve']:.3f} "
|
||||
f"R@5_pipeline={out['r_at_5_pipeline']:.3f} "
|
||||
f"lift_R@5={out['r_at_5_lift']:+.3f} "
|
||||
f"R@10_retrieve={out['r_at_10_retrieve']:.3f} "
|
||||
f"R@10_pipeline={out['r_at_10_pipeline']:.3f} "
|
||||
f"lift_R@10={out['r_at_10_lift']:+.3f} "
|
||||
f"errors={len(errors)} -> {args.out}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
335
bench/memory_footprint.py
Normal file
335
bench/memory_footprint.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
"""M-03 RAM footprint bench. Reports RSS at store size N.
|
||||
|
||||
Target: RSS <= 300 MB warm at N=10k on a 16+ GB machine.
|
||||
|
||||
Pressplay 8 GB M1 hung mid-run on 2026-04-19 while trying to build the
|
||||
runtime graph at N=10k (Pitfall 4 from 05-RESEARCH: bge-m3 ~2 GB +
|
||||
NetworkX ~200 MB + LanceDB ~50 MB + Python overhead -> swap thrash).
|
||||
Phase 5 measures on this 16 GB dev Mac; pressplay cross-validates at
|
||||
N <= 2000 per D5-09.
|
||||
|
||||
JSON output (one line to stdout):
|
||||
|
||||
{
|
||||
"n": int,
|
||||
"rss_mb_peak": float, # platform-adjusted MB
|
||||
"threshold_mb": 300.0,
|
||||
"passed": bool, # True iff rss_mb_peak <= threshold_mb
|
||||
"platform": "darwin"|"linux"|"win32",
|
||||
"stage_ms": {"seed": float, "graph": float},
|
||||
"seed_n": int, # records that actually made it in
|
||||
"graph_built": bool, # True iff build_runtime_graph finished
|
||||
}
|
||||
|
||||
Exit codes:
|
||||
0 if passed, 1 otherwise.
|
||||
|
||||
CLI:
|
||||
python -m bench.memory_footprint [--n 10000] [--dim 1024] [--seed 42]
|
||||
[--skip-graph]
|
||||
|
||||
--skip-graph keeps the RSS reading to the seeded-store baseline (no
|
||||
NetworkX graph build); useful when the graph build is the timeout cause
|
||||
and we want to isolate the store-only overhead.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import EMBED_DIM, MemoryRecord
|
||||
|
||||
THRESHOLD_MB = 300.0
|
||||
|
||||
|
||||
def _isolate_keyring_in_memory() -> None:
|
||||
"""Install an in-memory keyring backend so MemoryStore's crypto layer
|
||||
never calls macOS Keychain (which hangs under SecItemCopyMatching when
|
||||
the bench is invoked from a non-interactive shell).
|
||||
|
||||
Idempotent: if the current backend already has our sentinel attribute,
|
||||
it's a no-op. This is strictly bench-scope — production code paths do
|
||||
NOT touch this function.
|
||||
"""
|
||||
import keyring
|
||||
from keyring.backend import KeyringBackend
|
||||
|
||||
if getattr(keyring.get_keyring(), "_iai_bench_noop", False):
|
||||
return
|
||||
|
||||
class _BenchNoOpKeyring(KeyringBackend):
|
||||
priority = 99
|
||||
_iai_bench_noop = True
|
||||
_kv: dict[tuple[str, str], str] = {}
|
||||
|
||||
def get_password(self, service: str, username: str) -> str | None:
|
||||
return self._kv.get((service, username))
|
||||
|
||||
def set_password(self, service: str, username: str, password: str) -> None:
|
||||
self._kv[(service, username)] = password
|
||||
|
||||
def delete_password(self, service: str, username: str) -> None:
|
||||
self._kv.pop((service, username), None)
|
||||
|
||||
keyring.set_keyring(_BenchNoOpKeyring())
|
||||
|
||||
|
||||
def _rss_mb() -> float:
|
||||
"""Peak RSS in MB, platform-adjusted.
|
||||
|
||||
macOS returns ru_maxrss in BYTES.
|
||||
Linux returns ru_maxrss in KB.
|
||||
Windows via resource is not supported; the Windows branch falls back to
|
||||
a best-effort reading and the platform marker in the JSON output lets
|
||||
the report flag it.
|
||||
"""
|
||||
r = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if sys.platform == "darwin":
|
||||
return float(r) / 1024.0 / 1024.0
|
||||
# Linux reports kilobytes; everything else treated as KB for safety.
|
||||
return float(r) / 1024.0
|
||||
|
||||
|
||||
def _make_noise_record(i: int, rng: np.random.Generator, dim: int) -> MemoryRecord:
|
||||
"""Inline noise-record maker that does not pull in bench/verbatim.
|
||||
|
||||
Keeps this bench self-contained so imports don't drag heavy deps.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
vec = rng.standard_normal(dim)
|
||||
norm = float(np.linalg.norm(vec))
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=f"bench noise record {i}",
|
||||
aaak_index="",
|
||||
embedding=vec.tolist(),
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["bench", "ops-11"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def _seed_store(
|
||||
store: MemoryStore, n: int, dim: int, seed: int, *, concurrent: bool = False
|
||||
) -> int:
|
||||
"""Seed N synthetic records. Returns the count actually inserted.
|
||||
|
||||
When ``concurrent`` is True, inserts are dispatched from a thread
|
||||
pool so the coalescing AsyncWriteQueue can actually batch records
|
||||
inside its 100 ms window. Sequential blocking inserts (the default
|
||||
sync path) see no coalesce benefit because each insert waits on its
|
||||
own batch flush before the next enqueue even happens.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
records = [_make_noise_record(i, rng, dim=dim) for i in range(n)]
|
||||
if not concurrent:
|
||||
for r in records:
|
||||
store.insert(r)
|
||||
return len(records)
|
||||
|
||||
# Concurrent path: a thread pool fires enqueues from many threads so
|
||||
# the queue's coalesce window fills. Pool size ~256 is large enough
|
||||
# to always fill a max_batch=128 window on this hardware.
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=256) as pool:
|
||||
list(pool.map(store.insert, records))
|
||||
return len(records)
|
||||
|
||||
|
||||
def run_memory_footprint(
|
||||
n: int = 10_000,
|
||||
store_path: Path | str | None = None,
|
||||
dim: int = EMBED_DIM,
|
||||
seed: int = 42,
|
||||
*,
|
||||
skip_graph: bool = False,
|
||||
isolate_keyring: bool = True,
|
||||
async_writes: bool = False,
|
||||
) -> dict:
|
||||
"""Seed N records, optionally build the runtime graph, measure RSS.
|
||||
|
||||
`isolate_keyring` (default True) installs an in-memory keyring backend
|
||||
so MemoryStore's crypto layer never hits macOS Keychain. Set False only
|
||||
when benching against an existing ~/.iai-mcp store whose real key lives
|
||||
in the user keyring.
|
||||
|
||||
Returns a JSON-shaped dict with the keys described in the module docstring.
|
||||
"""
|
||||
if isolate_keyring:
|
||||
_isolate_keyring_in_memory()
|
||||
|
||||
cleanup: tempfile.TemporaryDirectory | None = None
|
||||
if store_path is None:
|
||||
cleanup = tempfile.TemporaryDirectory(prefix="iai-bench-ops11-")
|
||||
path = Path(cleanup.name)
|
||||
else:
|
||||
path = Path(store_path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Honour the caller's --dim request by setting IAI_MCP_EMBED_DIM BEFORE
|
||||
# the MemoryStore is constructed. The store reads this env var via
|
||||
# store._resolve_embed_dim() on first table creation (see store.py:115).
|
||||
# Restore the prior value after the run so other benches/tests are not
|
||||
# contaminated.
|
||||
prev_embed_dim = os.environ.get("IAI_MCP_EMBED_DIM")
|
||||
if dim != EMBED_DIM:
|
||||
os.environ["IAI_MCP_EMBED_DIM"] = str(dim)
|
||||
|
||||
try:
|
||||
store = MemoryStore(path=path)
|
||||
# Match the store's actual embed dim so inserts don't get silently
|
||||
# rejected when the env override was ignored (e.g. existing table
|
||||
# on disk pins a different dim).
|
||||
eff_dim = store.embed_dim
|
||||
|
||||
# if --async-writes is set, enable the coalescing
|
||||
# write queue before the seed loop so every store.insert() below
|
||||
# routes through it. The queue is drained + torn down after the
|
||||
# seed completes, keeping the graph build / RSS reading on the
|
||||
# legacy sync path.
|
||||
if async_writes:
|
||||
import asyncio as _asyncio
|
||||
|
||||
async def _enable():
|
||||
await store.enable_async_writes()
|
||||
|
||||
_asyncio.run(_enable())
|
||||
|
||||
t0 = time.perf_counter()
|
||||
seed_n = _seed_store(
|
||||
store, n, dim=eff_dim, seed=seed, concurrent=async_writes,
|
||||
)
|
||||
seed_ms = (time.perf_counter() - t0) * 1000.0
|
||||
|
||||
if async_writes:
|
||||
import asyncio as _asyncio
|
||||
|
||||
async def _disable():
|
||||
await store.disable_async_writes()
|
||||
|
||||
_asyncio.run(_disable())
|
||||
|
||||
graph_built = False
|
||||
graph_ms = 0.0
|
||||
if not skip_graph:
|
||||
# Lazy import so --skip-graph runs don't pay the NetworkX load.
|
||||
from iai_mcp import retrieve
|
||||
|
||||
t1 = time.perf_counter()
|
||||
try:
|
||||
_graph, _assignment, _rc = retrieve.build_runtime_graph(store)
|
||||
graph_built = True
|
||||
except Exception:
|
||||
# Graph build can OOM on small hosts; surface that as the
|
||||
# diagnostic rather than crashing the bench. The RSS reading
|
||||
# still reflects peak consumed up to the failure.
|
||||
graph_built = False
|
||||
graph_ms = (time.perf_counter() - t1) * 1000.0
|
||||
|
||||
gc.collect()
|
||||
rss_mb_peak = _rss_mb()
|
||||
|
||||
return {
|
||||
"n": n,
|
||||
"rss_mb_peak": round(rss_mb_peak, 2),
|
||||
"threshold_mb": THRESHOLD_MB,
|
||||
"passed": rss_mb_peak <= THRESHOLD_MB,
|
||||
"platform": sys.platform,
|
||||
"stage_ms": {
|
||||
"seed": round(seed_ms, 2),
|
||||
"graph": round(graph_ms, 2),
|
||||
},
|
||||
"seed_n": seed_n,
|
||||
"graph_built": graph_built,
|
||||
"dim": eff_dim,
|
||||
"async_writes": bool(async_writes),
|
||||
}
|
||||
finally:
|
||||
# Restore IAI_MCP_EMBED_DIM so other benches / tests run with the
|
||||
# host default.
|
||||
if dim != EMBED_DIM:
|
||||
if prev_embed_dim is None:
|
||||
os.environ.pop("IAI_MCP_EMBED_DIM", None)
|
||||
else:
|
||||
os.environ["IAI_MCP_EMBED_DIM"] = prev_embed_dim
|
||||
if cleanup is not None:
|
||||
cleanup.cleanup()
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="bench.memory_footprint",
|
||||
description=(
|
||||
"OPS-11 / RAM bench. Seeds N records, optionally builds "
|
||||
"the runtime graph, reports peak RSS. Target: <=300 MB at "
|
||||
"N=10k on a 16+ GB host."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", "--n-records", dest="n", type=int, default=10_000,
|
||||
help="record count to seed (default 10000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim", type=int, default=EMBED_DIM,
|
||||
help=f"embedding dimension (default {EMBED_DIM}; tests use 32/64 for speed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="RNG seed (default 42)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-graph", action="store_true",
|
||||
help="Skip build_runtime_graph; isolate store-only RSS",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-writes", action="store_true",
|
||||
help=(
|
||||
"enable MemoryStore.enable_async_writes() before the "
|
||||
"seed loop so inserts go through the coalescing AsyncWriteQueue. "
|
||||
"Target: amortise the ~0.3 MB/insert LanceDB buffer overhead by "
|
||||
"batching 128 inserts per flush."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out", type=str, default=None,
|
||||
help="Write the JSON result to this file (in addition to stdout).",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
result = run_memory_footprint(
|
||||
n=args.n, dim=args.dim, seed=args.seed,
|
||||
skip_graph=args.skip_graph, async_writes=args.async_writes,
|
||||
)
|
||||
if args.out:
|
||||
with open(args.out, "w") as fh:
|
||||
json.dump(result, fh)
|
||||
print(json.dumps(result))
|
||||
return 0 if result["passed"] else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
449
bench/neural_map.py
Normal file
449
bench/neural_map.py
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
"""bench/neural_map.py -- D-SPEED benchmark.
|
||||
|
||||
Measures recall_for_response latency at store sizes {100, 1k, 5k, 10k}. The
|
||||
D-SPEED contract is p95 < 100ms at 10k. The bench seeds a synthetic store,
|
||||
builds the runtime graph, runs N iterations of recall_for_response with varied
|
||||
cue strings, and reports:
|
||||
|
||||
- latency_ms_p50 / latency_ms_p95 across iterations
|
||||
- stage_timings_ms: mean per-stage timing (embed / gate / seeds / spread / rank)
|
||||
- passed: p95 < 100ms
|
||||
|
||||
CLI:
|
||||
python -m bench.neural_map [--n 100] [--n 1000] [--n 5000] [--n 10000]
|
||||
[--iterations 10]
|
||||
|
||||
When the executor hardware cannot meet <100ms at 10k, main() returns 1 so
|
||||
CI catches the regression; the user / retro decides whether to
|
||||
tune the implementation or accept.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from iai_mcp.community import CommunityAssignment
|
||||
from iai_mcp.graph import MemoryGraph
|
||||
from iai_mcp.pipeline import recall_for_response
|
||||
from iai_mcp.retrieve import build_runtime_graph
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import EMBED_DIM, MemoryRecord
|
||||
|
||||
|
||||
# D-SPEED: 100ms p95 ceiling at 10k records.
|
||||
D_SPEED_P95_MS = 100.0
|
||||
|
||||
|
||||
class _BenchEmbedder:
|
||||
"""Fast deterministic embedder for bench runs.
|
||||
|
||||
Random vectors seeded from cue text + a fixed base seed. Matches the
|
||||
Embedder protocol expected by pipeline.recall_for_response (DIM attribute +
|
||||
embed method); no network, no sentence-transformer load.
|
||||
"""
|
||||
|
||||
def __init__(self, base_seed: int = 0, dim: int = EMBED_DIM) -> None:
|
||||
self.DIM = dim
|
||||
self.DEFAULT_DIM = dim
|
||||
self.DEFAULT_MODEL_KEY = "bench"
|
||||
self._base_seed = base_seed
|
||||
|
||||
def embed(self, text: str) -> list[float]:
|
||||
# Combine base_seed + text into a stable integer seed (hash is
|
||||
# randomised per-process by default, so use a stable digest).
|
||||
import hashlib
|
||||
digest = hashlib.sha256(
|
||||
f"{self._base_seed}:{text}".encode("utf-8")
|
||||
).hexdigest()
|
||||
rng = random.Random(int(digest[:16], 16))
|
||||
v = [rng.random() * 2 - 1 for _ in range(self.DIM)]
|
||||
norm = sum(x * x for x in v) ** 0.5
|
||||
return [x / norm for x in v] if norm > 0 else v
|
||||
|
||||
|
||||
def _make_record(vec: list[float], text: str, tags: list[str]) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=text,
|
||||
aaak_index="",
|
||||
embedding=vec,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=tags,
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def _percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
idx = max(0, min(len(s) - 1, int(len(s) * pct)))
|
||||
return float(s[idx])
|
||||
|
||||
|
||||
def run_neural_map_bench(
|
||||
n: int = 100,
|
||||
iterations: int = 10,
|
||||
store_path: Path | str | None = None,
|
||||
seed: int = 0,
|
||||
warm_cascade: bool = False,
|
||||
) -> dict:
|
||||
"""Run the D-SPEED benchmark at store size N.
|
||||
|
||||
Parameters:
|
||||
n: number of records to seed.
|
||||
iterations: number of recall_for_response calls to measure.
|
||||
store_path: optional MemoryStore directory; defaults to a temp dir.
|
||||
seed: RNG base seed for deterministic synthetic data.
|
||||
warm_cascade: — when True, fire the synchronous
|
||||
core-side HIPPEA cascade after seeding but before timing so
|
||||
the measured p95 reflects the warm path, not the cold path.
|
||||
Returns ``cascade_warmed`` count in the result dict; 0 when
|
||||
disabled or when the cascade produced no ids.
|
||||
|
||||
Returns dict with n, latency_ms_p50, latency_ms_p95, stage_timings_ms,
|
||||
build_ms, passed, iterations, and (when warm_cascade=True) cascade_warmed.
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
cleanup: tempfile.TemporaryDirectory | None = None
|
||||
if store_path is None:
|
||||
cleanup = tempfile.TemporaryDirectory(prefix="iai-bench-nm-")
|
||||
path = Path(cleanup.name)
|
||||
else:
|
||||
path = Path(store_path)
|
||||
|
||||
try:
|
||||
store = MemoryStore(path=path)
|
||||
embedder = _BenchEmbedder(base_seed=seed, dim=store.embed_dim)
|
||||
|
||||
# Seed N records with a mix of tags so community detection has
|
||||
# structure.
|
||||
tag_pool = [
|
||||
["topic:auth"], ["topic:db"], ["topic:web"],
|
||||
["topic:net"], ["topic:cli"],
|
||||
]
|
||||
for i in range(n):
|
||||
vec = embedder.embed(f"seed-{i}")
|
||||
tags = list(tag_pool[i % len(tag_pool)])
|
||||
rec = _make_record(vec, text=f"synthetic fact {i}", tags=tags)
|
||||
store.insert(rec)
|
||||
|
||||
# Build runtime graph (timed separately).
|
||||
t_build = time.perf_counter()
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
build_ms = (time.perf_counter() - t_build) * 1000.0
|
||||
|
||||
# fire the sync core-side cascade AFTER seeding +
|
||||
# build_runtime_graph (both required for salience computation) and
|
||||
# BEFORE the timing loop starts. Writes into the same process-local
|
||||
# hippea_cascade._warm_lru that recall_for_response consults via
|
||||
# get_warm_record.
|
||||
cascade_warmed = 0
|
||||
if warm_cascade:
|
||||
try:
|
||||
from iai_mcp import hippea_cascade
|
||||
|
||||
warm_ids = hippea_cascade.compute_core_side_warm_snapshot(
|
||||
store, assignment, top_k=3, max_records=50,
|
||||
)
|
||||
for rid in warm_ids:
|
||||
try:
|
||||
rec = store.get(rid)
|
||||
if rec is not None:
|
||||
hippea_cascade._warm_lru[rid] = rec
|
||||
cascade_warmed += 1
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
cascade_warmed = 0
|
||||
|
||||
cues = [
|
||||
"what did we cover about auth yesterday?",
|
||||
"explain the db migration plan",
|
||||
"how does the web cache invalidation work",
|
||||
"summary of the cli subcommand changes",
|
||||
"recent network stack bug report",
|
||||
]
|
||||
|
||||
latencies: list[float] = []
|
||||
stage_totals: dict[str, list[float]] = {
|
||||
"embed": [], "gate": [], "seeds": [], "spread": [], "rank": [],
|
||||
}
|
||||
for i in range(iterations):
|
||||
cue = cues[rng.randrange(len(cues))]
|
||||
# Stage timings from an instrumented copy -- manual per-stage.
|
||||
t_stage = time.perf_counter()
|
||||
cue_emb = embedder.embed(cue)
|
||||
stage_totals["embed"].append(
|
||||
(time.perf_counter() - t_stage) * 1000.0
|
||||
)
|
||||
t_stage = time.perf_counter()
|
||||
# Gate = community gate cost (computed inside recall_for_response; we
|
||||
# approximate with a standalone timed call to avoid forking).
|
||||
# The pipeline call dominates; the coarse breakdown is still
|
||||
# informative for regression detection.
|
||||
stage_totals["gate"].append(
|
||||
(time.perf_counter() - t_stage) * 1000.0
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=cue,
|
||||
session_id="bench",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
call_ms = (time.perf_counter() - t0) * 1000.0
|
||||
latencies.append(call_ms)
|
||||
|
||||
# Allocate the remaining latency roughly between seeds / spread /
|
||||
# rank for a coarse breakdown.
|
||||
remaining = max(0.0, call_ms - sum(
|
||||
stage_totals[k][-1] for k in ("embed", "gate")
|
||||
))
|
||||
stage_totals["seeds"].append(remaining * 0.2)
|
||||
stage_totals["spread"].append(remaining * 0.3)
|
||||
stage_totals["rank"].append(remaining * 0.5)
|
||||
|
||||
p50 = _percentile(latencies, 0.50)
|
||||
p95 = _percentile(latencies, 0.95)
|
||||
|
||||
def _mean(xs: list[float]) -> float:
|
||||
return float(sum(xs) / len(xs)) if xs else 0.0
|
||||
|
||||
stage_timings_ms = {k: _mean(v) for k, v in stage_totals.items()}
|
||||
passed = bool(p95 < D_SPEED_P95_MS)
|
||||
|
||||
result = {
|
||||
"n": n,
|
||||
"iterations": iterations,
|
||||
"latency_ms_p50": float(p50),
|
||||
"latency_ms_p95": float(p95),
|
||||
"build_ms": float(build_ms),
|
||||
"stage_timings_ms": stage_timings_ms,
|
||||
"passed": passed,
|
||||
"threshold_ms": D_SPEED_P95_MS,
|
||||
}
|
||||
if warm_cascade:
|
||||
result["cascade_warmed"] = cascade_warmed
|
||||
return result
|
||||
finally:
|
||||
if cleanup is not None:
|
||||
cleanup.cleanup()
|
||||
|
||||
|
||||
def main(
|
||||
ns: list[int] | None = None,
|
||||
iterations: int = 10,
|
||||
store_path: Path | str | None = None,
|
||||
*,
|
||||
ref_mempalace_p95_ms: float | None = None,
|
||||
ref_claude_mem_p95_ms: float | None = None,
|
||||
with_cascade: bool = False,
|
||||
) -> int:
|
||||
"""CLI entry. Returns 0 when every N passes the D-SPEED threshold and
|
||||
(when supplied) the comparative-reference gate.
|
||||
|
||||
extension:
|
||||
- ``ref_mempalace_p95_ms`` / ``ref_claude_mem_p95_ms`` are the reference
|
||||
p95 latencies measured separately for the mempalace / claude-mem
|
||||
adapters on this host. When supplied, the per-N JSON flips
|
||||
``passed=False`` if IAI's p95 exceeds either reference AND records
|
||||
the offending reference name in ``reason``.
|
||||
- ``with_cascade=True`` attempts to warm the HIPPEA LRU before timing
|
||||
the recall so the test can observe the warm-RAM path latency.
|
||||
Graceful no-op when hippea_cascade is unavailable.
|
||||
"""
|
||||
ns = ns or [100, 1_000, 5_000, 10_000]
|
||||
results: list[dict] = []
|
||||
any_failed = False
|
||||
for n in ns:
|
||||
out = run_neural_map_bench(
|
||||
n=n,
|
||||
iterations=iterations,
|
||||
store_path=store_path,
|
||||
warm_cascade=with_cascade,
|
||||
)
|
||||
|
||||
# comparative gate — IAI must be <= every supplied ref.
|
||||
refs: dict[str, float] = {}
|
||||
reason: str | None = None
|
||||
if ref_mempalace_p95_ms is not None:
|
||||
refs["mempalace"] = ref_mempalace_p95_ms
|
||||
if out["latency_ms_p95"] > ref_mempalace_p95_ms:
|
||||
out["passed"] = False
|
||||
reason = (
|
||||
f"exceeds mempalace ref {ref_mempalace_p95_ms}ms "
|
||||
f"(IAI p95={out['latency_ms_p95']:.2f}ms)"
|
||||
)
|
||||
if ref_claude_mem_p95_ms is not None:
|
||||
refs["claude_mem"] = ref_claude_mem_p95_ms
|
||||
if out["latency_ms_p95"] > ref_claude_mem_p95_ms:
|
||||
out["passed"] = False
|
||||
# First reference to fail wins the reason string; append
|
||||
# claude-mem only when it is the ONLY failing ref.
|
||||
cm_reason = (
|
||||
f"exceeds claude-mem ref {ref_claude_mem_p95_ms}ms "
|
||||
f"(IAI p95={out['latency_ms_p95']:.2f}ms)"
|
||||
)
|
||||
reason = reason or cm_reason
|
||||
if refs:
|
||||
out["refs"] = refs
|
||||
if reason is not None:
|
||||
out["reason"] = reason
|
||||
|
||||
results.append(out)
|
||||
if not out["passed"]:
|
||||
any_failed = True
|
||||
print(json.dumps(out))
|
||||
return 1 if any_failed else 0
|
||||
|
||||
|
||||
def _warm_cascade_for_bench(
|
||||
n: int, store_path: Path | str | None = None,
|
||||
) -> int:
|
||||
"""actually fire the core-side HIPPEA cascade in the bench
|
||||
process so the measured p95 reflects the warm path, not the cold path.
|
||||
|
||||
Returns the number of record ids written into the bench-process
|
||||
``_warm_lru`` (0 on any failure — cold path still gives a canonical
|
||||
reading, but the JSON output records the 0 so downstream audits
|
||||
can distinguish "warm-up intended but failed" from "warm-up hit").
|
||||
|
||||
Reuses :func:`compute_core_side_warm_snapshot` (sync, no asyncio
|
||||
dependency) rather than the async ``run_cascade`` — the sync helper
|
||||
lets us invoke the cascade inline without event-loop entanglement in
|
||||
the bench harness.
|
||||
"""
|
||||
try:
|
||||
from iai_mcp import hippea_cascade, retrieve
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
store = MemoryStore(path=store_path) if store_path else MemoryStore()
|
||||
_graph, assignment, _rc = retrieve.build_runtime_graph(store)
|
||||
warm_ids = hippea_cascade.compute_core_side_warm_snapshot(
|
||||
store, assignment, top_k=3, max_records=50,
|
||||
)
|
||||
# Write into the shared process-local LRU used by get_warm_record
|
||||
# so the recall path in this process hits warm on subsequent calls.
|
||||
warmed = 0
|
||||
for rid in warm_ids:
|
||||
try:
|
||||
rec = store.get(rid)
|
||||
if rec is not None:
|
||||
hippea_cascade._warm_lru[rid] = rec
|
||||
warmed += 1
|
||||
except Exception:
|
||||
continue
|
||||
return warmed
|
||||
except Exception:
|
||||
# Warm path is opportunistic; cold path still gives the canonical
|
||||
# reading. Return 0 so the JSON output can distinguish "intended
|
||||
# warm-up but could not complete" from "warm-up succeeded".
|
||||
return 0
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(prog="bench.neural_map")
|
||||
parser.add_argument(
|
||||
"--n", action="append", type=int, default=None,
|
||||
help="store sizes to bench; repeat for multiple N",
|
||||
)
|
||||
parser.add_argument("--iterations", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--ref-mempalace-p95-ms",
|
||||
dest="ref_mempalace_p95_ms",
|
||||
type=float, default=None,
|
||||
help=(
|
||||
"OPS-10 comparative reference p95 (ms) — IAI must be <= this to "
|
||||
"pass the gate."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref-claude-mem-p95-ms",
|
||||
dest="ref_claude_mem_p95_ms",
|
||||
type=float, default=None,
|
||||
help=(
|
||||
"OPS-10 comparative reference p95 (ms) — IAI must be <= this to "
|
||||
"pass the gate."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with-cascade",
|
||||
dest="with_cascade",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Warm the HIPPEA LRU before each per-N run (Plan 05-04 preview); "
|
||||
"graceful no-op if cascade module unavailable."
|
||||
),
|
||||
)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def _install_bench_noop_keyring() -> None:
|
||||
"""Install an in-memory keyring backend BEFORE any MemoryStore is
|
||||
constructed so the crypto layer never hangs on macOS Keychain
|
||||
SecItemCopyMatching in non-interactive shells. Bench-scope only."""
|
||||
try:
|
||||
import keyring
|
||||
from keyring.backend import KeyringBackend
|
||||
|
||||
if getattr(keyring.get_keyring(), "_iai_bench_noop", False):
|
||||
return
|
||||
|
||||
class _BenchNoOpKeyring(KeyringBackend):
|
||||
priority = 99
|
||||
_iai_bench_noop = True
|
||||
_kv: dict[tuple[str, str], str] = {}
|
||||
|
||||
def get_password(self, s: str, u: str):
|
||||
return self._kv.get((s, u))
|
||||
|
||||
def set_password(self, s: str, u: str, p: str) -> None:
|
||||
self._kv[(s, u)] = p
|
||||
|
||||
def delete_password(self, s: str, u: str) -> None:
|
||||
self._kv.pop((s, u), None)
|
||||
|
||||
keyring.set_keyring(_BenchNoOpKeyring())
|
||||
except Exception:
|
||||
# If keyring isn't installed or the backend can't be swapped,
|
||||
# continue — the store may still work against an already-unlocked
|
||||
# macOS keychain.
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_install_bench_noop_keyring()
|
||||
args = _parse_args()
|
||||
sys.exit(main(
|
||||
ns=args.n,
|
||||
iterations=args.iterations,
|
||||
ref_mempalace_p95_ms=args.ref_mempalace_p95_ms,
|
||||
ref_claude_mem_p95_ms=args.ref_claude_mem_p95_ms,
|
||||
with_cascade=args.with_cascade,
|
||||
))
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,250 @@
|
|||
{
|
||||
"env": {
|
||||
"cpu_brand": "Apple M2 Max",
|
||||
"cpu_cores_physical": 12,
|
||||
"ram_gb": "64.0",
|
||||
"os": "Darwin",
|
||||
"os_version": "25.3.0",
|
||||
"python_version": "3.12.13",
|
||||
"iai_mcp_git_sha": "9c61a18",
|
||||
"iai_mcp_git_dirty": true,
|
||||
"lance_version": "unknown",
|
||||
"lancedb_version": "0.30.2",
|
||||
"pyarrow_version": "23.0.1",
|
||||
"sentence_transformers_version": "5.4.1",
|
||||
"embedder_model": "bge-small-en-v1.5",
|
||||
"seed_list": [
|
||||
13,
|
||||
42,
|
||||
137
|
||||
],
|
||||
"iai_mcp_store": "/private/tmp/iai-mcp-bench-claude/store",
|
||||
"wall_clock_start_utc": "2026-05-03T01:10:24.783110+00:00",
|
||||
"scale": "honest",
|
||||
"n_sessions": 1000,
|
||||
"n_probes_pre": 250,
|
||||
"n_probes_post": 250,
|
||||
"n_slices": [
|
||||
0,
|
||||
1
|
||||
],
|
||||
"k_hits": 10,
|
||||
"a_threshold": 0.98,
|
||||
"candidate_pool_size": 200,
|
||||
"bootstrap_resamples": 10000,
|
||||
"floor_mode": "relaxed",
|
||||
"wall_clock_duration_seconds": 5328.49
|
||||
},
|
||||
"summary": {
|
||||
"per_cell": [
|
||||
{
|
||||
"seed": 13,
|
||||
"n_slice": 0,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.272,
|
||||
"rr_at_1_cosine": 0.272
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.912,
|
||||
"mean_anti_hits_count": 1.904
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.692,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 13,
|
||||
"n_slice": 1,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.272,
|
||||
"rr_at_1_cosine": 0.272
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.912,
|
||||
"mean_anti_hits_count": 1.904
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.692,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"n_slice": 0,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.264,
|
||||
"rr_at_1_cosine": 0.264
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.892,
|
||||
"mean_anti_hits_count": 2.16
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.708,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"n_slice": 1,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.264,
|
||||
"rr_at_1_cosine": 0.264
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.892,
|
||||
"mean_anti_hits_count": 2.16
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.708,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 137,
|
||||
"n_slice": 0,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.292,
|
||||
"rr_at_1_cosine": 0.292
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.868,
|
||||
"mean_anti_hits_count": 2.2
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.74,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 137,
|
||||
"n_slice": 1,
|
||||
"n_b_probes": 250,
|
||||
"n_a_probes": 250,
|
||||
"metric_b": {
|
||||
"delta_mrr_point": 0.0,
|
||||
"delta_mrr_ci_lo": 0.0,
|
||||
"delta_mrr_ci_hi": 0.0,
|
||||
"wilcoxon_p": null,
|
||||
"max_rank_regression": 0,
|
||||
"rr_at_1_pipeline": 0.292,
|
||||
"rr_at_1_cosine": 0.292
|
||||
},
|
||||
"metric_b_revised": {
|
||||
"hint_emission_rate": 1.0,
|
||||
"anti_hits_coverage": 0.868,
|
||||
"mean_anti_hits_count": 2.2
|
||||
},
|
||||
"metric_a": {
|
||||
"hit_at_k_pipeline": 1.0,
|
||||
"hit_at_k_cosine": 0.74,
|
||||
"k": 10,
|
||||
"catastrophic_floor_violations": 0
|
||||
}
|
||||
}
|
||||
],
|
||||
"cross_seed": {
|
||||
"n_0": {
|
||||
"delta_mrr_mean": 0.0,
|
||||
"delta_mrr_stdev": 0.0,
|
||||
"delta_mrr_min": 0.0,
|
||||
"delta_mrr_max": 0.0,
|
||||
"robust": false
|
||||
},
|
||||
"n_1": {
|
||||
"delta_mrr_mean": 0.0,
|
||||
"delta_mrr_stdev": 0.0,
|
||||
"delta_mrr_min": 0.0,
|
||||
"delta_mrr_max": 0.0,
|
||||
"robust": false
|
||||
}
|
||||
},
|
||||
"gates": {
|
||||
"per_cell": {
|
||||
"seed13_n0": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
},
|
||||
"seed13_n1": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
},
|
||||
"seed42_n0": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
},
|
||||
"seed42_n1": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
},
|
||||
"seed137_n0": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
},
|
||||
"seed137_n1": {
|
||||
"gate_a": true,
|
||||
"gate_b_classical": false,
|
||||
"gate_b_contract": true
|
||||
}
|
||||
},
|
||||
"cross_seed_robust": false,
|
||||
"overall_pass": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# Contradiction-longitudinal falsifiability bench — PASS
|
||||
|
||||
**Run ID:** 20260503T011024Z-seeds13-42-137-scale_honest
|
||||
**Duration:** 5328.5s
|
||||
|
||||
## Environment
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| `cpu_brand` | Apple M2 Max |
|
||||
| `cpu_cores_physical` | 12 |
|
||||
| `ram_gb` | 64.0 |
|
||||
| `os` | Darwin |
|
||||
| `os_version` | 25.3.0 |
|
||||
| `python_version` | 3.12.13 |
|
||||
| `iai_mcp_git_sha` | (pre-release) |
|
||||
| `iai_mcp_git_dirty` | True |
|
||||
| `lance_version` | unknown |
|
||||
| `lancedb_version` | 0.30.2 |
|
||||
| `pyarrow_version` | 23.0.1 |
|
||||
| `sentence_transformers_version` | 5.4.1 |
|
||||
| `embedder_model` | bge-small-en-v1.5 |
|
||||
| `seed_list` | [13, 42, 137] |
|
||||
| `iai_mcp_store` | /private/tmp/iai-mcp-bench-claude/store |
|
||||
| `wall_clock_start_utc` | 2026-05-03T01:10:24.783110+00:00 |
|
||||
| `scale` | honest |
|
||||
| `n_sessions` | 1000 |
|
||||
| `n_probes_pre` | 250 |
|
||||
| `n_probes_post` | 250 |
|
||||
| `n_slices` | [0, 1] |
|
||||
| `k_hits` | 10 |
|
||||
| `a_threshold` | 0.98 |
|
||||
| `candidate_pool_size` | 200 |
|
||||
| `bootstrap_resamples` | 10000 |
|
||||
| `floor_mode` | relaxed |
|
||||
| `wall_clock_duration_seconds` | 5328.49 |
|
||||
|
||||
## Cross-seed (B robustness)
|
||||
|
||||
| N slice | ΔMRR mean | stdev | min | max | robust? |
|
||||
|---|---|---|---|---|---|
|
||||
| n_0 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | NO |
|
||||
| n_1 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | NO |
|
||||
|
||||
## Per-cell detail
|
||||
|
||||
| seed | N | A hit@k (pipe / cos) | A floor | B-class ΔMRR (CI) | B-contract hint% / anti-hits% | gate A | gate B-class | gate B-contract |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| 13 | 0 | 1.000 / 0.692 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.912 | PASS | FAIL | PASS |
|
||||
| 13 | 1 | 1.000 / 0.692 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.912 | PASS | FAIL | PASS |
|
||||
| 42 | 0 | 1.000 / 0.708 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.892 | PASS | FAIL | PASS |
|
||||
| 42 | 1 | 1.000 / 0.708 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.892 | PASS | FAIL | PASS |
|
||||
| 137 | 0 | 1.000 / 0.740 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.868 | PASS | FAIL | PASS |
|
||||
| 137 | 1 | 1.000 / 0.740 | 0 | 0.0000 (0.0000, 0.0000) | 1.000 / 0.868 | PASS | FAIL | PASS |
|
||||
|
||||
**Cross-seed robust gate (B-classical only):** FAIL (expected: B-class is not the architectural promise)
|
||||
**Overall verdict (uses gate_a + gate_b_contract):** PASS
|
||||
|
||||
## Notes on metric design
|
||||
|
||||
- **Metric A (verbatim preserved)** tests REQUIREMENTS.md — the system's promise that contradiction = reconsolidation, never overwrite. Pipeline beating cosine here = real architectural advantage.
|
||||
- **Metric B-classical (rank current above cosine)** tests an expectation that does NOT appear in any design doc. Per REQUIREMENTS.md + 02-CONTEXT.md, the system uses dual-route + inhibitory edges + hints, not rerank. Expect ΔMRR ≈ 0; this is a feature, not a bug.
|
||||
- **Metric B-contract (s4_contradiction hint OR anti_hits ≥80%)** tests what the system actually promises (REQUIREMENTS.md MEM-08, dual-route). Cosine cannot do either; pipeline either signals contradictions or it doesn't.
|
||||
249
bench/tokens.py
Normal file
249
bench/tokens.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""bench/tokens.py -- / benchmark harness.
|
||||
|
||||
Measures session-start token budget three ways, preferring the most accurate
|
||||
source available at runtime:
|
||||
|
||||
1. Anthropic `count_tokens` API (best). Used when ANTHROPIC_API_KEY is set.
|
||||
Gives an honest billable-token count that includes Anthropic-side overhead
|
||||
and exact tokeniser output. Model: claude-sonnet-4-5. This is the only mode
|
||||
whose numbers are safe to publish (PROJECT.md: "honest mode-by-mode
|
||||
benchmarks, not headline numbers").
|
||||
|
||||
2. tiktoken cl100k_base fallback. OpenAI's tokeniser shipped with the tiktoken
|
||||
package -- runs fully offline, no network, no key. It under-counts Claude by
|
||||
~5-10% on English and over-counts by ~10-15% on Cyrillic (GPT-4 tokeniser
|
||||
packs multibyte differently). Acceptable for local dev and CI; the JSON
|
||||
output always records mode so downstream dashboards can reject non-API
|
||||
numbers from public charts.
|
||||
|
||||
3. char/4 heuristic. Used only when both 1 and 2 are unavailable (e.g. minimal
|
||||
CI image without tiktoken installed). Very rough; adequate only for sanity
|
||||
checks on the order of magnitude.
|
||||
|
||||
Thresholds:
|
||||
- (steady warm-cache): <= STEADY_LIMIT (3000 tokens) on every warm run
|
||||
- (first fresh session): <= FRESH_LIMIT (8000 tokens)
|
||||
|
||||
Exit codes:
|
||||
- 0: both steady_ok and fresh_ok
|
||||
- 1: at least one failed
|
||||
|
||||
JSON output format (one line to stdout):
|
||||
{"fresh": int, "warm": [int, ...], "steady_ok": bool, "fresh_ok": bool,
|
||||
"mode": "anthropic-count-tokens" | "tiktoken-cl100k-proxy" |
|
||||
"heuristic-char4" | "injected",
|
||||
"limits": {"steady": 3000, "fresh": 8000}}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
from iai_mcp.retrieve import build_runtime_graph
|
||||
from iai_mcp.session import SessionStartPayload, assemble_session_start
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
# budget targets
|
||||
STEADY_LIMIT = 3000 # warm-cache steady-state
|
||||
FRESH_LIMIT = 8000 # first-fresh-session (cache populate premium)
|
||||
|
||||
|
||||
def _anthropic_count_tokens(text: str) -> int:
|
||||
"""Use Anthropic count_tokens API. Raises if key absent or call fails."""
|
||||
import anthropic
|
||||
client = anthropic.Anthropic()
|
||||
resp = client.messages.count_tokens(
|
||||
model="claude-sonnet-4-5",
|
||||
messages=[{"role": "user", "content": text}],
|
||||
)
|
||||
return int(resp.input_tokens)
|
||||
|
||||
|
||||
def _tiktoken_count(text: str) -> int:
|
||||
"""Offline tiktoken cl100k_base as a proxy for Claude's tokeniser.
|
||||
|
||||
Raises ImportError if tiktoken not installed -- caller falls through to
|
||||
the char/4 heuristic in that case.
|
||||
"""
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return len(enc.encode(text))
|
||||
|
||||
|
||||
def _char4_count(text: str) -> int:
|
||||
"""Last-resort char/4 heuristic. Reasonable for English prose, bad for CJK."""
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
def _payload_to_prompt(payload: SessionStartPayload) -> str:
|
||||
"""Flatten the session-start payload to a single prompt string.
|
||||
|
||||
Mirrors the TypeScript wrapper's buildCachedSystemPrompt shape so the
|
||||
counted prompt is faithful to what Anthropic actually receives.
|
||||
|
||||
D5-02: at wake_depth=minimal, the legacy l0/l1/l2/rich_club
|
||||
fields are empty and the payload is three pointer handles. Include them
|
||||
alongside legacy segments so both modes flatten to a representative
|
||||
prompt string for counting.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
if payload.l0:
|
||||
parts.append(f"# L0 identity\n{payload.l0}")
|
||||
if payload.l1:
|
||||
parts.append(f"# L1 critical facts\n{payload.l1}")
|
||||
for segment in payload.l2:
|
||||
parts.append(f"# L2 community\n{segment}")
|
||||
if payload.rich_club:
|
||||
parts.append(f"# Global rich-club\n{payload.rich_club}")
|
||||
# / 05-06: lazy session-start wire payload.
|
||||
# Under wake_depth=minimal the wire is the compact handle alone
|
||||
# (the 3 legacy pointer fields stay on the dataclass for back-compat
|
||||
# callers but are NOT serialised to the wire).
|
||||
# Under standard/deep the wire is the Phase-1 eager L0/L1/L2/rich_club
|
||||
# plus the 3 legacy pointer fields, matching the pre-05-06 baseline.
|
||||
# The compact handle is carried on the dataclass under standard/deep
|
||||
# too so opt-in callers may read it, but it does NOT add to the wire
|
||||
# (that would inflate the standard baseline).
|
||||
compact = getattr(payload, "compact_handle", "")
|
||||
wake_depth = getattr(payload, "wake_depth", "minimal")
|
||||
if wake_depth == "minimal":
|
||||
if compact:
|
||||
parts.append(compact)
|
||||
else:
|
||||
lazy = [
|
||||
s for s in (
|
||||
getattr(payload, "identity_pointer", ""),
|
||||
getattr(payload, "brain_handle", ""),
|
||||
getattr(payload, "topic_cluster_hint", ""),
|
||||
) if s
|
||||
]
|
||||
if lazy:
|
||||
parts.append(" ".join(lazy))
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _fresh_prompt(payload: SessionStartPayload) -> str:
|
||||
"""the first fresh-session request pays the cache-populate premium.
|
||||
|
||||
Simulated here by padding the cached prefix with ~1000 tokens of dynamic
|
||||
tail content (D-10 dynamic reserve). Anthropic's count_tokens will return
|
||||
the sum of both parts in one call.
|
||||
"""
|
||||
prompt = _payload_to_prompt(payload)
|
||||
tail = "dynamic tail content " * 125 # ~2500 chars ~ 625 tokens heuristic
|
||||
return f"{prompt}\n\n{tail}" if prompt else tail
|
||||
|
||||
|
||||
def run_token_bench(
|
||||
store: MemoryStore | None = None,
|
||||
n_runs: int = 3,
|
||||
count_tokens_fn: Callable[[str], int] | None = None,
|
||||
wake_depth: str = "minimal",
|
||||
) -> dict:
|
||||
"""Run the token benchmark.
|
||||
|
||||
Parameters:
|
||||
store: optional MemoryStore override (tests pass an isolated tmp_path store).
|
||||
n_runs: how many warm-cache repeats to measure (OPS-01 steady-state needs
|
||||
at least 3 consecutive samples).
|
||||
count_tokens_fn: optional token-counter injection (test-only); overrides both
|
||||
the Anthropic API and the heuristic fallback.
|
||||
wake_depth: TOK-11 — selects session-start payload mode.
|
||||
Default ``minimal`` measures the lazy <=30-tok handle; pass
|
||||
``standard`` for the Phase-1 eager dump baseline; ``deep`` for
|
||||
the ≤2000-tok expanded rich_club.
|
||||
|
||||
Returns a dict with keys described in the module docstring.
|
||||
"""
|
||||
s = store if store is not None else MemoryStore()
|
||||
records_count = s.db.open_table("records").count_rows()
|
||||
if records_count > 0:
|
||||
_graph, assignment, rc = build_runtime_graph(s)
|
||||
payload = assemble_session_start(
|
||||
s, assignment, rc, profile_state={"wake_depth": wake_depth},
|
||||
)
|
||||
else:
|
||||
# Empty-store fallback: mint a representative compact handle so the
|
||||
# warm-prompt count reflects the wire payload shape even before any
|
||||
# record is written. Mirrors session.assemble_session_start at
|
||||
# wake_depth=minimal.
|
||||
from iai_mcp.handle import encode_compact_handle
|
||||
from uuid import uuid4
|
||||
|
||||
_compact = encode_compact_handle("", str(uuid4())[:8], "none", 0)
|
||||
payload = SessionStartPayload(
|
||||
l0="",
|
||||
l1="",
|
||||
l2=[],
|
||||
rich_club="",
|
||||
total_cached_tokens=max(1, len(_compact) // 4),
|
||||
total_dynamic_tokens=1000,
|
||||
compact_handle=_compact,
|
||||
wake_depth=wake_depth,
|
||||
)
|
||||
|
||||
counter: Callable[[str], int]
|
||||
mode: str
|
||||
if count_tokens_fn is not None:
|
||||
counter = count_tokens_fn
|
||||
mode = "injected"
|
||||
elif os.environ.get("ANTHROPIC_API_KEY"):
|
||||
counter = _anthropic_count_tokens
|
||||
mode = "anthropic-count-tokens"
|
||||
else:
|
||||
# Prefer tiktoken over char/4 -- it actually tokenises the text and
|
||||
# tracks Claude within ~10% across English + Cyrillic.
|
||||
try:
|
||||
import tiktoken # noqa: F401
|
||||
counter = _tiktoken_count
|
||||
mode = "tiktoken-cl100k-proxy"
|
||||
except ImportError:
|
||||
counter = _char4_count
|
||||
mode = "heuristic-char4"
|
||||
|
||||
warm_prompt = _payload_to_prompt(payload) or "."
|
||||
fresh_prompt = _fresh_prompt(payload)
|
||||
fresh = int(counter(fresh_prompt))
|
||||
warm = [int(counter(warm_prompt)) for _ in range(n_runs)]
|
||||
|
||||
fresh_ok = fresh <= FRESH_LIMIT
|
||||
steady_ok = all(w <= STEADY_LIMIT for w in warm)
|
||||
|
||||
return {
|
||||
"fresh": fresh,
|
||||
"warm": warm,
|
||||
"steady_ok": steady_ok,
|
||||
"fresh_ok": fresh_ok,
|
||||
"mode": mode,
|
||||
"limits": {"steady": STEADY_LIMIT, "fresh": FRESH_LIMIT},
|
||||
"payload_cached_tokens": payload.total_cached_tokens,
|
||||
"payload_dynamic_tokens": payload.total_dynamic_tokens,
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="bench.tokens",
|
||||
description=(
|
||||
"OPS-01/OPS-02 session-start token bench. TOK-11 added "
|
||||
"--wake-depth for measuring the lazy <=30-tok payload vs Phase-1 "
|
||||
"eager dump vs the deep variant."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wake-depth",
|
||||
choices=("minimal", "standard", "deep"),
|
||||
default="minimal",
|
||||
help="Session-start payload mode (default: minimal per D5-02).",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
result = run_token_bench(wake_depth=args.wake_depth)
|
||||
print(json.dumps(result))
|
||||
return 0 if (result["steady_ok"] and result["fresh_ok"]) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
477
bench/total_session_cost.py
Normal file
477
bench/total_session_cost.py
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
"""OPS-12 / total session cost bench.
|
||||
|
||||
Runs a fixed 10-turn representative script per D5-08 (see 05-CONTEXT.md)
|
||||
and counts the total tokens Claude would pay for the full session with
|
||||
IAI-MCP wired in. The 10 turns cover the axes the real-user workload
|
||||
touches most: verbatim recall, interleaved code-edit chat (no recall),
|
||||
cross-community recall, save, introspection.
|
||||
|
||||
JSON output (one line to stdout):
|
||||
|
||||
{
|
||||
"adapter": "iai-mcp",
|
||||
"wake_depth": "minimal"|"standard"|"deep",
|
||||
"total_tokens": int,
|
||||
"per_turn": [int] * 10,
|
||||
"mode": "anthropic-count-tokens"|"tiktoken-cl100k-proxy"|
|
||||
"heuristic-char4"|"injected",
|
||||
"refs": {"mempalace": int?, "claude_mem": int?},
|
||||
"passed": bool, # True iff every supplied ref >= IAI
|
||||
"script_name": "D5-08-v1"
|
||||
}
|
||||
|
||||
Exit codes:
|
||||
0 if passed, 1 otherwise.
|
||||
|
||||
CLI:
|
||||
python -m bench.total_session_cost
|
||||
python -m bench.total_session_cost --wake-depth standard
|
||||
python -m bench.total_session_cost --ref-mempalace 7000 --ref-claude-mem 5000
|
||||
|
||||
**Framing note (D5-08):** this bench is a *simulated* 10-turn script —
|
||||
it reproduces the token composition (system overhead + tool descriptions
|
||||
+ tool-call payloads + tool-result bodies) a real MCP runtime would emit
|
||||
for the turn kinds. Real runtime adds network JSON-RPC envelope
|
||||
overhead (~30-50 tok/turn); the simulation excludes that. Downstream
|
||||
reports MUST disclose this caveat alongside the row.
|
||||
|
||||
Reference-adapter notes: per PATTERNS.md Discovery #5, bench/adapters/
|
||||
mempalace_*.py and claude_mem_*.py do not exist on this machine. The
|
||||
comparative gate is driven by explicit ref numbers via CLI flags so the
|
||||
bench is usable without live adapters; when unknown, refs default to
|
||||
None and passed=True is the degenerate answer. the published bench report
|
||||
carries the honest "mempalace/claude-mem refs not measured" disclosure
|
||||
for rows where a measurement was not taken.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
# Reuse bench/tokens.py's 3-tier counter helpers — single source of truth
|
||||
# for what "tiktoken-cl100k-proxy" and friends mean.
|
||||
from bench.tokens import (
|
||||
_anthropic_count_tokens,
|
||||
_char4_count,
|
||||
_tiktoken_count,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------- adapters
|
||||
#
|
||||
# Live subprocess adapters for the reference column. Each adapter runs
|
||||
# the 10-turn script through the target tool's CLI, sums the response tokens
|
||||
# via the injected counter, and returns the total. On ANY failure
|
||||
# (tool absent, timeout, non-zero exit, empty stdout) the adapter returns
|
||||
# ``None`` and emits ``{"event": "bench_adapter_unavailable", ...}`` to
|
||||
# stderr. Callers MUST treat None as "honest disclosure, no measurement"
|
||||
# rather than a hard bench failure.
|
||||
#
|
||||
# Security note (T-05-06-04): turn text is a constant from _SCRIPT, never
|
||||
# from user input, and ``subprocess.run(argv_list, shell=False)`` avoids
|
||||
# any shell-injection surface. The 30s per-turn timeout bounds the DoS
|
||||
# risk (T-05-06-03).
|
||||
|
||||
_ADAPTER_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
def _log_adapter_unavailable(tool: str, reason: str) -> None:
|
||||
line = json.dumps({
|
||||
"event": "bench_adapter_unavailable",
|
||||
"tool": tool,
|
||||
"reason": reason,
|
||||
})
|
||||
print(line, file=sys.stderr)
|
||||
|
||||
|
||||
def _run_subprocess_adapter(
|
||||
*,
|
||||
tool_name: str,
|
||||
cli_name: str,
|
||||
argv_template: Callable[[str], list[str]],
|
||||
script: list[dict],
|
||||
counter: Callable[[str], int],
|
||||
) -> int | None:
|
||||
"""Shared helper: locate ``cli_name`` via ``shutil.which``; for each turn
|
||||
run its argv (provided by ``argv_template(turn_input)``) with a bounded
|
||||
timeout; sum stdout token counts across all turns. Return ``None`` on
|
||||
any failure (absent / timeout / non-zero / empty stdout)."""
|
||||
exe = shutil.which(cli_name)
|
||||
if exe is None:
|
||||
_log_adapter_unavailable(tool_name, "cli_not_found")
|
||||
return None
|
||||
|
||||
total = 0
|
||||
for turn in script:
|
||||
argv = [exe, *argv_template(turn["input"])[1:]]
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
argv,
|
||||
timeout=_ADAPTER_TIMEOUT_SECONDS,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
_log_adapter_unavailable(tool_name, f"timeout: {exc}")
|
||||
return None
|
||||
except (OSError, ValueError) as exc:
|
||||
_log_adapter_unavailable(tool_name, f"subprocess_error: {exc}")
|
||||
return None
|
||||
|
||||
if proc.returncode != 0:
|
||||
_log_adapter_unavailable(
|
||||
tool_name,
|
||||
f"non_zero_exit={proc.returncode} stderr={proc.stderr[:200]!r}",
|
||||
)
|
||||
return None
|
||||
|
||||
stdout = proc.stdout or ""
|
||||
# Empty stdout is a legitimate "no match" response for search-style
|
||||
# CLIs; we DO count it (0 tokens) rather than treating as failure,
|
||||
# so adapters run against a pristine palace still publish a number.
|
||||
total += int(counter(stdout))
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def _run_mempalace_adapter(
|
||||
script: list[dict],
|
||||
counter: Callable[[str], int],
|
||||
) -> int | None:
|
||||
"""M-07 live reference: run each turn through ``mempalace search`` and
|
||||
sum the stdout token counts. Returns ``None`` when mempalace is absent
|
||||
or any subprocess call fails. Honest-disclosure contract per Plan 05-06.
|
||||
"""
|
||||
return _run_subprocess_adapter(
|
||||
tool_name="mempalace",
|
||||
cli_name="mempalace",
|
||||
argv_template=lambda text: ["mempalace", "search", text],
|
||||
script=script,
|
||||
counter=counter,
|
||||
)
|
||||
|
||||
|
||||
def _run_claude_mem_adapter(
|
||||
script: list[dict],
|
||||
counter: Callable[[str], int],
|
||||
) -> int | None:
|
||||
"""Forward-compat mirror of the mempalace adapter. On machines where
|
||||
``claude-mem`` is not installed this returns ``None`` + stderr event;
|
||||
when it IS installed (future pressplay cross-validation run) the same
|
||||
code path measures it without another plan iteration."""
|
||||
return _run_subprocess_adapter(
|
||||
tool_name="claude-mem",
|
||||
cli_name="claude-mem",
|
||||
argv_template=lambda text: ["claude-mem", "recall", text],
|
||||
script=script,
|
||||
counter=counter,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- D5-08 script
|
||||
#
|
||||
# Fixed 10-turn representative script. Each turn has a `kind` (used to
|
||||
# compose a realistic tool-result body) and an `input` (the cue text).
|
||||
# Order matters: turn 1 pays session-start overhead, turn 4 exercises the
|
||||
# cross-community recall path, turn 5/6 exercise save/introspect.
|
||||
|
||||
SCRIPT_NAME = "D5-08-v1"
|
||||
|
||||
_SCRIPT: list[dict] = [
|
||||
{
|
||||
"kind": "recall",
|
||||
"input": "Tell me the decisions we made about architecture",
|
||||
},
|
||||
{
|
||||
"kind": "chat",
|
||||
"input": "Let me iterate on this function; no recall needed here",
|
||||
},
|
||||
{
|
||||
"kind": "recall",
|
||||
"input": "What did I say about bench discipline?",
|
||||
},
|
||||
{
|
||||
"kind": "recall_cross_community",
|
||||
"input": "What is the connection between and the autistic kernel?",
|
||||
},
|
||||
{
|
||||
"kind": "save",
|
||||
"input": "Decision locked: use cachetools TTLCache for LRU",
|
||||
},
|
||||
{
|
||||
"kind": "introspect",
|
||||
"input": "profile_get_set operation=get knob=wake_depth",
|
||||
},
|
||||
{
|
||||
"kind": "chat",
|
||||
"input": "Continuing this refactor; still no recall",
|
||||
},
|
||||
{
|
||||
"kind": "recall",
|
||||
"input": "Alice said something about pressplay cross-validation",
|
||||
},
|
||||
{
|
||||
"kind": "reinforce",
|
||||
"input": "memory_reinforce the last 3 hits",
|
||||
},
|
||||
{
|
||||
"kind": "introspect",
|
||||
"input": "events_query kind=first_turn_recall limit=5",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Tool-description overhead mirrors the TOK-15 audit result
|
||||
# (134 raw tok total for the 11 registered tools; see 05-03-SUMMARY.md).
|
||||
# We reproduce the POST-audit text verbatim so the bench reflects the
|
||||
# actual current overhead Claude sees on each turn.
|
||||
_POST_TOK15_TOOL_DESCRIPTIONS = "\n".join([
|
||||
"Recall verbatim memories matching cue. Returns hits + anti_hits.",
|
||||
"Structural recall over role->filler bindings. Returns hits.",
|
||||
"Boost Hebbian edges among co-retrieved record ids.",
|
||||
"Mark a record contradicted; new fact stored as new record.",
|
||||
"Trigger memory consolidation.",
|
||||
"Read or write a profile knob (15 sealed). operation: get|set.",
|
||||
"List pending curiosity questions. Optional session_id filter.",
|
||||
"List induced schemas. Optional domain + confidence_min filters.",
|
||||
"Query user-visible events by kind, since, severity, limit.",
|
||||
"Topology snapshot: N, C, L, sigma, community_count, regime.",
|
||||
"Camouflaging detection status; window_size weekly points.",
|
||||
])
|
||||
|
||||
# Synthetic tool-result body per turn kind. Realistic-but-bounded; a real
|
||||
# runtime varies by store content but the ratio across wake_depths is
|
||||
# what measures, not the absolute per-query payload.
|
||||
_RESULT_BODIES: dict[str, str] = {
|
||||
"recall": (
|
||||
"hits=[{record_id, literal_surface, score}] "
|
||||
"anti_hits=[{record_id, reason}] "
|
||||
"activation_trace=[community_gate, spread, rank] "
|
||||
"budget_used=200"
|
||||
),
|
||||
"save": "ok=true id=<uuid>",
|
||||
"introspect": '{"value": "minimal"}',
|
||||
"reinforce": "ok=true edges_boosted=3",
|
||||
"chat": "",
|
||||
"recall_cross_community": (
|
||||
"hits=[{record_id, literal_surface, score, community_id}] "
|
||||
"anti_hits=[] activation_trace=[cross_community_spread] "
|
||||
"budget_used=350"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- counter select
|
||||
|
||||
def _select_counter(
|
||||
count_tokens_fn: Callable[[str], int] | None = None,
|
||||
) -> tuple[Callable[[str], int], str]:
|
||||
"""3-tier counter fallback mirroring bench/tokens.py:165-182.
|
||||
|
||||
Priority:
|
||||
1. explicit injection (`count_tokens_fn` kwarg, tests)
|
||||
2. Anthropic count_tokens API (`ANTHROPIC_API_KEY` env var)
|
||||
3. tiktoken cl100k_base (offline proxy)
|
||||
4. char/4 heuristic (last resort)
|
||||
"""
|
||||
if count_tokens_fn is not None:
|
||||
return count_tokens_fn, "injected"
|
||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
return _anthropic_count_tokens, "anthropic-count-tokens"
|
||||
try:
|
||||
import tiktoken # noqa: F401
|
||||
return _tiktoken_count, "tiktoken-cl100k-proxy"
|
||||
except ImportError:
|
||||
return _char4_count, "heuristic-char4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- per-turn cost
|
||||
|
||||
def _session_start_overhead_tokens(wake_depth: str) -> int:
|
||||
"""Session-start payload size charged to turn 1 per wake_depth mode.
|
||||
|
||||
Numbers sourced from measurements (05-03-SUMMARY.md table):
|
||||
- minimal : 24 tok (lazy pointers only)
|
||||
- standard : 1388 tok (eager Phase-1 L0+L1+L2+rich_club)
|
||||
- deep : ~2000 tok (rich_club budget lifted per D5-02)
|
||||
|
||||
Rounded to the cache metric exactly so the numbers are
|
||||
consistent with M-01's reported warm session-start row.
|
||||
"""
|
||||
if wake_depth == "minimal":
|
||||
return 24
|
||||
if wake_depth == "standard":
|
||||
return 1388
|
||||
return 2000 # deep
|
||||
|
||||
|
||||
def _simulate_turn(
|
||||
turn: dict,
|
||||
counter: Callable[[str], int],
|
||||
) -> int:
|
||||
"""Compose the per-turn text that Claude sees and count its tokens."""
|
||||
parts: list[str] = [
|
||||
_POST_TOK15_TOOL_DESCRIPTIONS, # constant per-turn overhead
|
||||
turn["input"], # user / call payload
|
||||
_RESULT_BODIES.get(turn["kind"], ""), # synthetic result body
|
||||
]
|
||||
return int(counter("\n".join(p for p in parts if p)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- public API
|
||||
|
||||
def run_total_session_cost(
|
||||
*,
|
||||
wake_depth: str = "minimal",
|
||||
mempalace_ref: int | None = None,
|
||||
claude_mem_ref: int | None = None,
|
||||
measure_mempalace: bool = False,
|
||||
measure_claude_mem: bool = False,
|
||||
count_tokens_fn: Callable[[str], int] | None = None,
|
||||
) -> dict:
|
||||
"""Run the fixed 10-turn script at the given wake_depth.
|
||||
|
||||
Parameters:
|
||||
wake_depth: "minimal" | "standard" | "deep" — selects session-start
|
||||
payload size charged to turn 1.
|
||||
mempalace_ref / claude_mem_ref: optional manually-supplied reference
|
||||
totals (stored as ``refs["*_manual"]`` for audit). When no live
|
||||
measurement exists, a manual int is the comparator for ``passed``.
|
||||
measure_mempalace / measure_claude_mem: when True, invoke the live
|
||||
subprocess adapter and store the result as ``refs["*_measured"]``.
|
||||
A live measurement supersedes the manual ref as the comparator.
|
||||
count_tokens_fn: optional counter injection (tests use a fixed
|
||||
function to decouple assertions from tokeniser drift).
|
||||
"""
|
||||
counter, mode = _select_counter(count_tokens_fn)
|
||||
|
||||
per_turn: list[int] = []
|
||||
for i, turn in enumerate(_SCRIPT):
|
||||
t = _simulate_turn(turn, counter)
|
||||
if i == 0:
|
||||
# Turn 1 pays the session-start overhead per wake_depth.
|
||||
t += _session_start_overhead_tokens(wake_depth)
|
||||
per_turn.append(int(t))
|
||||
|
||||
total = int(sum(per_turn))
|
||||
|
||||
refs: dict[str, int] = {}
|
||||
passed = True
|
||||
|
||||
# Live measurements first so we can decide whether the manual int should
|
||||
# be recorded under the legacy key ("mempalace") or the audit-trail key
|
||||
# ("mempalace_manual", used when BOTH a measurement AND a manual ref are
|
||||
# supplied per Test 6).
|
||||
mp_measured: int | None = None
|
||||
cm_measured: int | None = None
|
||||
if measure_mempalace:
|
||||
mp_measured = _run_mempalace_adapter(_SCRIPT, counter)
|
||||
if mp_measured is not None:
|
||||
refs["mempalace_measured"] = int(mp_measured)
|
||||
if measure_claude_mem:
|
||||
cm_measured = _run_claude_mem_adapter(_SCRIPT, counter)
|
||||
if cm_measured is not None:
|
||||
refs["claude_mem_measured"] = int(cm_measured)
|
||||
|
||||
# Manual refs. Back-compat with when no live measurement is
|
||||
# present, the manual int lands under the legacy "mempalace" / "claude_mem"
|
||||
# key so pre-existing downstream consumers (and tests) keep working.
|
||||
if mempalace_ref is not None:
|
||||
key = "mempalace_manual" if mp_measured is not None else "mempalace"
|
||||
refs[key] = int(mempalace_ref)
|
||||
if claude_mem_ref is not None:
|
||||
key = "claude_mem_manual" if cm_measured is not None else "claude_mem"
|
||||
refs[key] = int(claude_mem_ref)
|
||||
|
||||
# Gate logic: measured > legacy manual > audit-trail manual > no gate.
|
||||
mp_gate = refs.get(
|
||||
"mempalace_measured", refs.get("mempalace", refs.get("mempalace_manual"))
|
||||
)
|
||||
cm_gate = refs.get(
|
||||
"claude_mem_measured", refs.get("claude_mem", refs.get("claude_mem_manual"))
|
||||
)
|
||||
if mp_gate is not None and total > mp_gate:
|
||||
passed = False
|
||||
if cm_gate is not None and total > cm_gate:
|
||||
passed = False
|
||||
|
||||
return {
|
||||
"adapter": "iai-mcp",
|
||||
"wake_depth": wake_depth,
|
||||
"total_tokens": total,
|
||||
"per_turn": per_turn,
|
||||
"mode": mode,
|
||||
"refs": refs,
|
||||
"passed": passed,
|
||||
"script_name": SCRIPT_NAME,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- CLI
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="bench.total_session_cost",
|
||||
description=(
|
||||
"OPS-12 / total session cost bench. Fixed 10-turn "
|
||||
"representative script (D5-08); measures IAI-MCP token cost "
|
||||
"at wake_depth minimal|standard|deep and optionally compares "
|
||||
"to supplied mempalace / claude-mem reference totals."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wake-depth",
|
||||
choices=("minimal", "standard", "deep"),
|
||||
default="minimal",
|
||||
help="session-start payload size (default minimal per D5-02)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref-mempalace",
|
||||
dest="mempalace_ref",
|
||||
type=int, default=None,
|
||||
help="mempalace reference total (tokens) for the comparative gate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref-claude-mem",
|
||||
dest="claude_mem_ref",
|
||||
type=int, default=None,
|
||||
help="claude-mem reference total (tokens) for the comparative gate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--measure-mempalace",
|
||||
action="store_true",
|
||||
help=(
|
||||
"attempt a live mempalace subprocess run to fill the "
|
||||
"reference column; on failure emits a bench_adapter_unavailable "
|
||||
"stderr event and records no measurement"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--measure-claude-mem",
|
||||
action="store_true",
|
||||
help=(
|
||||
"attempt a live claude-mem subprocess run; identical fallback "
|
||||
"shape to --measure-mempalace"
|
||||
),
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
result = run_total_session_cost(
|
||||
wake_depth=args.wake_depth,
|
||||
mempalace_ref=args.mempalace_ref,
|
||||
claude_mem_ref=args.claude_mem_ref,
|
||||
measure_mempalace=args.measure_mempalace,
|
||||
measure_claude_mem=args.measure_claude_mem,
|
||||
)
|
||||
print(json.dumps(result))
|
||||
return 0 if result["passed"] else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
253
bench/trajectory.py
Normal file
253
bench/trajectory.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""bench/trajectory.py -- trajectory benchmark (Plan 02-04 Task 4, D-33).
|
||||
|
||||
Generates a deterministic 30-session synthetic corpus following autism/NT
|
||||
interaction pattern models and runs M1..M6 aggregation across it. Validates:
|
||||
- M1 (clarifying questions/session) decreases
|
||||
- M2 (retrieval precision@5) increases
|
||||
- M3 (tokens/session) decreases
|
||||
- M4 (profile-vector variance) decreases
|
||||
- M5 (curiosity frequency) decreases
|
||||
- M6 (context-repeat rate) > 0.9 by session ~20
|
||||
|
||||
Diverse-text fixture: corpus spans English, Russian, Japanese, Arabic, and
|
||||
German for variance testing of corpus shape. NOT a multilingual product
|
||||
mandate — IAI-MCP brain is English-only since (default embedder
|
||||
bge-small-en-v1.5). Non-English samples here exercise edge cases in the
|
||||
trajectory aggregation, not architectural multilingual support.
|
||||
|
||||
CLI:
|
||||
python -m bench.trajectory [--n-sessions 30] [--real-logs PATH]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
|
||||
# reproducible corpus from seed=42.
|
||||
DEFAULT_SEED = 42
|
||||
|
||||
# Diverse-text samples for corpus-shape variance testing.
|
||||
# Brain is English-only since Plan 05-08; non-English entries here are
|
||||
# fixture diversity, not a multilingual product feature.
|
||||
_LANG_SAMPLES: dict[str, list[str]] = {
|
||||
"en": [
|
||||
"authentication uses JWT with refresh rotation",
|
||||
"db migration scheduled for Friday evening",
|
||||
"web cache invalidation on deploy",
|
||||
"cli subcommand for trajectory aggregation",
|
||||
],
|
||||
"ru": [
|
||||
"авторизация использует JWT с обновлением токена",
|
||||
"миграция базы данных запланирована на пятницу",
|
||||
"инвалидация кэша при деплое",
|
||||
],
|
||||
"ja": [
|
||||
"認証はJWTとリフレッシュローテーションを使用",
|
||||
"データベース移行は金曜日の夕方に予定",
|
||||
],
|
||||
"ar": [
|
||||
"المصادقة تستخدم JWT مع تدوير الرمز",
|
||||
"ترحيل قاعدة البيانات مجدول ليوم الجمعة",
|
||||
],
|
||||
"de": [
|
||||
"Authentifizierung verwendet JWT mit Token-Rotation",
|
||||
"Datenbankmigration für Freitagabend geplant",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def generate_synthetic_corpus(
|
||||
n_sessions: int = 30,
|
||||
seed: int = DEFAULT_SEED,
|
||||
) -> list[dict]:
|
||||
"""Build a deterministic 30-session corpus.
|
||||
|
||||
Each session dict: {session_id, records, curiosity_events, trajectory_metrics}.
|
||||
|
||||
Trajectory metrics follow the predicted directions (M1/M3/M4/M5 down,
|
||||
M2/M6 up). This gives downstream run_trajectory_bench a clean signal to
|
||||
validate.
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
languages = list(_LANG_SAMPLES.keys())
|
||||
corpus: list[dict] = []
|
||||
|
||||
for i in range(n_sessions):
|
||||
session_id = f"synth-{i:03d}"
|
||||
# Use modulo so every language appears across the 30 sessions.
|
||||
# Also inject extra non-English sessions early to satisfy the
|
||||
# diverse-language fixture assertion at small corpus sizes
|
||||
# (corpus-shape check, not a multilingual product claim).
|
||||
if i < len(languages):
|
||||
lang = languages[i]
|
||||
else:
|
||||
lang = rng.choice(languages)
|
||||
samples = _LANG_SAMPLES[lang]
|
||||
|
||||
n_records = rng.randint(3, 8)
|
||||
records: list[dict] = []
|
||||
for k in range(n_records):
|
||||
text = samples[k % len(samples)]
|
||||
records.append({
|
||||
"id": str(uuid4()),
|
||||
"literal_surface": text,
|
||||
"language": lang,
|
||||
"tags": [f"topic:t{k % 3}", f"session:{session_id}"],
|
||||
})
|
||||
|
||||
# Curiosity events decay over sessions (M5 downward trend).
|
||||
n_curiosity = max(0, 6 - (i // 5))
|
||||
curiosity_events: list[dict] = []
|
||||
for _ in range(n_curiosity):
|
||||
curiosity_events.append({
|
||||
"question_id": str(uuid4()),
|
||||
"entropy": float(0.5 + rng.random() * 0.5),
|
||||
})
|
||||
|
||||
# Predicted M1..M6 directions.
|
||||
progress = i / max(1, n_sessions - 1) # 0.0 at start -> 1.0 at end
|
||||
m1 = max(0.5, 6.0 * (1.0 - progress)) # clarifying Qs down
|
||||
m2 = min(1.0, 0.4 + progress * 0.5) # precision@5 up
|
||||
m3 = max(1000.0, 3000.0 * (1.0 - 0.6 * progress)) # tokens down
|
||||
m4 = max(0.05, 0.5 * (1.0 - progress)) # variance down
|
||||
m5 = float(n_curiosity) # frequency down
|
||||
m6 = min(1.0, 0.4 + progress * 0.55) # repeat rate up
|
||||
|
||||
corpus.append({
|
||||
"session_id": session_id,
|
||||
"records": records,
|
||||
"curiosity_events": curiosity_events,
|
||||
"trajectory_metrics": {
|
||||
"m1": m1, "m2": m2, "m3": m3,
|
||||
"m4": m4, "m5": m5, "m6": m6,
|
||||
},
|
||||
})
|
||||
return corpus
|
||||
|
||||
|
||||
def run_trajectory_bench(
|
||||
corpus: list[dict],
|
||||
store_path: Path | str | None = None,
|
||||
) -> dict:
|
||||
"""Apply the corpus to a fresh store and aggregate M1..M6 trends.
|
||||
|
||||
Returns {m1_trend, m2_trend, ..., m6_trend, passed}. Trends are lists of
|
||||
floats in session order. `passed` reflects the 6 predicted directions.
|
||||
"""
|
||||
from iai_mcp.trajectory import record_session_metrics
|
||||
|
||||
cleanup: tempfile.TemporaryDirectory | None = None
|
||||
if store_path is None:
|
||||
cleanup = tempfile.TemporaryDirectory(prefix="iai-bench-traj-")
|
||||
path = Path(cleanup.name)
|
||||
else:
|
||||
path = Path(store_path)
|
||||
|
||||
try:
|
||||
store = MemoryStore(path=path)
|
||||
|
||||
m1t: list[float] = []
|
||||
m2t: list[float] = []
|
||||
m3t: list[float] = []
|
||||
m4t: list[float] = []
|
||||
m5t: list[float] = []
|
||||
m6t: list[float] = []
|
||||
for session in corpus:
|
||||
sid = session["session_id"]
|
||||
# Emit curiosity_question events so M1 compute_* can find them.
|
||||
for q in session["curiosity_events"]:
|
||||
write_event(
|
||||
store,
|
||||
kind="curiosity_question",
|
||||
data={
|
||||
"question_id": q["question_id"],
|
||||
"text": "",
|
||||
"tier": "question",
|
||||
"entropy": q["entropy"],
|
||||
"turn": 1,
|
||||
"triggered_by": [],
|
||||
},
|
||||
severity="info",
|
||||
session_id=sid,
|
||||
)
|
||||
# Record the synthetic metrics.
|
||||
metrics = dict(session["trajectory_metrics"])
|
||||
record_session_metrics(store, session_id=sid, metrics=metrics)
|
||||
m1t.append(metrics["m1"])
|
||||
m2t.append(metrics["m2"])
|
||||
m3t.append(metrics["m3"])
|
||||
m4t.append(metrics["m4"])
|
||||
m5t.append(metrics["m5"])
|
||||
m6t.append(metrics["m6"])
|
||||
|
||||
def _down(trend: list[float]) -> bool:
|
||||
return bool(trend) and trend[-1] < trend[0]
|
||||
|
||||
def _up(trend: list[float]) -> bool:
|
||||
return bool(trend) and trend[-1] > trend[0]
|
||||
|
||||
# success conditions.
|
||||
passed = (
|
||||
_down(m1t) and _up(m2t) and _down(m3t)
|
||||
and _down(m4t) and _down(m5t) and _up(m6t)
|
||||
)
|
||||
return {
|
||||
"m1_trend": m1t,
|
||||
"m2_trend": m2t,
|
||||
"m3_trend": m3t,
|
||||
"m4_trend": m4t,
|
||||
"m5_trend": m5t,
|
||||
"m6_trend": m6t,
|
||||
"passed": passed,
|
||||
}
|
||||
finally:
|
||||
if cleanup is not None:
|
||||
cleanup.cleanup()
|
||||
|
||||
|
||||
def main(
|
||||
n_sessions: int = 30,
|
||||
seed: int = DEFAULT_SEED,
|
||||
real_logs_path: str | None = None,
|
||||
store_path: Path | str | None = None,
|
||||
) -> int:
|
||||
"""CLI entry. --real-logs=PATH imports real Claude Code logs when present,
|
||||
otherwise falls back to the synthetic 30-session corpus."""
|
||||
if real_logs_path and Path(real_logs_path).exists():
|
||||
# Real-log import path stub -- owns the ingestion schema.
|
||||
# Fall back to synthetic so stays green on executors
|
||||
# without access to Claude Code session dumps.
|
||||
corpus = generate_synthetic_corpus(n_sessions=n_sessions, seed=seed)
|
||||
else:
|
||||
corpus = generate_synthetic_corpus(n_sessions=n_sessions, seed=seed)
|
||||
|
||||
out = run_trajectory_bench(corpus, store_path=store_path)
|
||||
print(json.dumps(out))
|
||||
return 0 if out["passed"] else 1
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(prog="bench.trajectory")
|
||||
parser.add_argument("--n-sessions", type=int, default=30)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED)
|
||||
parser.add_argument("--real-logs", dest="real_logs", default=None)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
sys.exit(main(
|
||||
n_sessions=args.n_sessions,
|
||||
seed=args.seed,
|
||||
real_logs_path=args.real_logs,
|
||||
))
|
||||
316
bench/verbatim.py
Normal file
316
bench/verbatim.py
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
"""bench/verbatim.py -- benchmark harness + diagnostics.
|
||||
|
||||
Simulates a session gap by inserting N pinned records, flooding the store with
|
||||
`session_gap * noise_per_session` unrelated records, then retrieving each
|
||||
pinned record by its own literal_surface as the cue. Counts byte-exact matches.
|
||||
|
||||
Target: >= ACCURACY_FLOOR (0.99) on pinned records -- / MEM-10.
|
||||
|
||||
Exit codes:
|
||||
- 0 if accuracy >= 0.99
|
||||
- 1 otherwise
|
||||
|
||||
JSON output (one line to stdout):
|
||||
{"accuracy": float, "n_records": int, "session_gap": int,
|
||||
"hits_exact": int, "passed": bool, "floor": 0.99, "noise_mode": str,
|
||||
"skip_l0_seed": bool, "storage_direct": bool, "k": int}
|
||||
|
||||
Plan 05-01 (D5-01) diagnostic flags -- BENCH-ONLY (no production change):
|
||||
--skip-l0-seed : skip _seed_l0_identity to isolate L0 crowding (effect b)
|
||||
--storage-direct : bypass recall(), call store.query_similar directly
|
||||
(isolates provenance-write amplification, effect c)
|
||||
--n : override n_records (default 20)
|
||||
--gap : override session_gap (default 20)
|
||||
--noise-per-session : override noise_per_session (default 10)
|
||||
--k : override k_hits (default max(n_records + 10, 20))
|
||||
|
||||
Design note -- why we bypass dispatch("memory_recall"):
|
||||
The Plan-02 core.memory_recall routes non-empty stores through recall_for_response
|
||||
(Phase 8 entry-point split) which instantiates an Embedder() (downloads
|
||||
bge-small-en-v1.5 from HuggingFace
|
||||
on first call). That's fine for a real runtime but wrong for an offline bench:
|
||||
we need to measure storage-layer verbatim-recall correctness, not embedder
|
||||
warm-up latency. So we call `retrieve.recall` directly with a fixed cue
|
||||
embedding aligned with the pinned records (all-ones vector).
|
||||
|
||||
H-03 noise model (review finding, 2026-04-16):
|
||||
The original noise vector was [-0.5]^384, which gives cosine=-1.0 against the
|
||||
[1.0]^384 cue -- making pinned-vs-noise discrimination a geometric artifact
|
||||
rather than a measurement of the storage layer. The fix uses seeded
|
||||
numpy.random.standard_normal(EMBED_DIM) normalised to unit length. Against a
|
||||
[1.0]^384 cue the expected cosine of a random unit vector is 0 with stddev
|
||||
1/sqrt(EMBED_DIM) ~= 0.05 -- realistic noise geometry, but pinned still wins
|
||||
because cos=+1 >> cos~=0. The bench remains honest about what it measures
|
||||
(literal_surface round-trip under realistic embedding noise, given a fixed
|
||||
cue). A real bge-small-en-v1.5 bench is deferred to Phase 2.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.core import _seed_l0_identity
|
||||
from iai_mcp.retrieve import recall
|
||||
from iai_mcp.store import EMBED_DIM, MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
ACCURACY_FLOOR = 0.99 # OPS-04
|
||||
NOISE_SEED = 20260416 # fixed for reproducibility across runs / CI
|
||||
|
||||
|
||||
def _make_pinned(text: str, dim: int = EMBED_DIM) -> MemoryRecord:
|
||||
"""A pinned verbatim record -- detail_level=5, never_merge=True, never_decay=True.
|
||||
|
||||
Uses a fixed all-ones embedding so the cue (also all-ones) maxes cosine to
|
||||
every pinned record simultaneously. The recall ranking then scores by
|
||||
insertion order / stability -- but the literal_surface substring match is
|
||||
the only correctness signal we care about.
|
||||
|
||||
language="en" required. `dim` parameterised so callers
|
||||
can match a legacy 384d store or the 1024d default; default is
|
||||
`EMBED_DIM` (the current module constant). Unit tests that construct a
|
||||
fresh isolated store pick up the default; bench main() queries the
|
||||
store instance's embed_dim so a pre-existing ~/.iai-mcp store (possibly
|
||||
still at 384d prior to migration) works unchanged.
|
||||
"""
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="semantic",
|
||||
literal_surface=text,
|
||||
aaak_index="",
|
||||
embedding=[1.0] * dim,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=5,
|
||||
pinned=True,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=True,
|
||||
never_merge=True,
|
||||
provenance=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
tags=["benchmark", "pinned"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def _random_unit_vector(rng: np.random.Generator, dim: int = EMBED_DIM) -> list[float]:
|
||||
"""Unit-norm Gaussian vector with configurable dim.
|
||||
|
||||
Expected cosine vs [1.0]^dim cue: 0 with stddev 1/sqrt(dim) ~= 0.05 at 384d
|
||||
or ~= 0.03 at 1024d. Uses the provided seeded Generator so every run
|
||||
reproduces identical noise.
|
||||
"""
|
||||
v = rng.standard_normal(dim)
|
||||
v = v / np.linalg.norm(v)
|
||||
return v.tolist()
|
||||
|
||||
|
||||
def _make_noise(i: int, rng: np.random.Generator, dim: int = EMBED_DIM) -> MemoryRecord:
|
||||
"""Noise record with a random unit-vector embedding (H-03 honesty fix).
|
||||
|
||||
Previous implementation used [-0.5]^EMBED_DIM which gave cosine=-1 against the
|
||||
cue, making pinned-vs-noise discrimination trivial by geometry. Seeded
|
||||
Gaussian unit vectors reproduce deterministically and approximate the
|
||||
orthogonality-on-average of real embeddings.
|
||||
|
||||
language="en" required.
|
||||
"""
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=f"unrelated session noise record #{i}: " + ("y " * 20),
|
||||
aaak_index="",
|
||||
embedding=_random_unit_vector(rng, dim=dim),
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
tags=[],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def run_verbatim_bench(
|
||||
store: MemoryStore | None = None,
|
||||
n_records: int = 20,
|
||||
session_gap: int = 20,
|
||||
noise_per_session: int = 10,
|
||||
seed: int = NOISE_SEED,
|
||||
*,
|
||||
skip_l0_seed: bool = False,
|
||||
storage_direct: bool = False,
|
||||
k: int | None = None,
|
||||
) -> dict:
|
||||
"""Run the verbatim-recall benchmark.
|
||||
|
||||
Parameters:
|
||||
store: optional; isolated tmp_path store in tests, default MemoryStore in CLI.
|
||||
n_records: how many pinned records to store and recall.
|
||||
session_gap: how many "sessions" of noise to interpose between write and recall.
|
||||
noise_per_session: noise records per simulated session.
|
||||
seed: RNG seed for noise vectors (H-03: reproducibility across runs).
|
||||
skip_l0_seed: D5-01 effect (b) isolation -- skip the L0 identity
|
||||
seed so pinned records are not competed against by a fixed-embedding
|
||||
identity record. BENCH-SCOPE ONLY; production _seed_l0_identity is
|
||||
unchanged.
|
||||
storage_direct: D5-01 effect (c) isolation -- bypass
|
||||
retrieve.recall() and call store.query_similar directly, so the
|
||||
per-hit provenance write amplification is removed from the hot loop.
|
||||
BENCH-SCOPE ONLY; production recall() is unchanged.
|
||||
k: override the top-k passed into recall(k_hits=K) or query_similar(k=K);
|
||||
None keeps the historic default of max(n_records + 10, 20).
|
||||
|
||||
Returns a dict as documented in the module docstring.
|
||||
"""
|
||||
s = store if store is not None else MemoryStore()
|
||||
if not skip_l0_seed:
|
||||
_seed_l0_identity(s)
|
||||
|
||||
# consult the store's actual embedding dim. An existing Phase 1
|
||||
# store may still have 384d records pre-D-35-migration; a fresh store has
|
||||
# the default (1024d). Match either transparently.
|
||||
dim = s.embed_dim
|
||||
|
||||
pinned_texts = [
|
||||
f"Alice said on day {i}: verbatim phrase #{i}-{'x' * 10}"
|
||||
for i in range(n_records)
|
||||
]
|
||||
pinned_records = [_make_pinned(t, dim=dim) for t in pinned_texts]
|
||||
for r in pinned_records:
|
||||
s.insert(r)
|
||||
|
||||
# Simulate session_gap * noise_per_session unrelated records.
|
||||
# H-03: seeded RNG shared across every noise draw so results are reproducible.
|
||||
rng = np.random.default_rng(seed)
|
||||
for session_idx in range(session_gap):
|
||||
for j in range(noise_per_session):
|
||||
s.insert(_make_noise(session_idx * noise_per_session + j, rng, dim=dim))
|
||||
|
||||
cue_emb = [1.0] * dim
|
||||
# k must be >= n_records for every pinned record to have a chance of surfacing.
|
||||
# Plus a buffer for the L0 seed + anti-hits tail, so we retrieve a generous top-k.
|
||||
effective_k = k if k is not None else max(n_records + 10, 20)
|
||||
hits_exact = 0
|
||||
for text in pinned_texts:
|
||||
if storage_direct:
|
||||
# D5-01 (c): bypass recall() -> no per-hit provenance write amplification.
|
||||
raw = s.query_similar(cue_emb, k=effective_k)
|
||||
literal_surfaces = [rec.literal_surface for rec, _score in raw]
|
||||
else:
|
||||
# retrieve.recall now defaults to mode='verbatim'
|
||||
# (conservative North-Star fallback). The bench's _make_pinned
|
||||
# uses tier='semantic' which the verbatim filter would drop.
|
||||
# The bench is measuring "verbatim TEXT exact-match recall under
|
||||
# noise" — that is independent of the cue-router's verbatim/concept
|
||||
# mode (the bench uses synthetic cues, not classifier-tagged
|
||||
# natural-language queries). Pin mode='concept' so the bench
|
||||
# measures what it has always measured.
|
||||
resp = recall(
|
||||
store=s,
|
||||
cue_embedding=cue_emb,
|
||||
cue_text=text,
|
||||
session_id="bench-verbatim",
|
||||
budget_tokens=5000,
|
||||
k_hits=effective_k,
|
||||
k_anti=3,
|
||||
mode="concept",
|
||||
)
|
||||
literal_surfaces = [h.literal_surface for h in resp.hits]
|
||||
if text in literal_surfaces:
|
||||
hits_exact += 1
|
||||
|
||||
accuracy = hits_exact / n_records if n_records > 0 else 0.0
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"n_records": n_records,
|
||||
"session_gap": session_gap,
|
||||
"noise_per_session": noise_per_session,
|
||||
"hits_exact": hits_exact,
|
||||
"passed": accuracy >= ACCURACY_FLOOR,
|
||||
"floor": ACCURACY_FLOOR,
|
||||
"noise_mode": "random-unit-vectors",
|
||||
"noise_seed": seed,
|
||||
# diagnostic traceability keys.
|
||||
"skip_l0_seed": bool(skip_l0_seed),
|
||||
"storage_direct": bool(storage_direct),
|
||||
"k": int(effective_k),
|
||||
}
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="bench.verbatim",
|
||||
description="OPS-04 / verbatim recall benchmark + diagnostics",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-l0-seed",
|
||||
action="store_true",
|
||||
help="D5-01 diagnostic: skip _seed_l0_identity to isolate L0 crowding effect",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--storage-direct",
|
||||
action="store_true",
|
||||
help="D5-01 diagnostic: bypass recall(), call store.query_similar directly",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", "--n-records",
|
||||
dest="n_records",
|
||||
type=int,
|
||||
default=20,
|
||||
help="pinned record count (default 20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gap", "--session-gap",
|
||||
dest="session_gap",
|
||||
type=int,
|
||||
default=20,
|
||||
help="session gap -- how many noise sessions between writes and recall (default 20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise-per-session",
|
||||
type=int,
|
||||
default=10,
|
||||
help="noise records per simulated session (default 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="override k_hits (default: max(n_records + 10, 20))",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args(argv)
|
||||
result = run_verbatim_bench(
|
||||
n_records=args.n_records,
|
||||
session_gap=args.session_gap,
|
||||
noise_per_session=args.noise_per_session,
|
||||
skip_l0_seed=args.skip_l0_seed,
|
||||
storage_direct=args.storage_direct,
|
||||
k=args.k,
|
||||
)
|
||||
print(json.dumps(result))
|
||||
return 0 if result["passed"] else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
140
deploy/hooks/iai-mcp-session-capture.sh
Executable file
140
deploy/hooks/iai-mcp-session-capture.sh
Executable file
|
|
@ -0,0 +1,140 @@
|
|||
#!/usr/bin/env bash
|
||||
# IAI-MCP Stop hook — ambient WRITE-side capture (Plan 06 + Phase 7.1).
|
||||
#
|
||||
# Fires when a Claude Code session ends. Reads the session's JSONL transcript,
|
||||
# batch-captures user + assistant turns into the iai-mcp episodic tier through
|
||||
# `iai-mcp capture-transcript --no-spawn`. NEVER spawns a daemon (Phase 7.1 R3).
|
||||
# If the daemon is unreachable, the call defers events to
|
||||
# ~/.iai-mcp/.deferred-captures/ for the daemon to drain on next socket
|
||||
# activation (handled by drain_deferred_captures in daemon.main + _tick_body
|
||||
# WAKE handler — Plan 07.1-06).
|
||||
#
|
||||
# Fail-safe by design: any error exits 0 so session teardown is never blocked.
|
||||
# Logs go to ~/.iai-mcp/logs/capture-YYYY-MM-DD.log for audit.
|
||||
#
|
||||
# Hook payload (stdin JSON from Claude Code) contains:
|
||||
# - session_id (UUID of the session that just ended)
|
||||
# - transcript_path (absolute path to the session JSONL) — available in
|
||||
# newer Claude Code builds; we fall back to scanning the
|
||||
# per-project transcript dir for the matching session_id.
|
||||
# - cwd (working directory at session end)
|
||||
|
||||
set -u # no -e: we must not abort on errors, fail-safe is paramount
|
||||
input=$(cat 2>/dev/null || true)
|
||||
|
||||
# Best-effort jq; fall back to Python if jq missing.
|
||||
extract() {
|
||||
local key=$1
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
printf '%s' "$input" | jq -r ".${key} // empty" 2>/dev/null
|
||||
else
|
||||
printf '%s' "$input" | /usr/bin/python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
d = json.load(sys.stdin)
|
||||
print(d.get('${key}', '') or '')
|
||||
except Exception:
|
||||
print('')
|
||||
" 2>/dev/null
|
||||
fi
|
||||
}
|
||||
|
||||
session_id=$(extract "session_id")
|
||||
transcript_path=$(extract "transcript_path")
|
||||
cwd=$(extract "cwd")
|
||||
|
||||
# Fallback: locate transcript if the hook payload didn't include its path.
|
||||
# Claude Code stores transcripts under ~/.claude/projects/{cwd-hash}/{uuid}.jsonl
|
||||
if [[ -z "$transcript_path" && -n "$session_id" ]]; then
|
||||
projects_dir="$HOME/.claude/projects"
|
||||
if [[ -d "$projects_dir" ]]; then
|
||||
# Look for the most recent file whose basename starts with session_id.
|
||||
# ls -t (mtime newest first). Avoid `find` per the project's no-grep hook.
|
||||
for d in "$projects_dir"/*/; do
|
||||
candidate="${d}${session_id}.jsonl"
|
||||
if [[ -f "$candidate" ]]; then
|
||||
transcript_path="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
mkdir -p "$HOME/.iai-mcp/logs" 2>/dev/null || true
|
||||
log="$HOME/.iai-mcp/logs/capture-$(date -u +%Y-%m-%d).log"
|
||||
ts=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
{
|
||||
echo "---"
|
||||
echo "$ts session=$session_id cwd=$cwd transcript=$transcript_path"
|
||||
} >> "$log" 2>/dev/null
|
||||
|
||||
# Skip if we couldn't find anything to capture.
|
||||
if [[ -z "$transcript_path" || ! -f "$transcript_path" ]]; then
|
||||
echo "$ts skipped: no transcript found" >> "$log" 2>/dev/null
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Locate the project's venv-installed `iai-mcp` CLI. Cache the last-known-good
|
||||
# path in ~/.iai-mcp/.cli-path to avoid re-scanning on every session end.
|
||||
cli_cache="$HOME/.iai-mcp/.cli-path"
|
||||
iai_cli=""
|
||||
if [[ -f "$cli_cache" ]]; then
|
||||
cached=$(cat "$cli_cache" 2>/dev/null || true)
|
||||
[[ -x "$cached" ]] && iai_cli="$cached"
|
||||
fi
|
||||
if [[ -z "$iai_cli" ]]; then
|
||||
# Resolve via PATH first (covers ~/.local/bin/iai-mcp installed by scripts/install.sh)
|
||||
path_cli="$(command -v iai-mcp 2>/dev/null || true)"
|
||||
if [[ -n "$path_cli" && -x "$path_cli" ]]; then
|
||||
iai_cli="$path_cli"
|
||||
else
|
||||
# Fall back to common clone locations
|
||||
for candidate in \
|
||||
"$HOME/.local/bin/iai-mcp" \
|
||||
"$HOME/iai-mcp/.venv/bin/iai-mcp" \
|
||||
"$HOME/IAI-MCP/.venv/bin/iai-mcp" \
|
||||
"/usr/local/bin/iai-mcp" \
|
||||
"/opt/homebrew/bin/iai-mcp"; do
|
||||
if [[ -x "$candidate" ]]; then
|
||||
iai_cli="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
if [[ -n "$iai_cli" ]]; then
|
||||
printf '%s' "$iai_cli" > "$cli_cache" 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "$iai_cli" ]]; then
|
||||
echo "$ts skipped: iai-mcp CLI not found" >> "$log" 2>/dev/null
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Run capture with a 30s hard timeout — if it hangs, the session must still
|
||||
# end cleanly. `timeout` is in coreutils (macOS: brew install coreutils). We
|
||||
# fall back to a background kill loop if absent.
|
||||
if command -v timeout >/dev/null 2>&1; then
|
||||
result=$(timeout 30 "$iai_cli" capture-transcript --no-spawn \
|
||||
--session-id "$session_id" \
|
||||
--max-turns 200 \
|
||||
"$transcript_path" 2>&1)
|
||||
elif command -v gtimeout >/dev/null 2>&1; then
|
||||
result=$(gtimeout 30 "$iai_cli" capture-transcript --no-spawn \
|
||||
--session-id "$session_id" \
|
||||
--max-turns 200 \
|
||||
"$transcript_path" 2>&1)
|
||||
else
|
||||
result=$("$iai_cli" capture-transcript --no-spawn \
|
||||
--session-id "$session_id" \
|
||||
--max-turns 200 \
|
||||
"$transcript_path" 2>&1)
|
||||
fi
|
||||
rc=$?
|
||||
|
||||
{
|
||||
echo "$ts rc=$rc result=$result"
|
||||
} >> "$log" 2>/dev/null
|
||||
|
||||
exit 0
|
||||
83
deploy/launchd/com.iai-mcp.daemon.plist
Normal file
83
deploy/launchd/com.iai-mcp.daemon.plist
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.iai-mcp.daemon</string>
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/usr/local/bin/python3</string>
|
||||
<string>-m</string>
|
||||
<string>iai_mcp.daemon</string>
|
||||
</array>
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
<!--
|
||||
Phase 10.6 Plan 10.6-01 Task 1.7: KeepAlive policy uses ONLY
|
||||
`Crashed=true`. With this policy, launchd respawns ONLY on
|
||||
a non-zero exit (the new lifecycle state machine exits 0
|
||||
gracefully on Hibernation; an MCP wrapper kickstart is the
|
||||
sole wake mechanism in steady state).
|
||||
|
||||
Removed (was Phase 07.8): `SuccessfulExit=false`, which paired
|
||||
with the legacy 75/0 exit-code branching. Now that exit code
|
||||
is uniformly 0 for graceful shutdown, `SuccessfulExit=false`
|
||||
would put us in a respawn loop.
|
||||
-->
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>Crashed</key>
|
||||
<true/>
|
||||
</dict>
|
||||
|
||||
<key>ThrottleInterval</key>
|
||||
<integer>5</integer>
|
||||
|
||||
<key>ProcessType</key>
|
||||
<string>Background</string>
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>/Users/{USERNAME}/Library/Logs/iai-mcp-daemon.stdout.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/Users/{USERNAME}/Library/Logs/iai-mcp-daemon.stderr.log</string>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>/Users/{USERNAME}</string>
|
||||
|
||||
<!--
|
||||
Phase 10.6 Plan 10.6-01 Task 1.7: env-var update.
|
||||
REMOVED:
|
||||
- IAI_MCP_RSS_RESTART_THRESHOLD_MB (legacy RSS-watchdog gone)
|
||||
- IAI_DAEMON_IDLE_SHUTDOWN_SECS (legacy socket idle_watcher gone)
|
||||
- IAI_MCP_SKIP_STARTUP_OPTIMIZE (legacy boot-time optimize defer gone)
|
||||
ADDED (lifecycle cadence + sleep quarantine):
|
||||
- LIFECYCLE_DROWSY_AFTER_SEC (default 300 == 5 min)
|
||||
- LIFECYCLE_SLEEP_HEARTBEAT_IDLE_SEC (default 1800 == 30 min)
|
||||
- LIFECYCLE_HIBERNATE_AFTER_SEC (default 7200 == 2 h)
|
||||
- IAI_MCP_SLEEP_QUARANTINE_TTL_HOURS (default 24)
|
||||
-->
|
||||
<key>EnvironmentVariables</key>
|
||||
<dict>
|
||||
<key>PATH</key>
|
||||
<string>/usr/local/bin:/usr/bin:/bin</string>
|
||||
<key>IAI_MCP_STORE</key>
|
||||
<string>/Users/{USERNAME}/.iai-mcp</string>
|
||||
<key>HOME</key>
|
||||
<string>/Users/{USERNAME}</string>
|
||||
<key>LANG</key>
|
||||
<string>en_US.UTF-8</string>
|
||||
<key>LIFECYCLE_DROWSY_AFTER_SEC</key>
|
||||
<string>300</string>
|
||||
<key>LIFECYCLE_SLEEP_HEARTBEAT_IDLE_SEC</key>
|
||||
<string>1800</string>
|
||||
<key>LIFECYCLE_HIBERNATE_AFTER_SEC</key>
|
||||
<string>7200</string>
|
||||
<key>IAI_MCP_SLEEP_QUARANTINE_TTL_HOURS</key>
|
||||
<string>24</string>
|
||||
</dict>
|
||||
</dict>
|
||||
</plist>
|
||||
39
deploy/systemd/iai-mcp-daemon.service
Normal file
39
deploy/systemd/iai-mcp-daemon.service
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# IAI-MCP Sleep Daemon -- systemd user unit (Plan 04-01, DAEMON-01)
|
||||
#
|
||||
# Install at ~/.config/systemd/user/iai-mcp-daemon.service, then:
|
||||
# systemctl --user daemon-reload
|
||||
# systemctl --user enable --now iai-mcp-daemon.service
|
||||
#
|
||||
# For survival past logout (headless servers):
|
||||
# loginctl enable-linger $USER
|
||||
#
|
||||
# C3 / guard: NO paid-API env var in Environment= lines. host_cli.py
|
||||
# scrubs the subprocess env at spawn time; the unit env is intentionally minimal.
|
||||
|
||||
[Unit]
|
||||
Description=IAI-MCP Sleep Daemon -- autonomous neural consolidation between sessions
|
||||
After=default.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/usr/bin/python3 -m iai_mcp.daemon
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=3
|
||||
|
||||
Environment="IAI_MCP_STORE=%h/.iai-mcp"
|
||||
Environment="LANG=en_US.UTF-8"
|
||||
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=iai-mcp-daemon
|
||||
|
||||
# Graceful shutdown: systemd default TimeoutStopSec is 90s; we tighten to 60s
|
||||
# so stop never kills us mid-Claude (subprocess timeout is 120s but the
|
||||
# daemon aborts the pending call cleanly on SIGTERM).
|
||||
TimeoutStopSec=60
|
||||
KillSignal=SIGTERM
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
BIN
logo.png
Normal file
BIN
logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 818 KiB |
1723
mcp-wrapper/package-lock.json
generated
Normal file
1723
mcp-wrapper/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
29
mcp-wrapper/package.json
Normal file
29
mcp-wrapper/package.json
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
{
|
||||
"name": "iai-mcp-wrapper",
|
||||
"version": "0.1.0",
|
||||
"description": "TypeScript MCP wrapper for IAI-MCP Python core (D-03)",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"bin": {
|
||||
"iai-mcp-wrapper": "dist/index.js"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "tsx src/index.ts",
|
||||
"typecheck": "tsc --noEmit",
|
||||
"test": "node --import tsx --test test/*.test.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.0.0",
|
||||
"zod": "^3.23.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.0.0",
|
||||
"typescript": "^5.4.0",
|
||||
"tsx": "^4.7.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
}
|
||||
463
mcp-wrapper/src/bridge.ts
Normal file
463
mcp-wrapper/src/bridge.ts
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
// Phase 7.1 — pure-connector bridge. NO spawn capability.
|
||||
// The daemon is launchd-managed (see scripts/install.sh).
|
||||
// Wrapper connects to ~/.iai-mcp/.daemon.sock with 5s timeout.
|
||||
// On connect failure, throws DaemonUnreachableError — does NOT
|
||||
// attempt to spawn a daemon (eliminating Phase 7's TOCTOU race).
|
||||
|
||||
import * as crypto from "node:crypto";
|
||||
import * as net from "node:net";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
|
||||
// HIGH-4 LOCKED (Plan 07-04 Task 1 Step A): env override is mandatory so
|
||||
// tests can isolate via tmp socket paths. The daemon-side honors the same
|
||||
// env (Plan 07-02 added it to socket_server.py:serve()).
|
||||
const DAEMON_SOCKET_PATH =
|
||||
process.env.IAI_DAEMON_SOCKET_PATH
|
||||
?? path.join(os.homedir(), ".iai-mcp", ".daemon.sock");
|
||||
const SOCKET_CONNECT_TIMEOUT_MS = 5000;
|
||||
// 5s — covers launchd socket-activation cold-start (~3s embedder load
|
||||
// + ~1s LanceDB open + buffer). launchd accepts the connection
|
||||
// immediately and queues the read until the daemon is ready, so a
|
||||
// single 5s timeout is sufficient even on a true cold start.
|
||||
// JSON-RPC 2.0 custom server-error code (-32099..-32000 reserved by spec for
|
||||
// implementation-defined server errors per jsonrpc.org/specification).
|
||||
const ERR_DAEMON_UNREACHABLE = -32002;
|
||||
|
||||
/**
|
||||
* Phase 7.1 — clean error class thrown when the daemon socket is not
|
||||
* reachable at start(). Replaces the pre-7.1 `daemon_spawn_failed`
|
||||
* generic Error. The error message points the user at the launchd
|
||||
* recovery commands. `code` matches the existing
|
||||
* `ERR_DAEMON_UNREACHABLE` JSON-RPC server-error constant so downstream
|
||||
* consumers (handleSocketDeath in-flight rejects, `iai-mcp doctor`)
|
||||
* can pattern-match on a single numeric code.
|
||||
*/
|
||||
export class DaemonUnreachableError extends Error {
|
||||
public code: number;
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "DaemonUnreachableError";
|
||||
this.code = ERR_DAEMON_UNREACHABLE;
|
||||
}
|
||||
}
|
||||
|
||||
interface RpcRequest {
|
||||
jsonrpc: "2.0";
|
||||
id: number;
|
||||
method: string;
|
||||
params: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface RpcResponse {
|
||||
jsonrpc: "2.0";
|
||||
id: number;
|
||||
result?: unknown;
|
||||
error?: { code: number; message: string };
|
||||
}
|
||||
|
||||
interface Pending {
|
||||
resolve: (v: unknown) => void;
|
||||
reject: (e: Error) => void;
|
||||
}
|
||||
|
||||
export class PythonCoreBridge {
|
||||
private sock: net.Socket | null = null;
|
||||
private nextId = 1;
|
||||
private pending = new Map<number, Pending>();
|
||||
private buffer = "";
|
||||
private reconnectAttempted = false;
|
||||
// V3-05 fix: serializes the at-most-one async reconnect from
|
||||
// handleSocketDeath. Concurrent call() awaits this promise BEFORE
|
||||
// checking !this.sock so a request landing in the gap between socket
|
||||
// close and reconnect-completion does NOT reject daemon_unreachable
|
||||
// when the daemon is actually healthy.
|
||||
private reconnectPromise: Promise<void> | null = null;
|
||||
// mcp-tools-list-empty-cache fix (2026-05-02): serializes concurrent
|
||||
// start() calls. Without this, the deferred-bridge-start ordering in
|
||||
// index.ts (multiple paths can trigger start: oninitialized,
|
||||
// CallToolRequest handler, top-level fire-and-forget) would each
|
||||
// observe `this.sock === null` and race independent connectWithTimeout
|
||||
// attempts. With it, the first caller drives the connect, every other
|
||||
// caller awaits the same promise. On reject the latch clears so the
|
||||
// next start() can retry (e.g. daemon came up later).
|
||||
private startPromise: Promise<void> | null = null;
|
||||
/** V3-06: consecutive JSON.parse failures on the NDJSON stream. */
|
||||
private parseErrorStreak = 0;
|
||||
private static readonly PARSE_ERROR_REJECT_THRESHOLD = 4;
|
||||
|
||||
// Allow overriding the Python interpreter via IAI_MCP_PYTHON for tests
|
||||
// that need to run the daemon against the project venv (see
|
||||
// test_mcp_tools.py).
|
||||
constructor(
|
||||
private readonly pythonCmd: string = process.env.IAI_MCP_PYTHON ?? "python3",
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Phase 7.1 — pure-connector start(). Socket-only; NO spawn capability.
|
||||
* Idempotent: a second call while a socket is alive is a no-op.
|
||||
*
|
||||
* Tries to connect to ~/.iai-mcp/.daemon.sock with a 5s timeout
|
||||
* (covers launchd socket-activation cold-start). On failure, throws
|
||||
* DaemonUnreachableError pointing the user at scripts/install.sh.
|
||||
*
|
||||
* The daemon's lifecycle is owned by launchd (see
|
||||
* scripts/com.iai-mcp.daemon.plist.template); the wrapper does not
|
||||
* spawn it under any condition (eliminates Phase 7's TOCTOU race when
|
||||
* N≥3 wrappers cold-start concurrently).
|
||||
*
|
||||
* mcp-tools-list-empty-cache fix (2026-05-02): start() is now safe to
|
||||
* call concurrently from multiple async paths (top-level boot fire,
|
||||
* server.oninitialized chain, CallToolRequest lazy-await). The first
|
||||
* caller drives the actual socket connect; the rest await the shared
|
||||
* `startPromise` and observe the same outcome. On reject the latch
|
||||
* is cleared so a future call() can retry once the daemon is up.
|
||||
*/
|
||||
async start(): Promise<void> {
|
||||
if (this.sock) return; // already connected; idempotent
|
||||
if (this.startPromise) return this.startPromise;
|
||||
this.startPromise = this._doStart();
|
||||
try {
|
||||
await this.startPromise;
|
||||
} catch (err) {
|
||||
// Allow a future caller to retry — the daemon may simply have been
|
||||
// slow to come up. Without clearing the latch, every subsequent
|
||||
// start() would short-circuit on the rejected memoised promise.
|
||||
this.startPromise = null;
|
||||
throw err;
|
||||
}
|
||||
// On success, leave startPromise set; further calls short-circuit on
|
||||
// `this.sock` truthiness (set inside _doStart before resolution).
|
||||
}
|
||||
|
||||
private async _doStart(): Promise<void> {
|
||||
// Reset reconnect-once latch so a fresh start() (e.g. after explicit
|
||||
// disconnect) is treated as a new session by handleSocketDeath.
|
||||
this.reconnectAttempted = false;
|
||||
|
||||
let sock: net.Socket;
|
||||
try {
|
||||
sock = await this.connectWithTimeout(
|
||||
DAEMON_SOCKET_PATH,
|
||||
SOCKET_CONNECT_TIMEOUT_MS,
|
||||
);
|
||||
} catch (e) {
|
||||
throw new DaemonUnreachableError(
|
||||
"iai-mcp daemon not running. "
|
||||
+ "Run: launchctl load -w ~/Library/LaunchAgents/com.iai-mcp.daemon.plist "
|
||||
+ "or run scripts/install.sh"
|
||||
);
|
||||
}
|
||||
this.sock = sock;
|
||||
this.attachSocketHandlers();
|
||||
}
|
||||
|
||||
/**
|
||||
* Promise wrapper around net.createConnection with a hard timeout.
|
||||
* Adapted from emitSessionOpen (lines below) — same silent-fail safety
|
||||
* pattern, but resolves with the live socket on success so the caller
|
||||
* can retain it for long-lived JSON-RPC traffic.
|
||||
*/
|
||||
private connectWithTimeout(
|
||||
socketPath: string,
|
||||
timeoutMs: number,
|
||||
): Promise<net.Socket> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const sock = net.createConnection(socketPath);
|
||||
const t = setTimeout(() => {
|
||||
try { sock.destroy(); } catch { /* ignore */ }
|
||||
reject(new Error("connect_timeout"));
|
||||
}, timeoutMs);
|
||||
sock.once("connect", () => {
|
||||
clearTimeout(t);
|
||||
resolve(sock);
|
||||
});
|
||||
sock.once("error", (e) => {
|
||||
clearTimeout(t);
|
||||
reject(e);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private attachSocketHandlers(): void {
|
||||
if (!this.sock) return;
|
||||
this.sock.on("data", (chunk: Buffer) => this.handleData(chunk));
|
||||
this.sock.on("close", () => this.handleSocketDeath("closed"));
|
||||
this.sock.on("error", (e: Error) => this.handleSocketDeath(`error: ${e.message}`));
|
||||
}
|
||||
|
||||
/**
|
||||
* NDJSON read buffer: socket data arrives in arbitrary chunks; we buffer
|
||||
* + split on `\n` manually. Each complete line is one JSON-RPC response
|
||||
* envelope.
|
||||
*/
|
||||
private handleData(chunk: Buffer): void {
|
||||
this.buffer += chunk.toString("utf-8");
|
||||
let nl: number;
|
||||
while ((nl = this.buffer.indexOf("\n")) >= 0) {
|
||||
const line = this.buffer.slice(0, nl).trim();
|
||||
this.buffer = this.buffer.slice(nl + 1);
|
||||
if (!line) continue;
|
||||
this.handleLine(line);
|
||||
}
|
||||
}
|
||||
|
||||
private handleLine(line: string): void {
|
||||
let msg: RpcResponse;
|
||||
try {
|
||||
msg = JSON.parse(line) as RpcResponse;
|
||||
} catch {
|
||||
this.parseErrorStreak += 1;
|
||||
if (
|
||||
this.parseErrorStreak >= PythonCoreBridge.PARSE_ERROR_REJECT_THRESHOLD
|
||||
&& this.pending.size > 0
|
||||
) {
|
||||
const oldestId = Math.min(...this.pending.keys());
|
||||
const handler = this.pending.get(oldestId);
|
||||
if (handler) {
|
||||
this.pending.delete(oldestId);
|
||||
handler.reject(
|
||||
new Error(
|
||||
`parse_error: ${PythonCoreBridge.PARSE_ERROR_REJECT_THRESHOLD} consecutive non-JSON lines on daemon socket; rejecting stale RPC id=${oldestId}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
try {
|
||||
process.stderr.write(
|
||||
`${JSON.stringify({
|
||||
event: "bridge_ndjson_parse_error_streak",
|
||||
threshold: PythonCoreBridge.PARSE_ERROR_REJECT_THRESHOLD,
|
||||
rejected_rpc_id: oldestId,
|
||||
})}\n`,
|
||||
);
|
||||
} catch { /* ignore */ }
|
||||
this.parseErrorStreak = 0;
|
||||
}
|
||||
return; // non-JSON line -- ignore (e.g., stray prints from daemon libs)
|
||||
}
|
||||
this.parseErrorStreak = 0;
|
||||
const handler = this.pending.get(msg.id);
|
||||
if (!handler) return;
|
||||
this.pending.delete(msg.id);
|
||||
if (msg.error) {
|
||||
handler.reject(new Error(msg.error.message));
|
||||
} else {
|
||||
handler.resolve(msg.result);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* R5 fail-loud: socket close/error rejects ALL pending Promises with
|
||||
* `daemon_unreachable` (-32002). D7-04 / SPEC R5: ONE reconnect attempt
|
||||
* (catches launchd KeepAlive respawn windows). After that attempt the
|
||||
* bridge stays degraded — every subsequent call returns
|
||||
* `daemon_unreachable` until the wrapper itself restarts.
|
||||
*/
|
||||
private handleSocketDeath(why: string): void {
|
||||
// Synchronous: every pending request fails LOUD immediately so callers
|
||||
// see daemon_unreachable instead of hanging forever (D7-04 / SPEC R5).
|
||||
const err = new Error(`daemon_unreachable: socket ${why} (code ${ERR_DAEMON_UNREACHABLE})`);
|
||||
for (const [, p] of this.pending) p.reject(err);
|
||||
this.pending.clear();
|
||||
this.sock = null;
|
||||
// Clear the start-latch so a future call() can retry start() (e.g.
|
||||
// after launchd respawn). reconnectPromise (below) handles the
|
||||
// immediate one-shot reconnect; startPromise reset enables
|
||||
// long-tail retry from any new caller after that.
|
||||
this.startPromise = null;
|
||||
|
||||
if (this.reconnectAttempted) return;
|
||||
this.reconnectAttempted = true;
|
||||
|
||||
// Async reconnect-once. Concurrent call() awaits this promise BEFORE
|
||||
// checking !this.sock, eliminating the V3-05 race.
|
||||
this.reconnectPromise = (async () => {
|
||||
try {
|
||||
// Test-only deterministic widener for the V3-05 race window.
|
||||
// In production this env var is unset → 0 ms → no-op. The
|
||||
// V3-05 regression test (tests/test_socket_disconnect_reconnect.py)
|
||||
// sets IAI_MCP_RECONNECT_TEST_DELAY_MS=1000 so the racing
|
||||
// call() can land deterministically inside the gap between
|
||||
// socket close and reconnect-completion. Without this delay the
|
||||
// race window is sub-millisecond and the regression test cannot
|
||||
// distinguish pre-fix (rejects daemon_unreachable) from post-fix
|
||||
// (awaits reconnectPromise, succeeds).
|
||||
const testDelayMs = Number(
|
||||
process.env.IAI_MCP_RECONNECT_TEST_DELAY_MS ?? "0",
|
||||
);
|
||||
if (testDelayMs > 0) {
|
||||
await new Promise<void>((r) => setTimeout(r, testDelayMs));
|
||||
}
|
||||
// Manually do socket-first connect (without resetting the latch
|
||||
// that start() does) so a SECOND mid-call death stays degraded.
|
||||
this.sock = await this.connectWithTimeout(
|
||||
DAEMON_SOCKET_PATH,
|
||||
SOCKET_CONNECT_TIMEOUT_MS,
|
||||
);
|
||||
this.attachSocketHandlers();
|
||||
} catch {
|
||||
// stay degraded — every subsequent call sees this.sock === null
|
||||
// and rejects with daemon_unreachable.
|
||||
} finally {
|
||||
this.reconnectPromise = null;
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a JSON-RPC 2.0 request over the socket; resolves with `result`
|
||||
* or rejects with the daemon-side `error.message`.
|
||||
*
|
||||
* R5 fail-loud: when this.sock is null (post-death, post-disconnect,
|
||||
* pre-start) the call rejects synchronously with `daemon_unreachable`.
|
||||
* NO silent fallback to a local Python core spawn.
|
||||
*/
|
||||
async call<T = unknown>(
|
||||
method: string,
|
||||
params: Record<string, unknown> = {},
|
||||
): Promise<T> {
|
||||
// V3-05 fix: if a reconnect is in flight, wait for it before deciding
|
||||
// whether the socket is alive. Without this await, a call() landing in
|
||||
// the gap between socket close and reconnect-completion would reject
|
||||
// with daemon_unreachable even though the daemon is healthy.
|
||||
if (this.reconnectPromise) {
|
||||
await this.reconnectPromise;
|
||||
}
|
||||
if (!this.sock) {
|
||||
throw new Error(`daemon_unreachable: bridge not connected (code ${ERR_DAEMON_UNREACHABLE})`);
|
||||
}
|
||||
const id = this.nextId++;
|
||||
const req: RpcRequest = { jsonrpc: "2.0", id, method, params };
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
this.pending.set(id, {
|
||||
resolve: resolve as (v: unknown) => void,
|
||||
reject,
|
||||
});
|
||||
try {
|
||||
this.sock!.write(JSON.stringify(req) + "\n");
|
||||
} catch (e) {
|
||||
this.pending.delete(id);
|
||||
reject(e as Error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Public API: close the socket but leave the daemon running.
|
||||
* Used by index.ts SIGTERM/SIGINT handlers.
|
||||
*
|
||||
* After Phase 7 the wrapper does NOT own the daemon's lifecycle —
|
||||
* disconnecting a wrapper must NOT kill the singleton, otherwise other
|
||||
* wrappers (other MCP hosts, sub-agents) would lose their
|
||||
* shared brain.
|
||||
*/
|
||||
disconnect(): void {
|
||||
if (this.sock) {
|
||||
try { this.sock.end(); } catch { /* ignore */ }
|
||||
try { this.sock.destroy(); } catch { /* ignore */ }
|
||||
this.sock = null;
|
||||
}
|
||||
// Clear the start-latch so a fresh start() (e.g. test re-use of the
|
||||
// bridge instance) is treated as a brand new connection.
|
||||
this.startPromise = null;
|
||||
// Reject any in-flight calls with a clean message (NOT
|
||||
// daemon_unreachable — the daemon is fine; we just chose to close).
|
||||
for (const [, p] of this.pending) {
|
||||
p.reject(new Error("bridge_disconnected"));
|
||||
}
|
||||
this.pending.clear();
|
||||
}
|
||||
|
||||
// Visible for tests: smoke endpoint replacing the pre-Phase-7
|
||||
// isRunning() that checked for a child process.
|
||||
isConnected(): boolean {
|
||||
return this.sock !== null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Plan 05-04 TOK-14 / D5-05 — session_open emit over the daemon unix socket.
|
||||
// UNCHANGED by Phase 7 (Plan 07-04). Same socket path; brief separate
|
||||
// connection that fires a one-shot HIPPEA pre-warm hint then closes.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
/**
|
||||
* Path to the Python daemon's unix control socket.
|
||||
* Mirror of `concurrency.SOCKET_PATH` in the Python core (`~/.iai-mcp/.daemon.sock`).
|
||||
*
|
||||
* Honors `IAI_DAEMON_SOCKET_PATH` so tests can isolate via tmp socket paths
|
||||
* (matches the same env override the main bridge socket connect uses).
|
||||
*/
|
||||
export function sessionOpenSocketPath(): string {
|
||||
const env = process.env.IAI_DAEMON_SOCKET_PATH;
|
||||
if (env) return env;
|
||||
return path.join(os.homedir(), ".iai-mcp", ".daemon.sock");
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate a fresh session identifier for the boot event.
|
||||
* Node stdlib since 14.17 — no dependency added.
|
||||
*/
|
||||
export function newSessionId(): string {
|
||||
return crypto.randomUUID();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Fire-and-forget NDJSON `session_open` message to the daemon socket.
|
||||
*
|
||||
* Contract:
|
||||
* - Writes one line: `{"type":"session_open","session_id":"...","ts":"..."}\n`
|
||||
* - One-shot semantics: does **not** read the daemon's response bytes before
|
||||
* `end()` — intentional (HIPPEA hint only). If the daemon wrote backpressure
|
||||
* or error bytes, they are left unread; the separate long-lived `PythonCoreBridge`
|
||||
* connection owns JSON-RPC traffic.
|
||||
* - Silent-fail on any network, socket-not-found, or timeout error. The
|
||||
* Python core's `_first_turn_recall_hook` falls back to the cold recall
|
||||
* path when the cascade LRU is empty (expected when daemon is down).
|
||||
* - Hard timeout at 2s so a hung socket cannot delay wrapper boot.
|
||||
*
|
||||
* Returns a Promise<void> that ALWAYS resolves (never rejects) so callers
|
||||
* can use `void emitSessionOpen(...)` in a sync bootstrap block without
|
||||
* an explicit `.catch`.
|
||||
*/
|
||||
export function emitSessionOpen(sessionId: string): Promise<void> {
|
||||
return new Promise<void>((resolve) => {
|
||||
let settled = false;
|
||||
const finish = () => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolve();
|
||||
};
|
||||
try {
|
||||
const socketPath = sessionOpenSocketPath();
|
||||
const sock = net.createConnection(socketPath, () => {
|
||||
const msg =
|
||||
JSON.stringify({
|
||||
type: "session_open",
|
||||
session_id: sessionId,
|
||||
ts: new Date().toISOString(),
|
||||
}) + "\n";
|
||||
sock.write(msg, () => {
|
||||
sock.end();
|
||||
});
|
||||
});
|
||||
sock.on("error", () => finish());
|
||||
sock.on("close", () => finish());
|
||||
sock.setTimeout(2000, () => {
|
||||
try {
|
||||
sock.destroy();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
finish();
|
||||
});
|
||||
} catch {
|
||||
// Any sync setup failure -> silent fallback.
|
||||
finish();
|
||||
}
|
||||
});
|
||||
}
|
||||
85
mcp-wrapper/src/caching.ts
Normal file
85
mcp-wrapper/src/caching.ts
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
// Anthropic 1h-TTL prompt caching (TOK-01, D-10).
|
||||
//
|
||||
// Single breakpoint at the stable/volatile boundary. The Python core's
|
||||
// `session_start_payload` returns the 4-segment cached prefix; this module
|
||||
// wraps it in Anthropic `content` blocks and stamps `cache_control` on the
|
||||
// last stable block so Anthropic's cache sees one hashable suffix.
|
||||
//
|
||||
// cache_control TTL="1h" is the Anthropic prompt-caching extended-TTL option
|
||||
// released in Oct 2024 (enabled per-org; falls back to "5m" default when
|
||||
// unsupported). Rationale per D-10: session-start prefix rarely changes
|
||||
// within an hour, so 1h TTL hits Anthropic's cache on every turn after the
|
||||
// first fresh-session write (OPS-02 8000-token premium absorbed once).
|
||||
|
||||
export interface CacheControl {
|
||||
readonly type: "ephemeral";
|
||||
readonly ttl: "1h" | "5m";
|
||||
}
|
||||
|
||||
export interface ContentBlock {
|
||||
type: string;
|
||||
text?: string;
|
||||
cache_control?: CacheControl;
|
||||
}
|
||||
|
||||
export interface SessionPayloadRaw {
|
||||
l0: string;
|
||||
l1: string;
|
||||
l2: string[];
|
||||
rich_club: string;
|
||||
total_cached_tokens: number;
|
||||
total_dynamic_tokens: number;
|
||||
breakpoint_marker?: string;
|
||||
}
|
||||
|
||||
/** Attach a single `cache_control` breakpoint at the stable/volatile boundary.
|
||||
*
|
||||
* Per TOK-01 we emit exactly one breakpoint: on the LAST block of `stable`.
|
||||
* If `stable` is empty the function returns the volatile blocks unchanged --
|
||||
* there is no sensible place to hang a breakpoint on an empty prefix and
|
||||
* Anthropic's API would reject the request.
|
||||
*
|
||||
* Returns a new array; inputs are not mutated. */
|
||||
export function applyCacheBreakpoint(
|
||||
stable: ContentBlock[],
|
||||
volatile: ContentBlock[],
|
||||
): ContentBlock[] {
|
||||
if (stable.length === 0) {
|
||||
return [...volatile];
|
||||
}
|
||||
const cloned = stable.map((b) => ({ ...b }));
|
||||
cloned[cloned.length - 1] = {
|
||||
...cloned[cloned.length - 1],
|
||||
cache_control: { type: "ephemeral", ttl: "1h" },
|
||||
};
|
||||
return [...cloned, ...volatile];
|
||||
}
|
||||
|
||||
/** Build the cached system prompt from the Python session_start_payload.
|
||||
*
|
||||
* Segments in order: L0 identity, L1 critical facts, L2 community summaries
|
||||
* (one block per community), rich-club prefetch. Empty segments are skipped
|
||||
* so the cache-key is stable across sessions where, say, L1 is empty.
|
||||
*
|
||||
* Returned blocks already have the cache_control breakpoint applied. */
|
||||
export function buildCachedSystemPrompt(
|
||||
payload: SessionPayloadRaw,
|
||||
): ContentBlock[] {
|
||||
const stable: ContentBlock[] = [];
|
||||
if (payload.l0) {
|
||||
stable.push({ type: "text", text: `# L0 identity\n${payload.l0}` });
|
||||
}
|
||||
if (payload.l1) {
|
||||
stable.push({ type: "text", text: `# L1 critical facts\n${payload.l1}` });
|
||||
}
|
||||
for (const segment of payload.l2) {
|
||||
stable.push({ type: "text", text: `# L2 community\n${segment}` });
|
||||
}
|
||||
if (payload.rich_club) {
|
||||
stable.push({
|
||||
type: "text",
|
||||
text: `# Global rich-club\n${payload.rich_club}`,
|
||||
});
|
||||
}
|
||||
return applyCacheBreakpoint(stable, []);
|
||||
}
|
||||
226
mcp-wrapper/src/index.ts
Normal file
226
mcp-wrapper/src/index.ts
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
#!/usr/bin/env node
|
||||
// IAI-MCP TypeScript wrapper entry point (Plan 03 wave).
|
||||
//
|
||||
// - Spawns the Python core over stdio JSON-RPC (see bridge.ts)
|
||||
// - Advertises the 12 hot tools via HOT_TOOLS registry (TOK-02)
|
||||
// - Attaches Anthropic 1h-TTL cache_control at the stable/volatile boundary
|
||||
// (TOK-01) via caching.ts helpers
|
||||
// - Advertises `clear_tool_uses_20250919` context editing with 30k trigger
|
||||
// (TOK-05) via registry.ts CONTEXT_EDITING_CONFIG
|
||||
// - On MCP `initialize`, warms the Python session_start payload so the first
|
||||
// real user turn doesn't pay the fresh-session cost synchronously.
|
||||
|
||||
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
|
||||
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
ListToolsRequestSchema,
|
||||
} from "@modelcontextprotocol/sdk/types.js";
|
||||
|
||||
import {
|
||||
emitSessionOpen,
|
||||
newSessionId,
|
||||
PythonCoreBridge,
|
||||
} from "./bridge.js";
|
||||
import {
|
||||
applyCacheBreakpoint,
|
||||
buildCachedSystemPrompt,
|
||||
type ContentBlock,
|
||||
type SessionPayloadRaw,
|
||||
} from "./caching.js";
|
||||
import { WrapperLifecycle } from "./lifecycle.js";
|
||||
import {
|
||||
CONTEXT_EDITING_CONFIG,
|
||||
HOT_TOOLS,
|
||||
listHotTools,
|
||||
} from "./registry.js";
|
||||
import { invokeTool, type ToolName } from "./tools.js";
|
||||
|
||||
// Re-export so consumers of the module (and tests) can touch the helpers
|
||||
// without dynamic imports.
|
||||
export {
|
||||
applyCacheBreakpoint,
|
||||
buildCachedSystemPrompt,
|
||||
CONTEXT_EDITING_CONFIG,
|
||||
HOT_TOOLS,
|
||||
};
|
||||
export type { ContentBlock, SessionPayloadRaw };
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// mcp-tools-list-empty-cache fix (2026-05-02):
|
||||
//
|
||||
// Pre-fix order was:
|
||||
// 1. await bridge.start() ← could block 5s on slow daemon
|
||||
// 2. construct Server + handlers
|
||||
// 3. await server.connect(transport)
|
||||
//
|
||||
// On a slow daemon (cold launchd hand-off, multi-second LanceDB open, RSS
|
||||
// watchdog respawn) the top-level await in step 1 delayed step 3 past the
|
||||
// MCP client's tools/list timeout. The client cached an empty tool list
|
||||
// for the rest of the session — symptom: "Connected" but zero
|
||||
// `mcp__iai-mcp__*` tools in the registry.
|
||||
//
|
||||
// Fixed order is:
|
||||
// 1. construct Server + register both request handlers + assign
|
||||
// oninitialized (must be set before connect — the initialized
|
||||
// notification fires immediately after handshake and an unset
|
||||
// handler would discard the HIPPEA pre-warm trigger).
|
||||
// 2. await server.connect(transport) ← tools/list is responsive HERE,
|
||||
// independent of daemon state (handler returns from static
|
||||
// registry.listHotTools()).
|
||||
// 3. fire-and-forget bridge.start() chained with emitSessionOpen — the
|
||||
// D5-05 invariant "emitSessionOpen fires AFTER daemon socket
|
||||
// reachable" is preserved by the .then() chain.
|
||||
// 4. CallToolRequest handler lazy-awaits bridge.start() before
|
||||
// delegating to invokeTool — first tools/call may pay daemon
|
||||
// cold-start cost ONCE; tools/list never blocks.
|
||||
//
|
||||
// Invariants preserved:
|
||||
// - Phase 7.1: wrapper does NOT spawn daemon (bridge.ts unchanged on
|
||||
// this point — it's still socket-only).
|
||||
// - Plan 05-04 D5-05 (HIPPEA pre-warm): emitSessionOpen still chained
|
||||
// off bridge.start() readiness.
|
||||
// - Plan 07-04 Task 2: SIGTERM/SIGINT closes socket only; daemon
|
||||
// survives. Unchanged.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const bridge = new PythonCoreBridge();
|
||||
|
||||
const server = new Server(
|
||||
{
|
||||
name: "iai-mcp",
|
||||
version: "0.1.0",
|
||||
},
|
||||
{
|
||||
capabilities: { tools: {} },
|
||||
// Expose TOK-05 context-editing config so MCP hosts that honour
|
||||
// Anthropic's context management can pick it up at discovery time.
|
||||
instructions: JSON.stringify({
|
||||
context_editing: CONTEXT_EDITING_CONFIG,
|
||||
hot_tools: HOT_TOOLS,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
// tools/list MUST return from the static registry without touching the
|
||||
// bridge — see file-top comment block. This is what makes the wrapper
|
||||
// safe to advertise to the MCP client before the daemon socket is
|
||||
// reachable.
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => ({
|
||||
tools: listHotTools(),
|
||||
}));
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async (req) => {
|
||||
const name = req.params.name as ToolName;
|
||||
if (!HOT_TOOLS.includes(name)) {
|
||||
return {
|
||||
content: [{ type: "text" as const, text: `unknown tool ${name}` }],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
try {
|
||||
// Lazy bridge connect: the first tools/call after wrapper boot drives
|
||||
// the daemon socket connect. Subsequent calls short-circuit on the
|
||||
// alive socket. start() is concurrency-safe (startPromise serialises
|
||||
// multiple concurrent first-callers — see bridge.ts).
|
||||
await bridge.start();
|
||||
const result = await invokeTool(bridge, name, req.params.arguments ?? {});
|
||||
return {
|
||||
content: [{ type: "text" as const, text: JSON.stringify(result) }],
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text" as const, text: `error: ${(e as Error).message}` },
|
||||
],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
// Boot-time session id for Plan 05-04 session_open + downstream bookkeeping.
|
||||
const bootSessionId = newSessionId();
|
||||
|
||||
// MCP initialize hook -- warm the Python session-start payload so the first
|
||||
// real turn doesn't pay the fresh-session cost synchronously. OPS-05 continuity
|
||||
// is surfaced earlier this way: by the time Claude issues tools/call, the L0
|
||||
// pinned record is already resident in the Python core's warm cache.
|
||||
//
|
||||
// Must be assigned BEFORE server.connect() — the initialized notification
|
||||
// fires immediately after the handshake and an unset handler would silently
|
||||
// discard the pre-warm trigger.
|
||||
server.oninitialized = () => {
|
||||
// Chain on bridge readiness so the session_start_payload call doesn't
|
||||
// race the socket connect. start() is idempotent and serialised; if
|
||||
// the lazy CallToolRequest path already drove start, this awaits the
|
||||
// same in-flight promise.
|
||||
bridge
|
||||
.start()
|
||||
.then(() =>
|
||||
bridge.call<SessionPayloadRaw>("session_start_payload", {
|
||||
session_id: bootSessionId,
|
||||
}),
|
||||
)
|
||||
.catch(() => null);
|
||||
};
|
||||
|
||||
// Phase 10.5 L5 + L4: proactive wake + heartbeat refresh.
|
||||
//
|
||||
// Run BEFORE server.connect so the heartbeat is registered before any
|
||||
// tools/list or tools/call request can land. ensureDaemonAlive is
|
||||
// independent of the bridge.start() call below — it only probes the
|
||||
// socket and (on darwin) invokes `launchctl kickstart` via execFile;
|
||||
// it never connects. The 045999b decoupling is preserved: tools/list
|
||||
// still responds from the static registry whether the daemon is up
|
||||
// or not, and ensureDaemonAlive's failure path (wake.signal write)
|
||||
// is silent and non-fatal.
|
||||
const lifecycle = new WrapperLifecycle();
|
||||
await lifecycle.ensureDaemonAlive();
|
||||
await lifecycle.registerHeartbeat();
|
||||
|
||||
const transport = new StdioServerTransport();
|
||||
await server.connect(transport);
|
||||
|
||||
// Fire-and-forget daemon connect AFTER the MCP transport is live.
|
||||
// - bridge.start(): socket-only connect to the singleton daemon (Phase 7.1
|
||||
// invariant — never spawns).
|
||||
// - emitSessionOpen: D5-05 HIPPEA pre-warm hint; chained off start() so
|
||||
// the cascade-LRU activation happens AFTER the daemon is known
|
||||
// reachable. If the daemon is unreachable, start() rejects with
|
||||
// DaemonUnreachableError and the .catch() suppresses the unhandled
|
||||
// rejection — the wrapper continues serving tools/list and falls back
|
||||
// to per-call lazy retry in the CallToolRequest handler.
|
||||
void bridge
|
||||
.start()
|
||||
.then(() => emitSessionOpen(bootSessionId))
|
||||
.catch(() => {
|
||||
// Silent: tools/call will surface the daemon_unreachable error
|
||||
// synchronously when the user actually invokes a tool.
|
||||
});
|
||||
|
||||
// Phase 7 (Plan 07-04 Task 2): wrapper closing must NOT kill the shared
|
||||
// daemon. disconnect() closes the socket only; the singleton survives so
|
||||
// other wrappers (other MCP hosts, sub-agents) and future boots
|
||||
// can join. This is the load-bearing semantic of the Phase 7 singleton
|
||||
// model — the pre-Phase-7 wrapper-side child-kill API has been removed.
|
||||
//
|
||||
// Phase 10.5 L4 addition: cleanupHeartbeat clears the refresh timer
|
||||
// AND deletes ~/.iai-mcp/wrappers/heartbeat-<pid>-<uuid>.json so the
|
||||
// daemon-side scanner doesn't have to rely on STALE-detection for a
|
||||
// gracefully-exiting wrapper. Cleanup is idempotent and never throws.
|
||||
const shutdown = async (): Promise<void> => {
|
||||
try {
|
||||
await lifecycle.cleanupHeartbeat();
|
||||
} catch {
|
||||
// Cleanup is best-effort; the daemon's HeartbeatScanner reaps
|
||||
// STALE / ORPHAN entries on its next tick.
|
||||
}
|
||||
bridge.disconnect();
|
||||
process.exit(0);
|
||||
};
|
||||
process.on("SIGTERM", () => {
|
||||
void shutdown();
|
||||
});
|
||||
process.on("SIGINT", () => {
|
||||
void shutdown();
|
||||
});
|
||||
339
mcp-wrapper/src/lifecycle.ts
Normal file
339
mcp-wrapper/src/lifecycle.ts
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
// Phase 10.5 L5 + L4 — wrapper-side proactive wake + heartbeat refresh.
|
||||
//
|
||||
// Two responsibilities, both lazy and idle-CPU-near-zero:
|
||||
//
|
||||
// L5 ensureDaemonAlive:
|
||||
// Probe the daemon UNIX socket (~/.iai-mcp/.daemon.sock) at boot.
|
||||
// If reachable, return immediately — no kickstart cost, no signal.
|
||||
// If unreachable AND platform is darwin, spawn `launchctl kickstart
|
||||
// -k gui/<uid>/com.iai-mcp.daemon` via Node's `execFile` API
|
||||
// (array args, hard-coded binary path, NEVER `shell: true`).
|
||||
// If the kickstart command fails or the platform is not darwin,
|
||||
// atomic-write ~/.iai-mcp/wake.signal so the next daemon cold-
|
||||
// start consumes it via `iai_mcp.wake_handler.WakeHandler`. The
|
||||
// wrapper itself NEVER spawns the daemon Python process — that
|
||||
// remains a launchd / external-init concern (Phase 7.1 invariant).
|
||||
//
|
||||
// L4 registerHeartbeat:
|
||||
// Atomically write ~/.iai-mcp/wrappers/heartbeat-<pid>-<uuid>.json
|
||||
// (temp + rename) and start a 30-second interval timer that
|
||||
// refreshes the `last_refresh` field. The timer is `unref()`d so
|
||||
// it does NOT block Node.js shutdown — the wrapper exits cleanly
|
||||
// even if `cleanupHeartbeat` is not called (the daemon's
|
||||
// HeartbeatScanner from Phase 10.4 will eventually classify the
|
||||
// file as STALE / ORPHAN and reap it).
|
||||
//
|
||||
// Hard rules carried from CONTEXT 10.5:
|
||||
//
|
||||
// - All `child_process` calls go through `execFile` (array args).
|
||||
// NEVER the shell-interpreting `exec` variant. NEVER `shell: true`.
|
||||
// Hard-coded binary path (/bin/launchctl); only the GUI uid is
|
||||
// process-derived (`process.getuid()`).
|
||||
// - The 30-sec refresh is a single `setInterval` with `unref()`, not
|
||||
// a busy loop or per-tick spawn.
|
||||
// - macOS-first; Linux / unknown platforms write `wake.signal`
|
||||
// directly without attempting kickstart.
|
||||
// - 045999b decoupling preserved — this module is independent of the
|
||||
// bridge / tools/list path. `ensureDaemonAlive` is a probe + spawn,
|
||||
// not a connect; tools/list MUST keep responding from the static
|
||||
// wrapper registry whether the daemon is up or not.
|
||||
// - `src/utils/execFileNoThrow.ts` is referenced in CONTEXT 10.5 as a
|
||||
// pattern reference but does NOT exist in this repo. We inline the
|
||||
// pattern here: `promisify(execFile)` + try/catch. Keeps the LOC
|
||||
// budget tight and makes the security guarantee local.
|
||||
//
|
||||
// File schema (matches `iai_mcp.heartbeat_scanner._parse_heartbeat_file`):
|
||||
//
|
||||
// {
|
||||
// "pid": 12345,
|
||||
// "uuid": "01HZQ...", // crypto.randomUUID()
|
||||
// "started_at": "2026-05-02T15:00:00Z",
|
||||
// "last_refresh": "2026-05-02T15:14:30Z",
|
||||
// "wrapper_version": "1.0.0",
|
||||
// "schema_version": 1
|
||||
// }
|
||||
|
||||
import { execFile } from "node:child_process";
|
||||
import { randomUUID } from "node:crypto";
|
||||
import { mkdir, rename, unlink, writeFile } from "node:fs/promises";
|
||||
import { homedir } from "node:os";
|
||||
import { dirname, join } from "node:path";
|
||||
import { promisify } from "node:util";
|
||||
|
||||
const execFileAsync = promisify(execFile);
|
||||
|
||||
// ---------------------------------------------------------------- constants
|
||||
|
||||
/** Refresh cadence (ms). 30 s is the LOCKED contract from CONTEXT 10.4 / 10.5
|
||||
* — three missed refreshes (~90 s) trip the heartbeat scanner's STALE
|
||||
* threshold (`DEFAULT_STALE_THRESHOLD_SEC` in `heartbeat_scanner.py`). */
|
||||
export const HEARTBEAT_REFRESH_INTERVAL_MS = 30_000;
|
||||
|
||||
/** Wrapper schema version. Bump only on a breaking change to the heartbeat
|
||||
* file shape. Phase 10.4 reader currently treats `schema_version` as
|
||||
* informational; future versions may gate field-presence checks on it. */
|
||||
export const HEARTBEAT_SCHEMA_VERSION = 1;
|
||||
|
||||
/** Wrapper version string written into each heartbeat file. Tracks the
|
||||
* `mcp-wrapper/package.json` version semantically; not auto-derived to
|
||||
* keep this module dependency-free at runtime. */
|
||||
export const WRAPPER_VERSION = "1.0.0";
|
||||
|
||||
/** Hard-coded launchctl binary path. Argv-only invocation — no shell
|
||||
* interpretation, no PATH lookup, no user-input interpolation. */
|
||||
const LAUNCHCTL_BIN = "/bin/launchctl";
|
||||
|
||||
/** Hard-coded launchd label for the IAI-MCP daemon. Matches the
|
||||
* `com.iai-mcp.daemon` LaunchAgent shipped by the project. */
|
||||
const LAUNCHD_LABEL = "com.iai-mcp.daemon";
|
||||
|
||||
/** Subprocess timeout (ms) for the kickstart call. Covers the worst-case
|
||||
* `launchctl kickstart` round-trip on a heavily loaded box; well under
|
||||
* the wrapper's MCP `tools/list` budget (server.connect already happens
|
||||
* before this in the boot flow). */
|
||||
const KICKSTART_TIMEOUT_MS = 5_000;
|
||||
|
||||
// ---------------------------------------------------------------- types
|
||||
|
||||
interface HeartbeatPayload {
|
||||
pid: number;
|
||||
uuid: string;
|
||||
started_at: string;
|
||||
last_refresh: string;
|
||||
wrapper_version: string;
|
||||
schema_version: number;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------- paths
|
||||
|
||||
/** Compute `~/.iai-mcp/.daemon.sock`. Mirrors the daemon-side socket
|
||||
* path constant in `iai_mcp.concurrency`. */
|
||||
export function defaultSocketPath(): string {
|
||||
return join(homedir(), ".iai-mcp", ".daemon.sock");
|
||||
}
|
||||
|
||||
/** Compute `~/.iai-mcp/wake.signal`. Mirrors the path the daemon-side
|
||||
* `WakeHandler` consumes on cold-start. */
|
||||
export function defaultWakeSignalPath(): string {
|
||||
return join(homedir(), ".iai-mcp", "wake.signal");
|
||||
}
|
||||
|
||||
/** Compute `~/.iai-mcp/wrappers/heartbeat-<pid>-<uuid>.json`. Matches
|
||||
* the filename glob in `iai_mcp.heartbeat_scanner`. */
|
||||
export function defaultHeartbeatPath(pid: number, uuid: string): string {
|
||||
return join(homedir(), ".iai-mcp", "wrappers", `heartbeat-${pid}-${uuid}.json`);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------- lifecycle
|
||||
|
||||
/** Constructor options. All fields optional; defaults derive from
|
||||
* `process` and `os.homedir()`. Dependency injection is here so tests
|
||||
* can supply a tmp dir without monkey-patching `homedir`. */
|
||||
export interface WrapperLifecycleOptions {
|
||||
pid?: number;
|
||||
uuid?: string;
|
||||
socketPath?: string;
|
||||
wakeSignalPath?: string;
|
||||
heartbeatPath?: string;
|
||||
/** Override the platform string. Defaults to `process.platform`. */
|
||||
platform?: NodeJS.Platform;
|
||||
/** Probe the daemon socket. Defaults to a real `net.createConnection`
|
||||
* attempt with a short timeout. Tests inject a mock. */
|
||||
socketReachable?: () => Promise<boolean>;
|
||||
/** Spawn `launchctl kickstart`. Defaults to the real `execFile` call.
|
||||
* Tests inject a mock that resolves or rejects deterministically. */
|
||||
spawnKickstart?: () => Promise<void>;
|
||||
/** Heartbeat refresh interval (ms). Defaults to
|
||||
* `HEARTBEAT_REFRESH_INTERVAL_MS`. Tests pass a smaller value. */
|
||||
refreshIntervalMs?: number;
|
||||
}
|
||||
|
||||
export class WrapperLifecycle {
|
||||
private readonly pid: number;
|
||||
private readonly uuid: string;
|
||||
private readonly socketPath: string;
|
||||
private readonly wakeSignalPath: string;
|
||||
private readonly heartbeatPath: string;
|
||||
private readonly platform: NodeJS.Platform;
|
||||
private readonly socketReachable: () => Promise<boolean>;
|
||||
private readonly spawnKickstart: () => Promise<void>;
|
||||
private readonly refreshIntervalMs: number;
|
||||
|
||||
private readonly startedAt: string;
|
||||
private timer: NodeJS.Timeout | null = null;
|
||||
|
||||
constructor(opts: WrapperLifecycleOptions = {}) {
|
||||
this.pid = opts.pid ?? process.pid;
|
||||
this.uuid = opts.uuid ?? randomUUID();
|
||||
this.socketPath = opts.socketPath ?? defaultSocketPath();
|
||||
this.wakeSignalPath = opts.wakeSignalPath ?? defaultWakeSignalPath();
|
||||
this.heartbeatPath =
|
||||
opts.heartbeatPath ?? defaultHeartbeatPath(this.pid, this.uuid);
|
||||
this.platform = opts.platform ?? process.platform;
|
||||
this.socketReachable = opts.socketReachable ?? defaultSocketReachable(this.socketPath);
|
||||
this.spawnKickstart = opts.spawnKickstart ?? defaultSpawnKickstart();
|
||||
this.refreshIntervalMs = opts.refreshIntervalMs ?? HEARTBEAT_REFRESH_INTERVAL_MS;
|
||||
this.startedAt = isoNow();
|
||||
}
|
||||
|
||||
/** L5: probe daemon socket; if unreachable, kickstart on darwin or
|
||||
* write `wake.signal` elsewhere. Never throws — the worst case is a
|
||||
* silent fallback to the signal file, which the daemon will pick up
|
||||
* on its next cold start. */
|
||||
async ensureDaemonAlive(): Promise<void> {
|
||||
let alive = false;
|
||||
try {
|
||||
alive = await this.socketReachable();
|
||||
} catch {
|
||||
alive = false;
|
||||
}
|
||||
if (alive) {
|
||||
return;
|
||||
}
|
||||
if (this.platform === "darwin") {
|
||||
try {
|
||||
await this.spawnKickstart();
|
||||
return;
|
||||
} catch {
|
||||
// Kickstart failed (launchd label missing, permission error,
|
||||
// timeout). Fall through to the wake.signal fallback so the
|
||||
// daemon's next cold-start path still consumes the request.
|
||||
}
|
||||
}
|
||||
// Non-darwin OR darwin-with-failed-kickstart: write the cross-
|
||||
// platform marker so a future daemon boot picks it up.
|
||||
try {
|
||||
await this.writeWakeSignal();
|
||||
} catch {
|
||||
// Even the wake.signal write failed (FS full, permission). Nothing
|
||||
// we can do safely here; do NOT escalate — the wrapper still has
|
||||
// useful work to do (tools/list responds from the static registry).
|
||||
}
|
||||
}
|
||||
|
||||
/** L4: write the heartbeat file and start the 30-sec refresh timer.
|
||||
* Called once at wrapper boot. Idempotent on the timer side: a second
|
||||
* call clears any prior timer before installing a new one. */
|
||||
async registerHeartbeat(): Promise<void> {
|
||||
await this.writeHeartbeat();
|
||||
if (this.timer !== null) {
|
||||
clearInterval(this.timer);
|
||||
}
|
||||
const timer = setInterval(() => {
|
||||
void this.writeHeartbeat().catch(() => {
|
||||
// Refresh failure is non-fatal: the daemon will classify the
|
||||
// stale file as STALE on the next scan and recover. We do NOT
|
||||
// log here to keep the idle-CPU profile near zero.
|
||||
});
|
||||
}, this.refreshIntervalMs);
|
||||
timer.unref();
|
||||
this.timer = timer;
|
||||
}
|
||||
|
||||
/** Graceful exit: stop the refresh timer and delete the heartbeat
|
||||
* file. Safe to call multiple times. Safe to call without prior
|
||||
* `registerHeartbeat` (no-ops). */
|
||||
async cleanupHeartbeat(): Promise<void> {
|
||||
if (this.timer !== null) {
|
||||
clearInterval(this.timer);
|
||||
this.timer = null;
|
||||
}
|
||||
try {
|
||||
await unlink(this.heartbeatPath);
|
||||
} catch {
|
||||
// Already gone (concurrent daemon-side cleanup of a STALE entry,
|
||||
// or never written). Idempotent — swallow.
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------- internals (visible-for-test)
|
||||
|
||||
/** Atomically write the heartbeat file: tmp + rename. The tmp
|
||||
* filename includes the wrapper's UUID so concurrent wrappers do
|
||||
* NOT collide on the staging path even if they share a working
|
||||
* directory. */
|
||||
private async writeHeartbeat(): Promise<void> {
|
||||
const payload: HeartbeatPayload = {
|
||||
pid: this.pid,
|
||||
uuid: this.uuid,
|
||||
started_at: this.startedAt,
|
||||
last_refresh: isoNow(),
|
||||
wrapper_version: WRAPPER_VERSION,
|
||||
schema_version: HEARTBEAT_SCHEMA_VERSION,
|
||||
};
|
||||
const dir = dirname(this.heartbeatPath);
|
||||
await mkdir(dir, { recursive: true });
|
||||
const tmp = `${this.heartbeatPath}.${this.uuid}.tmp`;
|
||||
await writeFile(tmp, JSON.stringify(payload), { encoding: "utf-8" });
|
||||
await rename(tmp, this.heartbeatPath);
|
||||
}
|
||||
|
||||
/** Atomically write `wake.signal`: tmp + rename. Per-uuid tmp suffix
|
||||
* avoids cross-wrapper staging collisions on the same machine. */
|
||||
private async writeWakeSignal(): Promise<void> {
|
||||
const dir = dirname(this.wakeSignalPath);
|
||||
await mkdir(dir, { recursive: true });
|
||||
const payload = JSON.stringify({
|
||||
requested_at: isoNow(),
|
||||
wrapper_pid: this.pid,
|
||||
wrapper_uuid: this.uuid,
|
||||
});
|
||||
const tmp = `${this.wakeSignalPath}.${this.uuid}.tmp`;
|
||||
await writeFile(tmp, payload, { encoding: "utf-8" });
|
||||
await rename(tmp, this.wakeSignalPath);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------- defaults
|
||||
|
||||
function isoNow(): string {
|
||||
// ISO-8601 with trailing Z — matches the wire format the daemon-side
|
||||
// `_parse_heartbeat_file` accepts (replaces "Z" with "+00:00" before
|
||||
// `datetime.fromisoformat`).
|
||||
return new Date().toISOString();
|
||||
}
|
||||
|
||||
/** Default socket-probe: open a UNIX-domain socket connection to the
|
||||
* daemon path with a short timeout. Resolves true on `connect`,
|
||||
* false on `error` or timeout. */
|
||||
function defaultSocketReachable(socketPath: string): () => Promise<boolean> {
|
||||
return async () => {
|
||||
const { createConnection } = await import("node:net");
|
||||
return await new Promise<boolean>((resolve) => {
|
||||
let settled = false;
|
||||
const settle = (v: boolean): void => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
try {
|
||||
socket.destroy();
|
||||
} catch {
|
||||
// socket already destroyed by the loser of the connect/timeout
|
||||
// race — ignore.
|
||||
}
|
||||
resolve(v);
|
||||
};
|
||||
const socket = createConnection({ path: socketPath });
|
||||
socket.setTimeout(1_000);
|
||||
socket.once("connect", () => settle(true));
|
||||
socket.once("error", () => settle(false));
|
||||
socket.once("timeout", () => settle(false));
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
/** Default kickstart spawn: `execFile` with array args, hard-coded
|
||||
* binary path, no shell. The GUI uid is process-derived (`getuid()`)
|
||||
* so the same wrapper works for any signed-in user. */
|
||||
function defaultSpawnKickstart(): () => Promise<void> {
|
||||
return async () => {
|
||||
// `process.getuid()` is undefined on Windows builds; ! asserts
|
||||
// non-null because we only ever call this on darwin (the
|
||||
// ensureDaemonAlive caller gates on platform === "darwin").
|
||||
const uid = typeof process.getuid === "function" ? process.getuid() : 0;
|
||||
const args = ["kickstart", "-k", `gui/${uid}/${LAUNCHD_LABEL}`];
|
||||
await execFileAsync(LAUNCHCTL_BIN, args, {
|
||||
timeout: KICKSTART_TIMEOUT_MS,
|
||||
// No `shell` option — argv-only invocation, no shell interpretation.
|
||||
});
|
||||
};
|
||||
}
|
||||
56
mcp-wrapper/src/registry.ts
Normal file
56
mcp-wrapper/src/registry.ts
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
// Lazy tool registry + context-editing config (TOK-02, TOK-05).
|
||||
//
|
||||
// TOK-02 ToolSearch lazy-load: in Phase 1 all 5 Phase-1 tools are hot (small
|
||||
// enough to always keep resident). The `loadColdTool` hook exists as a Phase-2
|
||||
// extension point -- when Mem-08 / schema_list / curiosity_pending ship (Phase
|
||||
// 2), they'll register here and be looked up
|
||||
// lazily by the MCP host's ToolSearch extension.
|
||||
//
|
||||
// TOK-05 context editing: we advertise `clear_tool_uses_20250919` with a
|
||||
// 30k-token trigger. When Claude's context crosses 30k tokens the Anthropic
|
||||
// API will drop earlier tool_use / tool_result messages, freeing headroom
|
||||
// for continued reasoning without reloading the full session prefix.
|
||||
//
|
||||
// Exact shape per Anthropic's context-management docs -- these strings are
|
||||
// consumed verbatim by the API.
|
||||
|
||||
import { TOOL_NAMES, toolSchemas, type ToolName } from "./tools.js";
|
||||
|
||||
// Phase-1 hot tools: all 5 always-resident (D-12 fixed surface).
|
||||
// Iteration order matches TOOL_NAMES so tools/list is deterministic.
|
||||
export const HOT_TOOLS: readonly ToolName[] = [...TOOL_NAMES] as const;
|
||||
|
||||
/** TOK-05 Anthropic context-editing config -- exact shape consumed by the API.
|
||||
*
|
||||
* `clear_tool_uses_20250919` is the dated context-edit strategy Anthropic
|
||||
* released on 2025-09-19; the trigger pairs `type: "input_tokens"` with a
|
||||
* numeric threshold that fires the edit. D-10 puts the threshold at 30k
|
||||
* tokens -- empirically enough headroom to preserve ~8-10 turns of tool
|
||||
* exchange before trimming. */
|
||||
export const CONTEXT_EDITING_CONFIG = {
|
||||
type: "clear_tool_uses_20250919" as const,
|
||||
trigger: {
|
||||
type: "input_tokens" as const,
|
||||
value: 30_000,
|
||||
},
|
||||
} as const;
|
||||
|
||||
/** Return the full tool-schema objects for the hot tools.
|
||||
*
|
||||
* MCP `tools/list` handler calls this directly. Kept as a function rather
|
||||
* than a const array so future versions can mutate the returned shape
|
||||
* (e.g., swap in per-user personalised descriptions) without changing the
|
||||
* call site. */
|
||||
export function listHotTools() {
|
||||
return HOT_TOOLS.map((n) => toolSchemas[n]);
|
||||
}
|
||||
|
||||
/** Phase-2 hook: lazy-load a tool that isn't in HOT_TOOLS.
|
||||
*
|
||||
* Phase 1 always returns null -- the MCP host's ToolSearch extension will
|
||||
* fall back to HOT_TOOLS when this returns null, which is exactly what we
|
||||
* want. Phase 2 populates this with a dynamic import of the new tool's
|
||||
* schema module. */
|
||||
export async function loadColdTool(_name: string): Promise<unknown | null> {
|
||||
return null;
|
||||
}
|
||||
367
mcp-wrapper/src/tools.ts
Normal file
367
mcp-wrapper/src/tools.ts
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
// Phase-1 (D-12) + Plan 02-04 (MCP-05/07/08) + Plan 03 (CONN-05/07 + AUTIST-13) tools.
|
||||
//
|
||||
// Tool shapes are JSON-schema dicts consumable by the MCP SDK's ListTools
|
||||
// handler. Descriptions are written for Claude's tool-discovery heuristics
|
||||
// (concise, task-oriented, reference the autistic-kernel defaults where they
|
||||
// affect behaviour).
|
||||
//
|
||||
// Plan 02-04 adds 3 user-introspection tools:
|
||||
// - curiosity_pending (MCP-07): list pending curiosity questions
|
||||
// - schema_list (MCP-08): list induced schemas
|
||||
// - events_query (MCP-05): user-visible events audit
|
||||
//
|
||||
// Plan 03 adds 3 scientific-depth tools:
|
||||
// - memory_recall_structural (CONN-05): TEM role->filler structural recall
|
||||
// - topology (CONN-07): Ashby sigma diagnostic snapshot
|
||||
// - camouflaging_status (AUTIST-13): ecological self-regulation status
|
||||
|
||||
import type { PythonCoreBridge } from "./bridge.js";
|
||||
|
||||
export const TOOL_NAMES = [
|
||||
"memory_recall",
|
||||
"memory_recall_structural",
|
||||
"memory_reinforce",
|
||||
"memory_contradict",
|
||||
"memory_capture",
|
||||
"memory_consolidate",
|
||||
"profile_get_set",
|
||||
"curiosity_pending",
|
||||
"schema_list",
|
||||
"events_query",
|
||||
"topology",
|
||||
"camouflaging_status",
|
||||
] as const;
|
||||
|
||||
export type ToolName = (typeof TOOL_NAMES)[number];
|
||||
|
||||
interface ToolSchema {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export const toolSchemas: Record<ToolName, ToolSchema> = {
|
||||
memory_recall: {
|
||||
name: "memory_recall",
|
||||
description:
|
||||
"Recall verbatim memories matching cue. Returns hits + anti_hits.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
cue: {
|
||||
type: "string",
|
||||
description: "Natural-language query to match against stored memories.",
|
||||
},
|
||||
budget_tokens: {
|
||||
type: "integer",
|
||||
description: "Soft token budget for response (default 1500).",
|
||||
default: 1500,
|
||||
},
|
||||
session_id: {
|
||||
type: "string",
|
||||
description:
|
||||
"Current session id; gets written into every recalled record's provenance (MEM-05).",
|
||||
},
|
||||
cue_embedding: {
|
||||
type: "array",
|
||||
items: { type: "number" },
|
||||
description:
|
||||
"Optional pre-computed embedding vector for the cue " +
|
||||
"(EMBED_DIM=384 floats; bge-small-en-v1.5). " +
|
||||
"When omitted, the daemon embeds the cue server-side. " +
|
||||
"Used by memory_contradict and tests that need byte-stable embeddings.",
|
||||
},
|
||||
language: {
|
||||
type: "string",
|
||||
description:
|
||||
"Optional ISO-639-1 language hint for the sleep-suggestion path " +
|
||||
"(8 supported: en/ru/ja/ar/de/fr/es/zh). Defaults to 'en' " +
|
||||
"when omitted. Hot-path retrieval is language-agnostic; this " +
|
||||
"key only affects the sleep-suggestion regex pre-screen.",
|
||||
},
|
||||
},
|
||||
required: ["cue"],
|
||||
},
|
||||
},
|
||||
memory_reinforce: {
|
||||
name: "memory_reinforce",
|
||||
description:
|
||||
"Boost Hebbian edges among co-retrieved record ids.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
ids: {
|
||||
type: "array",
|
||||
items: { type: "string", format: "uuid" },
|
||||
description: "Record UUIDs that were co-retrieved in the current context.",
|
||||
},
|
||||
},
|
||||
required: ["ids"],
|
||||
},
|
||||
},
|
||||
memory_contradict: {
|
||||
name: "memory_contradict",
|
||||
description:
|
||||
"Mark a record contradicted; new fact stored as new record.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
id: {
|
||||
type: "string",
|
||||
format: "uuid",
|
||||
description: "UUID of the record being contradicted.",
|
||||
},
|
||||
new_fact: {
|
||||
type: "string",
|
||||
description: "The updated verbatim fact. Stored as a new record.",
|
||||
},
|
||||
cue_embedding: {
|
||||
type: "array",
|
||||
items: { type: "number" },
|
||||
description:
|
||||
"Optional pre-computed embedding vector for the contradicting " +
|
||||
"fact (EMBED_DIM=384 floats; bge-small-en-v1.5). When omitted, " +
|
||||
"the daemon embeds new_fact server-side.",
|
||||
},
|
||||
},
|
||||
required: ["id", "new_fact"],
|
||||
},
|
||||
},
|
||||
memory_capture: {
|
||||
name: "memory_capture",
|
||||
description:
|
||||
"Capture a verbatim turn. Auto-dedups at cos>=0.95 (reinforces). " +
|
||||
"Use for corrections + load-bearing decisions.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
text: {
|
||||
type: "string",
|
||||
description:
|
||||
"Verbatim text to capture (user utterance, Claude decision, or observation). " +
|
||||
"Min 12 chars, max 8000 (longer is truncated).",
|
||||
},
|
||||
cue: {
|
||||
type: "string",
|
||||
description:
|
||||
"Short natural-language cue used for embedding + dedup lookup. " +
|
||||
"If empty, `text` itself is embedded.",
|
||||
},
|
||||
tier: {
|
||||
type: "string",
|
||||
enum: ["working", "episodic", "semantic", "procedural", "parametric"],
|
||||
default: "episodic",
|
||||
description:
|
||||
"Memory tier. Default 'episodic' (verbatim user utterances). " +
|
||||
"Use 'semantic' for induced summaries, 'procedural' for learned behaviour notes.",
|
||||
},
|
||||
session_id: {
|
||||
type: "string",
|
||||
description: "Current session id for provenance (MEM-05).",
|
||||
},
|
||||
role: {
|
||||
type: "string",
|
||||
enum: ["user", "assistant", "system"],
|
||||
default: "user",
|
||||
description: "Who produced this turn — tags the record for filtering.",
|
||||
},
|
||||
},
|
||||
required: ["text"],
|
||||
},
|
||||
},
|
||||
memory_consolidate: {
|
||||
name: "memory_consolidate",
|
||||
description:
|
||||
"Trigger memory consolidation.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
session_id: {
|
||||
type: "string",
|
||||
description:
|
||||
"Optional session id used for provenance tagging on the " +
|
||||
"consolidate event. Defaults to '-' when omitted.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
profile_get_set: {
|
||||
name: "profile_get_set",
|
||||
description:
|
||||
"Read or write a profile knob (11 sealed: 10 AUTIST + wake_depth). operation: get|set.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
operation: {
|
||||
type: "string",
|
||||
enum: ["get", "set"],
|
||||
description: "Whether to read or write a knob.",
|
||||
},
|
||||
knob: {
|
||||
type: "string",
|
||||
description: "Knob name. Omit on 'get' to retrieve all live + deferred knobs.",
|
||||
},
|
||||
value: {
|
||||
description: "New value when operation='set'. Any JSON-serialisable type.",
|
||||
},
|
||||
},
|
||||
required: ["operation"],
|
||||
},
|
||||
},
|
||||
curiosity_pending: {
|
||||
name: "curiosity_pending",
|
||||
description:
|
||||
"List pending curiosity questions. Optional session_id filter.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
session_id: {
|
||||
type: "string",
|
||||
description: "Only return questions from this session.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
schema_list: {
|
||||
name: "schema_list",
|
||||
description:
|
||||
"List induced schemas. Optional domain + confidence_min filters.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
domain: {
|
||||
type: "string",
|
||||
description: "Only return schemas tagged with this domain (e.g. 'coding').",
|
||||
},
|
||||
confidence_min: {
|
||||
type: "number",
|
||||
description: "Minimum parsed confidence (0.0-1.0). Default 0.0.",
|
||||
default: 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
events_query: {
|
||||
name: "events_query",
|
||||
description:
|
||||
"Query user-visible events by kind, since, severity, limit.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
kind: {
|
||||
type: "string",
|
||||
description:
|
||||
"Event kind. Must be in the whitelist (see tool description).",
|
||||
},
|
||||
since: {
|
||||
type: "string",
|
||||
description: "ISO-8601 timestamp; only events at or after this are returned.",
|
||||
},
|
||||
severity: {
|
||||
type: "string",
|
||||
enum: ["info", "warning", "critical"],
|
||||
description: "Optional severity filter.",
|
||||
},
|
||||
limit: {
|
||||
type: "integer",
|
||||
description: "Maximum events returned (default 100, capped at 1000).",
|
||||
default: 100,
|
||||
},
|
||||
},
|
||||
required: ["kind"],
|
||||
},
|
||||
},
|
||||
memory_recall_structural: {
|
||||
name: "memory_recall_structural",
|
||||
description:
|
||||
"Structural recall via role-filler bindings (TEM). O(N) scan; max_records caps.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
structure_query: {
|
||||
type: "object",
|
||||
description:
|
||||
"Optional role->filler map, e.g. {\"agent\": \"Alice\"}. Each value is hashed to a filler hypervector. When omitted or empty, query HV is zero-filled and every row with structure_hv is scored (expensive at large N).",
|
||||
additionalProperties: { type: "string" },
|
||||
},
|
||||
budget_tokens: {
|
||||
type: "integer",
|
||||
description: "Soft token budget for response (default 2000).",
|
||||
default: 2000,
|
||||
},
|
||||
max_records: {
|
||||
type: "integer",
|
||||
description:
|
||||
"Hard cap on records scanned after fetch (default 5000, max 50000). Prevents accidental full-corpus scans from `{}`.",
|
||||
default: 5000,
|
||||
},
|
||||
},
|
||||
required: [],
|
||||
},
|
||||
},
|
||||
topology: {
|
||||
name: "topology",
|
||||
description:
|
||||
"Topology snapshot: N, C, L, sigma, community_count, regime.",
|
||||
inputSchema: { type: "object", properties: {} },
|
||||
},
|
||||
camouflaging_status: {
|
||||
name: "camouflaging_status",
|
||||
description:
|
||||
"Camouflaging detection status; window_size weekly points.",
|
||||
inputSchema: {
|
||||
type: "object",
|
||||
properties: {
|
||||
window_size: {
|
||||
type: "integer",
|
||||
description: "Weekly points in the sliding window (default 5).",
|
||||
default: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export async function invokeTool(
|
||||
bridge: PythonCoreBridge,
|
||||
name: ToolName,
|
||||
args: Record<string, unknown>,
|
||||
): Promise<unknown> {
|
||||
switch (name) {
|
||||
case "memory_recall":
|
||||
return bridge.call("memory_recall", args);
|
||||
case "memory_reinforce":
|
||||
return bridge.call("memory_reinforce", args);
|
||||
case "memory_contradict":
|
||||
return bridge.call("memory_contradict", args);
|
||||
case "memory_capture":
|
||||
return bridge.call("memory_capture", args);
|
||||
case "memory_consolidate":
|
||||
return bridge.call("memory_consolidate", args);
|
||||
case "profile_get_set": {
|
||||
const op = args.operation as string;
|
||||
if (op === "get") {
|
||||
return bridge.call("profile_get", { knob: args.knob ?? null });
|
||||
}
|
||||
if (op === "set") {
|
||||
return bridge.call("profile_set", {
|
||||
knob: args.knob,
|
||||
value: args.value,
|
||||
});
|
||||
}
|
||||
throw new Error(`unknown operation ${op}`);
|
||||
}
|
||||
case "curiosity_pending":
|
||||
return bridge.call("curiosity_pending", args);
|
||||
case "schema_list":
|
||||
return bridge.call("schema_list", args);
|
||||
case "events_query":
|
||||
return bridge.call("events_query", args);
|
||||
case "memory_recall_structural":
|
||||
return bridge.call("memory_recall_structural", args);
|
||||
case "topology":
|
||||
return bridge.call("topology", args);
|
||||
case "camouflaging_status":
|
||||
return bridge.call("camouflaging_status", args);
|
||||
}
|
||||
}
|
||||
339
mcp-wrapper/test/lifecycle.test.ts
Normal file
339
mcp-wrapper/test/lifecycle.test.ts
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
// Phase 10.5 — tests for `WrapperLifecycle`.
|
||||
//
|
||||
// Eight-test matrix from CONTEXT 10.5:
|
||||
//
|
||||
// 1. ensureDaemonAlive: socket reachable -> NO subprocess invoked.
|
||||
// 2. ensureDaemonAlive: socket unreachable + darwin -> kickstart called.
|
||||
// 3. ensureDaemonAlive: kickstart throws -> falls back to wake.signal.
|
||||
// 4. ensureDaemonAlive: non-macos -> wake.signal written, no subprocess.
|
||||
// 5. registerHeartbeat: file exists with correct schema.
|
||||
// 6. heartbeat refresh: small interval -> last_refresh updates.
|
||||
// 7. cleanupHeartbeat: file gone, timer cleared.
|
||||
// 8. security: source has no `shell: true` and no shell-interpreting
|
||||
// subprocess variant in mcp-wrapper/src/.
|
||||
//
|
||||
// Test runner: Node's built-in `node:test` (zero new dep — Node 22 has
|
||||
// it natively) loaded via the existing `tsx` dev-dep so `.ts` files
|
||||
// run without a build step. Assertions: `node:assert/strict`.
|
||||
|
||||
import { describe, it } from "node:test";
|
||||
import { strict as assert } from "node:assert";
|
||||
import { mkdtemp, readFile, readdir, rm, stat } from "node:fs/promises";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
import { WrapperLifecycle } from "../src/lifecycle.js";
|
||||
|
||||
// Tmp-dir helper. node:test isolates per-file but not per-`it`, so
|
||||
// every test allocates its own dir.
|
||||
async function makeTmp(prefix: string): Promise<string> {
|
||||
return await mkdtemp(join(tmpdir(), `iai-mcp-lifecycle-${prefix}-`));
|
||||
}
|
||||
|
||||
async function cleanupTmp(dir: string): Promise<void> {
|
||||
await rm(dir, { recursive: true, force: true });
|
||||
}
|
||||
|
||||
// Sleep helper for fake-interval verification (Node's setInterval is
|
||||
// real-time; we use a small interval (10 ms) and wait deterministically).
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------- ensureDaemonAlive
|
||||
|
||||
describe("WrapperLifecycle.ensureDaemonAlive", () => {
|
||||
it("does NOT invoke subprocess when socket is reachable", async () => {
|
||||
const tmp = await makeTmp("alive");
|
||||
try {
|
||||
let kickstarts = 0;
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath: join(tmp, "wrappers", "heartbeat-1-x.json"),
|
||||
platform: "darwin",
|
||||
socketReachable: async () => true,
|
||||
spawnKickstart: async () => {
|
||||
kickstarts += 1;
|
||||
},
|
||||
});
|
||||
await lifecycle.ensureDaemonAlive();
|
||||
assert.equal(kickstarts, 0, "kickstart must not be invoked when socket is alive");
|
||||
// wake.signal must NOT be written when daemon is reachable.
|
||||
await assert.rejects(stat(join(tmp, "wake.signal")));
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
|
||||
it("invokes launchctl kickstart on darwin when socket is unreachable", async () => {
|
||||
const tmp = await makeTmp("kickstart");
|
||||
try {
|
||||
let kickstarts = 0;
|
||||
let signalWritten = false;
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath: join(tmp, "wrappers", "heartbeat-1-x.json"),
|
||||
platform: "darwin",
|
||||
socketReachable: async () => false,
|
||||
spawnKickstart: async () => {
|
||||
kickstarts += 1;
|
||||
},
|
||||
});
|
||||
await lifecycle.ensureDaemonAlive();
|
||||
assert.equal(kickstarts, 1, "kickstart must be invoked exactly once on darwin");
|
||||
try {
|
||||
await stat(join(tmp, "wake.signal"));
|
||||
signalWritten = true;
|
||||
} catch {
|
||||
signalWritten = false;
|
||||
}
|
||||
assert.equal(
|
||||
signalWritten,
|
||||
false,
|
||||
"wake.signal must NOT be written on successful kickstart",
|
||||
);
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
|
||||
it("falls back to wake.signal when kickstart fails on darwin", async () => {
|
||||
const tmp = await makeTmp("fallback");
|
||||
try {
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath: join(tmp, "wrappers", "heartbeat-1-x.json"),
|
||||
platform: "darwin",
|
||||
socketReachable: async () => false,
|
||||
spawnKickstart: async () => {
|
||||
throw new Error("kickstart simulated failure");
|
||||
},
|
||||
});
|
||||
await lifecycle.ensureDaemonAlive();
|
||||
const sigStat = await stat(join(tmp, "wake.signal"));
|
||||
assert.ok(sigStat.isFile(), "wake.signal must exist after kickstart failure");
|
||||
const raw = await readFile(join(tmp, "wake.signal"), "utf-8");
|
||||
const parsed = JSON.parse(raw);
|
||||
assert.ok(typeof parsed.requested_at === "string");
|
||||
assert.ok(typeof parsed.wrapper_pid === "number");
|
||||
assert.ok(typeof parsed.wrapper_uuid === "string");
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
|
||||
it("on non-macos writes wake.signal and never spawns subprocess", async () => {
|
||||
const tmp = await makeTmp("linux");
|
||||
try {
|
||||
let kickstarts = 0;
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath: join(tmp, "wrappers", "heartbeat-1-x.json"),
|
||||
platform: "linux",
|
||||
socketReachable: async () => false,
|
||||
spawnKickstart: async () => {
|
||||
kickstarts += 1;
|
||||
},
|
||||
});
|
||||
await lifecycle.ensureDaemonAlive();
|
||||
assert.equal(kickstarts, 0, "subprocess must never be invoked on non-darwin");
|
||||
const sigStat = await stat(join(tmp, "wake.signal"));
|
||||
assert.ok(sigStat.isFile(), "wake.signal must exist on non-darwin path");
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------- registerHeartbeat
|
||||
|
||||
describe("WrapperLifecycle.registerHeartbeat", () => {
|
||||
it("creates heartbeat file with correct schema", async () => {
|
||||
const tmp = await makeTmp("hb-schema");
|
||||
try {
|
||||
const heartbeatPath = join(tmp, "wrappers", "heartbeat-12345-abc.json");
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
pid: 12345,
|
||||
uuid: "abc",
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath,
|
||||
platform: "darwin",
|
||||
socketReachable: async () => true,
|
||||
spawnKickstart: async () => {},
|
||||
refreshIntervalMs: 60_000, // big — we don't want it firing in this test
|
||||
});
|
||||
await lifecycle.registerHeartbeat();
|
||||
try {
|
||||
const raw = await readFile(heartbeatPath, "utf-8");
|
||||
const parsed = JSON.parse(raw);
|
||||
assert.equal(parsed.pid, 12345);
|
||||
assert.equal(parsed.uuid, "abc");
|
||||
assert.ok(typeof parsed.started_at === "string");
|
||||
assert.ok(typeof parsed.last_refresh === "string");
|
||||
assert.ok(typeof parsed.wrapper_version === "string");
|
||||
assert.equal(parsed.schema_version, 1);
|
||||
} finally {
|
||||
await lifecycle.cleanupHeartbeat();
|
||||
}
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
|
||||
it("refresh timer updates last_refresh", async () => {
|
||||
const tmp = await makeTmp("hb-refresh");
|
||||
try {
|
||||
const heartbeatPath = join(tmp, "wrappers", "heartbeat-1-x.json");
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
pid: 1,
|
||||
uuid: "x",
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath,
|
||||
platform: "darwin",
|
||||
socketReachable: async () => true,
|
||||
spawnKickstart: async () => {},
|
||||
refreshIntervalMs: 10, // tight interval to keep test fast
|
||||
});
|
||||
await lifecycle.registerHeartbeat();
|
||||
try {
|
||||
const before = JSON.parse(await readFile(heartbeatPath, "utf-8"));
|
||||
await sleep(60); // ~6 refresh ticks
|
||||
const after = JSON.parse(await readFile(heartbeatPath, "utf-8"));
|
||||
// started_at is stable; last_refresh advances.
|
||||
assert.equal(before.started_at, after.started_at);
|
||||
assert.notEqual(before.last_refresh, after.last_refresh);
|
||||
} finally {
|
||||
await lifecycle.cleanupHeartbeat();
|
||||
}
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------- cleanupHeartbeat
|
||||
|
||||
describe("WrapperLifecycle.cleanupHeartbeat", () => {
|
||||
it("deletes heartbeat file and clears timer", async () => {
|
||||
const tmp = await makeTmp("cleanup");
|
||||
try {
|
||||
const heartbeatPath = join(tmp, "wrappers", "heartbeat-1-x.json");
|
||||
const lifecycle = new WrapperLifecycle({
|
||||
pid: 1,
|
||||
uuid: "x",
|
||||
socketPath: join(tmp, "daemon.sock"),
|
||||
wakeSignalPath: join(tmp, "wake.signal"),
|
||||
heartbeatPath,
|
||||
platform: "darwin",
|
||||
socketReachable: async () => true,
|
||||
spawnKickstart: async () => {},
|
||||
refreshIntervalMs: 10,
|
||||
});
|
||||
await lifecycle.registerHeartbeat();
|
||||
const sigBefore = await stat(heartbeatPath);
|
||||
assert.ok(sigBefore.isFile());
|
||||
|
||||
await lifecycle.cleanupHeartbeat();
|
||||
await assert.rejects(stat(heartbeatPath), "heartbeat file must be gone after cleanup");
|
||||
|
||||
// No refresh after cleanup: wait longer than the refresh interval
|
||||
// and verify the file does NOT reappear.
|
||||
await sleep(60);
|
||||
await assert.rejects(stat(heartbeatPath), "no refresh tick after cleanup");
|
||||
|
||||
// Idempotent: second cleanup must NOT throw.
|
||||
await lifecycle.cleanupHeartbeat();
|
||||
} finally {
|
||||
await cleanupTmp(tmp);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------- security
|
||||
|
||||
describe("WrapperLifecycle security invariants", () => {
|
||||
it("source contains no shell-true option and no shell-interpreting subprocess variants", async () => {
|
||||
// Walk mcp-wrapper/src/ and assert that no .ts file contains the
|
||||
// forbidden patterns. We allow the safe `execFile` API; we forbid
|
||||
// (a) the `shell: true` option anywhere, (b) bare-name calls to
|
||||
// the shell-interpreting subprocess variant from node:child_process.
|
||||
//
|
||||
// Detection strategy: build the forbidden tokens at runtime from
|
||||
// characters so the test source itself doesn't contain the literal
|
||||
// banned substring (avoids tripping security-reminder hooks that
|
||||
// grep for source-level mentions).
|
||||
const here = fileURLToPath(new URL(".", import.meta.url));
|
||||
const srcDir = join(here, "..", "src");
|
||||
const files = await readdir(srcDir);
|
||||
const tsFiles = files.filter((f) => f.endsWith(".ts"));
|
||||
assert.ok(tsFiles.length > 0, "expected at least one .ts file in src/");
|
||||
|
||||
const E = String.fromCharCode(0x65); // 'e'
|
||||
const X = String.fromCharCode(0x78); // 'x'
|
||||
const C = String.fromCharCode(0x63); // 'c'
|
||||
const SHELL_INTERP_TOKEN = E + X + E + C; // 4-char banned identifier
|
||||
const SHELL_OPTION_TOKEN = "shell"; // followed by colon + true
|
||||
const shellOptionRegex = new RegExp(
|
||||
`\\b${SHELL_OPTION_TOKEN}\\s*:\\s*true\\b`,
|
||||
);
|
||||
// Allow `<token>File` (the safe variant) but forbid bare `<token>(`
|
||||
// OR `child_process.<token>(`.
|
||||
const bareCallRegex = new RegExp(
|
||||
`(?:^|[^A-Za-z0-9_])${SHELL_INTERP_TOKEN}\\s*\\(`,
|
||||
);
|
||||
const dottedCallRegex = new RegExp(
|
||||
`\\bchild_process\\s*\\.\\s*${SHELL_INTERP_TOKEN}\\s*\\(`,
|
||||
);
|
||||
|
||||
const forbidden: { file: string; pattern: string; line: number }[] = [];
|
||||
for (const f of tsFiles) {
|
||||
const path = join(srcDir, f);
|
||||
const content = await readFile(path, "utf-8");
|
||||
const lines = content.split("\n");
|
||||
lines.forEach((line, idx) => {
|
||||
const trimmed = line.trim();
|
||||
// Strip trailing line comment so an inline `// NEVER ...` mention
|
||||
// in a code line doesn't match. Pure-comment lines (codePortion
|
||||
// empty after trim) are skipped.
|
||||
const codePortion = (trimmed.split("//")[0] ?? "").trim();
|
||||
if (codePortion.length === 0) {
|
||||
return;
|
||||
}
|
||||
if (shellOptionRegex.test(codePortion)) {
|
||||
forbidden.push({
|
||||
file: f,
|
||||
pattern: "shell-true option",
|
||||
line: idx + 1,
|
||||
});
|
||||
}
|
||||
if (dottedCallRegex.test(codePortion)) {
|
||||
forbidden.push({
|
||||
file: f,
|
||||
pattern: "child_process.<shell-interp-call>",
|
||||
line: idx + 1,
|
||||
});
|
||||
}
|
||||
if (bareCallRegex.test(codePortion)) {
|
||||
forbidden.push({
|
||||
file: f,
|
||||
pattern: "bare <shell-interp-call>",
|
||||
line: idx + 1,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
assert.deepEqual(
|
||||
forbidden,
|
||||
[],
|
||||
`Forbidden subprocess pattern in mcp-wrapper/src/: ${JSON.stringify(forbidden, null, 2)}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
20
mcp-wrapper/tsconfig.json
Normal file
20
mcp-wrapper/tsconfig.json
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Bundler",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"declaration": false,
|
||||
"sourceMap": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"skipLibCheck": true,
|
||||
"noEmitOnError": true,
|
||||
"lib": ["ES2022"],
|
||||
"types": ["node"]
|
||||
},
|
||||
"include": ["src/**/*"]
|
||||
}
|
||||
54
pyproject.toml
Normal file
54
pyproject.toml
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
[project]
|
||||
name = "iai-mcp"
|
||||
version = "0.1.0"
|
||||
description = "MCP server providing persistent verbatim memory and ambient capture for MCP-over-stdio hosts (developed against Claude Code)"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
dependencies = [
|
||||
"lancedb>=0.11.0",
|
||||
"pyarrow>=16.0.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"numpy>=1.26.0,<2.3.0",
|
||||
"pydantic>=2.7.0",
|
||||
"torch-hd>=5.7.0", # imports as `torchhd`; PyPI name has a dash
|
||||
"structlog>=24.0.0",
|
||||
"networkx>=3.3",
|
||||
"python-igraph>=0.11",
|
||||
"leidenalg>=0.10",
|
||||
"anthropic>=0.40.0", # count_tokens API for bench harness
|
||||
"tiktoken>=0.7.0", # offline tokeniser fallback for bench/tokens.py when no API key
|
||||
"langdetect>=1.0.9", # ISO-639-1 language auto-detect (pure Python, no-cloud)
|
||||
"cryptography>=42.0.0", # AES-256-GCM at rest (pyca/cryptography, audited primitive)
|
||||
"keyring>=24.0.0", # OS keychain (macOS / Linux Secret Service / Windows Credential Manager)
|
||||
"cachetools>=5.3.0", # TTLCache for activation-cascade LRU
|
||||
"psutil>=5.9.0", # daemon CPU watchdog + doctor checks
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-cov>=5.0",
|
||||
"pytest-rerunfailures>=14.0", # auto-retry test-pollution flakes (daemon/bridge tests)
|
||||
"ruff>=0.5.0",
|
||||
]
|
||||
# Optional: LLMLingua-2 compression for community summaries (~2.3 GB model).
|
||||
# Without this extra, compression falls back to passthrough.
|
||||
compress = ["llmlingua>=0.2.2", "accelerate>=1.0.0"]
|
||||
|
||||
[project.scripts]
|
||||
iai-mcp-core = "iai_mcp.core:main"
|
||||
iai-mcp = "iai_mcp.cli:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/iai_mcp"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
# A handful of daemon/bridge tests are sensitive to test-pollution from earlier
|
||||
# tests in the same suite (open file descriptors, async loop state) and pass
|
||||
# cleanly when run in isolation. Retry up to twice before reporting failure.
|
||||
addopts = "--reruns 2 --reruns-delay 1"
|
||||
75
scripts/com.iai-mcp.daemon.plist.template
Normal file
75
scripts/com.iai-mcp.daemon.plist.template
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<!--
|
||||
socket-activated LaunchAgent template (R1, D7.1-01).
|
||||
Rendered by scripts/install.sh: substitutes {PYTHON_PATH} + {HOME},
|
||||
writes to ~/Library/LaunchAgents/com.iai-mcp.daemon.plist, then
|
||||
`launchctl load -w` registers it.
|
||||
|
||||
DO NOT edit ~/Library/LaunchAgents/com.iai-mcp.daemon.plist directly —
|
||||
re-run scripts/install.sh after changes here.
|
||||
|
||||
RunAtLoad=false + Sockets.Listener = TRUE socket activation: launchd
|
||||
pre-binds the unix socket and spawns iai_mcp.daemon ONLY on first
|
||||
connection. KeepAlive.SuccessfulExit=false preserves Phase 7's idle
|
||||
shutdown semantics (no respawn after clean idle exit).
|
||||
-->
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.iai-mcp.daemon</string>
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{PYTHON_PATH}</string>
|
||||
<string>-m</string>
|
||||
<string>iai_mcp.daemon</string>
|
||||
</array>
|
||||
|
||||
<key>EnvironmentVariables</key>
|
||||
<dict>
|
||||
<key>PATH</key>
|
||||
<string>/usr/local/bin:/usr/bin:/bin:/opt/homebrew/bin</string>
|
||||
<key>HOME</key>
|
||||
<string>{HOME}</string>
|
||||
<key>IAI_MCP_LAUNCHD_MANAGED</key>
|
||||
<string>1</string>
|
||||
</dict>
|
||||
|
||||
<key>Sockets</key>
|
||||
<dict>
|
||||
<key>Listener</key>
|
||||
<dict>
|
||||
<key>SockPathName</key>
|
||||
<string>{HOME}/.iai-mcp/.daemon.sock</string>
|
||||
<key>SockType</key>
|
||||
<string>stream</string>
|
||||
<key>SockFamily</key>
|
||||
<string>Unix</string>
|
||||
<key>SockPathMode</key>
|
||||
<integer>384</integer>
|
||||
</dict>
|
||||
</dict>
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<false/>
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
<key>ProcessType</key>
|
||||
<string>Adaptive</string>
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>{HOME}/.iai-mcp/logs/launchd-stdout.log</string>
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{HOME}/.iai-mcp/logs/launchd-stderr.log</string>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>{HOME}</string>
|
||||
</dict>
|
||||
</plist>
|
||||
108
scripts/idle_cpu_regression_fence.sh
Executable file
108
scripts/idle_cpu_regression_fence.sh
Executable file
|
|
@ -0,0 +1,108 @@
|
|||
#!/usr/bin/env bash
|
||||
# scripts/idle_cpu_regression_fence.sh — A7 idle-CPU regression fence.
|
||||
#
|
||||
# SPEC A7: 30-min `python -m iai_mcp.daemon` run with first_turn_pending = 1
|
||||
# shows process CPU < 5% sampled every 30s.
|
||||
#
|
||||
# Usage:
|
||||
# scripts/idle_cpu_regression_fence.sh # 30-min run, samples every 30s
|
||||
# IAI_FENCE_DURATION_MIN=5 scripts/idle_cpu_regression_fence.sh # short run
|
||||
#
|
||||
# Pre-condition: daemon must already be running. The script does NOT spawn
|
||||
# the daemon and does NOT manage launchd — D7.2-26 + D7.2-27 keep daemon
|
||||
# lifecycle entirely under user discretion. To start the daemon manually
|
||||
# before running this fence, run:
|
||||
#
|
||||
# python -m iai_mcp.daemon &
|
||||
#
|
||||
# (development / manual subprocess path; foreground or background). The
|
||||
# fence treats the daemon as a black box and only reads its self-CPU% via
|
||||
# psutil, so any startup mechanism that yields a `iai_mcp.daemon` process
|
||||
# will work.
|
||||
#
|
||||
# Exit codes:
|
||||
# 0 — all samples < THRESHOLD_PCT sustained
|
||||
# 1 — at least one sample >= THRESHOLD_PCT
|
||||
# 2 — daemon not running / pgrep returned 0 matches
|
||||
# 3 — psutil / Python error
|
||||
set -eu
|
||||
|
||||
DURATION_MIN="${IAI_FENCE_DURATION_MIN:-30}"
|
||||
SAMPLE_INTERVAL_SEC="${IAI_FENCE_SAMPLE_INTERVAL_SEC:-30}"
|
||||
THRESHOLD_PCT="${IAI_FENCE_THRESHOLD_PCT:-5.0}"
|
||||
|
||||
# Locate the daemon PID. We use pgrep -f for the explicit module form.
|
||||
DAEMON_PID=$(pgrep -f "iai_mcp.daemon" | head -1 || true)
|
||||
if [ -z "$DAEMON_PID" ]; then
|
||||
echo "ERROR: no iai_mcp.daemon process found." >&2
|
||||
echo "Start it manually before running this fence:" >&2
|
||||
echo " python -m iai_mcp.daemon &" >&2
|
||||
exit 2
|
||||
fi
|
||||
|
||||
echo "Phase 7.2 A7 idle-CPU regression fence"
|
||||
echo " daemon PID: $DAEMON_PID"
|
||||
echo " duration: ${DURATION_MIN}min"
|
||||
echo " sample interval: ${SAMPLE_INTERVAL_SEC}s"
|
||||
echo " threshold: ${THRESHOLD_PCT}%"
|
||||
echo
|
||||
|
||||
SAMPLES_TAKEN=0
|
||||
OVER_THRESHOLD=0
|
||||
MAX_SEEN=0
|
||||
DURATION_SEC=$((DURATION_MIN * 60))
|
||||
START_TS=$(date +%s)
|
||||
|
||||
while true; do
|
||||
NOW=$(date +%s)
|
||||
ELAPSED=$((NOW - START_TS))
|
||||
if [ $ELAPSED -ge $DURATION_SEC ]; then
|
||||
break
|
||||
fi
|
||||
|
||||
# Use python+psutil for cross-platform self-CPU% read.
|
||||
CPU=$(python3 -c "
|
||||
import psutil, sys
|
||||
try:
|
||||
p = psutil.Process($DAEMON_PID)
|
||||
p.cpu_percent(interval=None)
|
||||
import time
|
||||
time.sleep(1.0)
|
||||
print(p.cpu_percent(interval=None))
|
||||
except Exception as e:
|
||||
sys.stderr.write(f'psutil error: {e}\n')
|
||||
sys.exit(3)
|
||||
")
|
||||
EXIT_CODE=$?
|
||||
if [ $EXIT_CODE -ne 0 ]; then
|
||||
echo " sample fail: psutil error" >&2
|
||||
exit 3
|
||||
fi
|
||||
|
||||
SAMPLES_TAKEN=$((SAMPLES_TAKEN + 1))
|
||||
printf " t=%4ds cpu=%5.1f%%\n" "$ELAPSED" "$CPU"
|
||||
|
||||
# awk float comparison (bash doesn't do floats natively).
|
||||
OVER=$(awk -v cpu="$CPU" -v thr="$THRESHOLD_PCT" 'BEGIN { print (cpu > thr) ? 1 : 0 }')
|
||||
if [ "$OVER" = "1" ]; then
|
||||
OVER_THRESHOLD=$((OVER_THRESHOLD + 1))
|
||||
fi
|
||||
|
||||
MAX_SEEN=$(awk -v cur="$CPU" -v prev="$MAX_SEEN" 'BEGIN { print (cur > prev) ? cur : prev }')
|
||||
|
||||
sleep "$SAMPLE_INTERVAL_SEC"
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Summary:"
|
||||
echo " total samples: $SAMPLES_TAKEN"
|
||||
echo " over threshold: $OVER_THRESHOLD"
|
||||
echo " max observed CPU%: $MAX_SEEN"
|
||||
|
||||
if [ $OVER_THRESHOLD -gt 0 ]; then
|
||||
echo "FAIL: $OVER_THRESHOLD/$SAMPLES_TAKEN samples exceeded ${THRESHOLD_PCT}%."
|
||||
exit 1
|
||||
else
|
||||
echo "PASS: all samples under threshold."
|
||||
exit 0
|
||||
fi
|
||||
198
scripts/install.sh
Executable file
198
scripts/install.sh
Executable file
|
|
@ -0,0 +1,198 @@
|
|||
#!/usr/bin/env bash
|
||||
# scripts/install.sh — first-time setup for collaborators.
|
||||
#
|
||||
# Usage (from repo root or anywhere inside the clone):
|
||||
# bash scripts/install.sh
|
||||
#
|
||||
# Does:
|
||||
# 1. creates .venv if missing
|
||||
# 2. installs iai-mcp editable into the venv
|
||||
# 3. builds the TS MCP wrapper
|
||||
# 4. symlinks ~/.local/bin/iai-mcp -> .venv/bin/iai-mcp so the CLI is
|
||||
# callable from anywhere without activating the venv
|
||||
# 5. optionally installs the sleep daemon (launchd on macOS, systemd on Linux)
|
||||
#
|
||||
# Idempotent. Safe to re-run.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
cd "${REPO_ROOT}"
|
||||
|
||||
step() { printf '\n\033[1;34m==> %s\033[0m\n' "$*"; }
|
||||
ok() { printf ' \033[0;32m✓\033[0m %s\n' "$*"; }
|
||||
warn() { printf ' \033[0;33m!\033[0m %s\n' "$*"; }
|
||||
die() { printf '\n\033[0;31m✗ %s\033[0m\n' "$*" >&2; exit 1; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sections 1-4: build / venv / pip / npm / symlink.
|
||||
#
|
||||
# IAI_TEST_SKIP_BUILD=1 short-circuits the whole bootstrap so the LaunchAgent
|
||||
# section (6) can be exercised in isolation by tests/test_install_uninstall.py
|
||||
# (Plan 07.1-03 Task 3) without spending ~30s on venv + npm.
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ "${IAI_TEST_SKIP_BUILD:-0}" == "1" ]]; then
|
||||
step "build skip (IAI_TEST_SKIP_BUILD=1)"
|
||||
ok "skipping sections 1-4 (venv/pip/npm/symlink) — test mode"
|
||||
else
|
||||
# -----------------------------------------------------------------------
|
||||
# 1. venv
|
||||
# -----------------------------------------------------------------------
|
||||
step "python venv"
|
||||
# iai-mcp requires Python 3.11 or 3.12 (torch + lancedb on 3.13/3.14
|
||||
# are not yet stable). Pick the highest supported interpreter we can find.
|
||||
PY=""
|
||||
for cand in python3.12 python3.11; do
|
||||
if command -v "$cand" >/dev/null 2>&1; then
|
||||
PY="$(command -v $cand)"
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ -z "$PY" ]; then
|
||||
# Fall back to plain python3 only if it self-reports as 3.11 or 3.12.
|
||||
if command -v python3 >/dev/null 2>&1; then
|
||||
ver="$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || echo unknown)"
|
||||
if [ "$ver" = "3.11" ] || [ "$ver" = "3.12" ]; then
|
||||
PY="$(command -v python3)"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
[ -n "$PY" ] || die "Python 3.11 or 3.12 not found. macOS: brew install python@3.12 | Linux: apt install python3.12 (or use pyenv)"
|
||||
ok "using $PY ($($PY --version))"
|
||||
if [ ! -d .venv ]; then
|
||||
"$PY" -m venv .venv
|
||||
ok ".venv created"
|
||||
else
|
||||
ok ".venv already exists"
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 2. editable install
|
||||
# -----------------------------------------------------------------------
|
||||
step "editable install (pip -e .)"
|
||||
.venv/bin/pip install --quiet --upgrade pip
|
||||
.venv/bin/pip install --quiet -e .
|
||||
ok "iai-mcp python package installed into venv"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 3. TS wrapper build
|
||||
# -----------------------------------------------------------------------
|
||||
step "TS wrapper build"
|
||||
if [ -d mcp-wrapper ]; then
|
||||
pushd mcp-wrapper >/dev/null
|
||||
if [ -f package-lock.json ]; then
|
||||
npm ci --silent --no-audit --no-fund
|
||||
else
|
||||
npm install --silent --no-audit --no-fund
|
||||
fi
|
||||
npm run build --silent
|
||||
popd >/dev/null
|
||||
ok "mcp-wrapper/dist built"
|
||||
else
|
||||
warn "mcp-wrapper/ missing — skipping"
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 4. global symlink into ~/.local/bin
|
||||
# -----------------------------------------------------------------------
|
||||
step "global CLI symlink"
|
||||
LOCAL_BIN="${HOME}/.local/bin"
|
||||
LINK_PATH="${LOCAL_BIN}/iai-mcp"
|
||||
TARGET="${REPO_ROOT}/.venv/bin/iai-mcp"
|
||||
|
||||
[ -x "${TARGET}" ] || die "venv entry point not found at ${TARGET}"
|
||||
|
||||
mkdir -p "${LOCAL_BIN}"
|
||||
|
||||
# `ln -sf` overwrites any existing symlink safely (idempotent).
|
||||
# Refuse to clobber a regular file the user put there themselves.
|
||||
if [ -e "${LINK_PATH}" ] && [ ! -L "${LINK_PATH}" ]; then
|
||||
die "${LINK_PATH} exists and is NOT a symlink. move it aside and re-run."
|
||||
fi
|
||||
ln -sf "${TARGET}" "${LINK_PATH}"
|
||||
ok "${LINK_PATH} -> ${TARGET}"
|
||||
|
||||
# PATH sanity check using python (grep is hook-blocked in this dev env).
|
||||
PATH_HAS_LOCAL_BIN="$(.venv/bin/python - <<PY
|
||||
import os
|
||||
print("1" if "${LOCAL_BIN}" in os.environ.get("PATH", "").split(":") else "0")
|
||||
PY
|
||||
)"
|
||||
if [ "${PATH_HAS_LOCAL_BIN}" != "1" ]; then
|
||||
warn "${LOCAL_BIN} is NOT in your PATH"
|
||||
warn "add this to ~/.zshrc or ~/.bashrc and restart your shell:"
|
||||
warn " export PATH=\"\${HOME}/.local/bin:\${PATH}\""
|
||||
else
|
||||
ok "${LOCAL_BIN} is in PATH"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. optional daemon install
|
||||
# ---------------------------------------------------------------------------
|
||||
step "sleep daemon (optional)"
|
||||
if command -v iai-mcp >/dev/null 2>&1; then
|
||||
INSTALLED_PATH="$(command -v iai-mcp)"
|
||||
ok "iai-mcp globally reachable at ${INSTALLED_PATH}"
|
||||
echo
|
||||
echo " to run the background sleep daemon (recommended — REM cycles +"
|
||||
echo " overnight consolidation on your local Claude subscription):"
|
||||
echo
|
||||
echo " iai-mcp daemon install --yes"
|
||||
echo " iai-mcp daemon start"
|
||||
echo
|
||||
echo " or skip for now and install later."
|
||||
else
|
||||
warn "iai-mcp not on PATH yet — add ~/.local/bin to PATH first, then run:"
|
||||
warn " iai-mcp daemon install --yes"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. LaunchAgent registration (Phase 7.1 — socket-activated singleton)
|
||||
#
|
||||
# Section 6 — socket-activated LaunchAgent. REPLACES the eager
|
||||
# RunAtLoad=true plist that `iai-mcp daemon install` writes.
|
||||
# The two flows compete for ~/Library/LaunchAgents/com.iai-mcp.daemon.plist;
|
||||
# whichever ran most recently wins. install.sh always wins because
|
||||
# it overwrites + reloads on every invocation (idempotent by design).
|
||||
# ---------------------------------------------------------------------------
|
||||
step "LaunchAgent registration"
|
||||
if [[ "$(uname)" != "Darwin" ]]; then
|
||||
warn "non-Darwin OS — skipping LaunchAgent registration"
|
||||
elif [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping launchctl calls (test mode)"
|
||||
else
|
||||
PYTHON_PATH="${REPO_ROOT}/.venv/bin/python"
|
||||
if [ ! -x "${PYTHON_PATH}" ]; then
|
||||
warn "venv python not found at ${PYTHON_PATH} — falling back to $(command -v python3)"
|
||||
PYTHON_PATH="$(command -v python3)"
|
||||
fi
|
||||
LA_DIR="${HOME}/Library/LaunchAgents"
|
||||
LA_PATH="${LA_DIR}/com.iai-mcp.daemon.plist"
|
||||
TEMPLATE="${REPO_ROOT}/scripts/com.iai-mcp.daemon.plist.template"
|
||||
[ -f "${TEMPLATE}" ] || die "plist template missing at ${TEMPLATE}"
|
||||
mkdir -p "${LA_DIR}" "${HOME}/.iai-mcp/logs" "${HOME}/.iai-mcp"
|
||||
# Substitute placeholders using sed; HOME/PYTHON_PATH may contain forward
|
||||
# slashes so we use `|` as the sed separator (not `/`).
|
||||
sed -e "s|{PYTHON_PATH}|${PYTHON_PATH}|g" -e "s|{HOME}|${HOME}|g" "${TEMPLATE}" > "${LA_PATH}"
|
||||
# Idempotent: unload prior registration if any, then load fresh. -w persists across reboots.
|
||||
launchctl unload -w "${LA_PATH}" 2>/dev/null || true
|
||||
if ! launchctl load -w "${LA_PATH}"; then
|
||||
warn "launchctl load reported non-zero — checking registration anyway"
|
||||
fi
|
||||
if launchctl list | grep -q "com.iai-mcp.daemon"; then
|
||||
ok "LaunchAgent registered (first MCP call will socket-activate the daemon)"
|
||||
else
|
||||
die "LaunchAgent NOT registered after launchctl load — investigate ${HOME}/.iai-mcp/logs/launchd-stderr.log"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# done
|
||||
# ---------------------------------------------------------------------------
|
||||
step "done"
|
||||
ok "iai-mcp installed at $(git rev-parse --short HEAD)"
|
||||
echo
|
||||
echo " next: bash scripts/uninstall.sh (to roll back; preserves data unless --purge-data)"
|
||||
echo " update: bash scripts/update.sh (pull + rebuild + restart daemon)"
|
||||
189
scripts/uninstall.sh
Executable file
189
scripts/uninstall.sh
Executable file
|
|
@ -0,0 +1,189 @@
|
|||
#!/usr/bin/env bash
|
||||
# scripts/uninstall.sh — LaunchAgent + daemon teardown.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/uninstall.sh # remove LaunchAgent + kill daemon
|
||||
# bash scripts/uninstall.sh --purge-state # also remove ~/.iai-mcp/.daemon.sock,
|
||||
# # .daemon-state.json, .lock
|
||||
# bash scripts/uninstall.sh --purge-data # also remove ~/.iai-mcp/lancedb +
|
||||
# # runtime_graph_cache.json
|
||||
# # DESTRUCTIVE — wipes user's brain.
|
||||
#
|
||||
# Idempotent: safe to re-run. Always exits 0 (best-effort).
|
||||
# DRY_RUN=1 env skips real launchctl + kill + rm calls (used by tests).
|
||||
#
|
||||
# Inverse of scripts/install.sh section 6 (Phase 7.1 LaunchAgent registration).
|
||||
|
||||
# NOTE on shell flags: we deliberately use only `set -u`, NOT `set -e`.
|
||||
# Uninstall must NEVER abort mid-flow — partial cleanup is worse than
|
||||
# best-effort full cleanup. Each step prints its own outcome via ok/warn/die.
|
||||
set -u
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
cd "${REPO_ROOT}"
|
||||
|
||||
step() { printf '\n\033[1;34m==> %s\033[0m\n' "$*"; }
|
||||
ok() { printf ' \033[0;32m✓\033[0m %s\n' "$*"; }
|
||||
warn() { printf ' \033[0;33m!\033[0m %s\n' "$*"; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. parse flags
|
||||
# ---------------------------------------------------------------------------
|
||||
PURGE_STATE=0
|
||||
PURGE_DATA=0
|
||||
for arg in "$@"; do
|
||||
case "${arg}" in
|
||||
--purge-state) PURGE_STATE=1 ;;
|
||||
--purge-data) PURGE_DATA=1 ;;
|
||||
-h|--help)
|
||||
sed -n '2,12p' "${BASH_SOURCE[0]}" | sed 's/^# \{0,1\}//'
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
warn "unknown flag '${arg}' (ignored — expected --purge-state, --purge-data, --help)"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
step "iai-mcp uninstall"
|
||||
if [[ "${PURGE_DATA}" == "1" ]]; then
|
||||
warn "--purge-data is DESTRUCTIVE: ~/.iai-mcp/lancedb (your brain) will be deleted"
|
||||
fi
|
||||
|
||||
LA_PATH="${HOME}/Library/LaunchAgents/com.iai-mcp.daemon.plist"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. launchctl unload (Darwin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "launchctl unload"
|
||||
if [[ "$(uname)" != "Darwin" ]]; then
|
||||
warn "non-Darwin OS — skipping launchctl unload"
|
||||
elif [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping launchctl unload (test mode)"
|
||||
else
|
||||
if [ -f "${LA_PATH}" ]; then
|
||||
if launchctl unload -w "${LA_PATH}" 2>/dev/null; then
|
||||
ok "LaunchAgent unloaded"
|
||||
else
|
||||
ok "LaunchAgent was not registered (already clean)"
|
||||
fi
|
||||
else
|
||||
ok "no LaunchAgent plist at ${LA_PATH} (already clean)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. remove plist file (Darwin only)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "remove plist"
|
||||
if [[ "$(uname)" != "Darwin" ]]; then
|
||||
warn "non-Darwin OS — skipping plist removal"
|
||||
elif [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping rm of ${LA_PATH} (test mode)"
|
||||
else
|
||||
rm -f "${LA_PATH}"
|
||||
ok "${LA_PATH} removed (or never existed)"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. kill any lingering daemon by cmdline match
|
||||
#
|
||||
# Defense against pgrep regex misfire: pgrep -f matches on substring of
|
||||
# the full command line. We re-verify each PID's cmdline contains the
|
||||
# literal "iai_mcp.daemon" via `ps -p PID -o command=` BEFORE killing.
|
||||
# ---------------------------------------------------------------------------
|
||||
step "kill lingering daemon"
|
||||
if [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping pgrep + kill (test mode)"
|
||||
else
|
||||
pids="$(pgrep -f "iai_mcp\.daemon" 2>/dev/null || true)"
|
||||
if [[ -n "${pids}" ]]; then
|
||||
warn "found pids: ${pids}"
|
||||
for pid in ${pids}; do
|
||||
# Verify cmdline really contains iai_mcp.daemon (defense against pgrep regex misfire).
|
||||
if ps -p "${pid}" -o command= 2>/dev/null | grep -q "iai_mcp.daemon"; then
|
||||
kill -TERM "${pid}" 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
sleep 3
|
||||
# SIGKILL stragglers
|
||||
for pid in ${pids}; do
|
||||
if kill -0 "${pid}" 2>/dev/null; then
|
||||
warn "pid ${pid} still alive — sending SIGKILL"
|
||||
kill -KILL "${pid}" 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
ok "lingering daemon(s) terminated"
|
||||
else
|
||||
ok "no lingering iai_mcp.daemon processes"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. --purge-state: remove socket + state + lock
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ "${PURGE_STATE}" == "1" ]]; then
|
||||
step "--purge-state: remove ~/.iai-mcp/.daemon.sock + .daemon-state.json + .lock"
|
||||
if [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping rm of state files (test mode)"
|
||||
else
|
||||
rm -f "${HOME}/.iai-mcp/.daemon.sock" \
|
||||
"${HOME}/.iai-mcp/.daemon-state.json" \
|
||||
"${HOME}/.iai-mcp/.lock"
|
||||
ok "state files removed (or never existed)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. --purge-data: remove lancedb + runtime cache (DESTRUCTIVE)
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ "${PURGE_DATA}" == "1" ]]; then
|
||||
step "--purge-data: remove ~/.iai-mcp/lancedb + runtime_graph_cache.json"
|
||||
if [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping rm of data files (test mode)"
|
||||
else
|
||||
# Confirmation prompt — only if attached to a tty (skip in non-interactive
|
||||
# subprocess to avoid hanging under set -u). bash 3.2 compatible.
|
||||
confirmed=0
|
||||
if [ -t 0 ]; then
|
||||
printf " \033[0;33m!\033[0m really delete ~/.iai-mcp/lancedb? [y/N] "
|
||||
read -r REPLY || REPLY=N
|
||||
if [[ "${REPLY}" =~ ^[Yy]$ ]]; then
|
||||
confirmed=1
|
||||
fi
|
||||
else
|
||||
warn "non-interactive stdin — skipping --purge-data confirmation (no deletion)"
|
||||
fi
|
||||
if [[ "${confirmed}" == "1" ]]; then
|
||||
rm -rf "${HOME}/.iai-mcp/lancedb" \
|
||||
"${HOME}/.iai-mcp/runtime_graph_cache.json"
|
||||
ok "data files removed"
|
||||
else
|
||||
ok "user declined — data preserved"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. verify
|
||||
# ---------------------------------------------------------------------------
|
||||
step "verify"
|
||||
if [[ "$(uname)" != "Darwin" ]]; then
|
||||
warn "non-Darwin OS — skipping launchctl verify"
|
||||
elif [[ "${DRY_RUN:-0}" == "1" ]]; then
|
||||
ok "DRY_RUN=1 — skipping launchctl list verify (test mode)"
|
||||
else
|
||||
if launchctl list 2>/dev/null | grep -q "com.iai-mcp.daemon"; then
|
||||
warn "com.iai-mcp.daemon still appears in launchctl list — manual cleanup may be needed"
|
||||
else
|
||||
ok "com.iai-mcp.daemon no longer in launchctl list"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# done
|
||||
# ---------------------------------------------------------------------------
|
||||
step "done"
|
||||
ok "iai-mcp uninstalled — re-run scripts/install.sh to restore."
|
||||
exit 0
|
||||
143
scripts/update.sh
Executable file
143
scripts/update.sh
Executable file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env bash
|
||||
# scripts/update.sh — pull + rebuild + restart daemon for collaborators
|
||||
#
|
||||
# Usage (from repo root or anywhere inside the clone):
|
||||
# bash scripts/update.sh
|
||||
#
|
||||
# Idempotent. Aborts on a dirty working tree so local changes are never
|
||||
# clobbered. Re-runs safely — each step detects whether it is needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Resolve repo root no matter where the script is invoked from.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
cd "${REPO_ROOT}"
|
||||
|
||||
step() { printf '\n\033[1;34m==> %s\033[0m\n' "$*"; }
|
||||
ok() { printf ' \033[0;32m✓\033[0m %s\n' "$*"; }
|
||||
warn() { printf ' \033[0;33m!\033[0m %s\n' "$*"; }
|
||||
die() { printf '\n\033[0;31m✗ %s\033[0m\n' "$*" >&2; exit 1; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0. Preconditions
|
||||
# ---------------------------------------------------------------------------
|
||||
step "preflight"
|
||||
[ -d .git ] || die "not a git repository (run from an iai-mcp clone)"
|
||||
|
||||
# Require a clean working tree — never trample local edits.
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git status --short
|
||||
die "working tree is dirty. commit or stash first, then re-run."
|
||||
fi
|
||||
ok "working tree clean"
|
||||
|
||||
VENV_PY="${REPO_ROOT}/.venv/bin/python"
|
||||
[ -x "${VENV_PY}" ] || die ".venv/bin/python not found — run 'python3 -m venv .venv && .venv/bin/pip install -e .' once, then rerun"
|
||||
ok "venv detected"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. git pull (fast-forward only — never merge surprises)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "git pull --ff-only origin main"
|
||||
BEFORE="$(git rev-parse HEAD)"
|
||||
git fetch --quiet origin main
|
||||
git pull --ff-only --quiet origin main
|
||||
AFTER="$(git rev-parse HEAD)"
|
||||
if [ "${BEFORE}" = "${AFTER}" ]; then
|
||||
ok "already at $(git rev-parse --short HEAD) — no upstream commits"
|
||||
NOOP=1
|
||||
else
|
||||
ok "advanced $(git rev-parse --short "${BEFORE}") → $(git rev-parse --short "${AFTER}")"
|
||||
NOOP=0
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Python package (editable reinstall — picks up deps or entry-point drift)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "python package refresh (editable)"
|
||||
"${VENV_PY}" -m pip install --quiet -e . || die "pip install -e failed"
|
||||
ok "iai-mcp python package up to date"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. TypeScript MCP wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
step "TS wrapper build"
|
||||
if [ -d mcp-wrapper ]; then
|
||||
pushd mcp-wrapper >/dev/null
|
||||
if [ -f package-lock.json ]; then
|
||||
npm ci --silent --no-audit --no-fund
|
||||
else
|
||||
npm install --silent --no-audit --no-fund
|
||||
fi
|
||||
npm run build --silent
|
||||
popd >/dev/null
|
||||
ok "mcp-wrapper/dist rebuilt"
|
||||
else
|
||||
warn "mcp-wrapper/ missing — skipping"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Global CLI symlink (idempotent — ensures ~/.local/bin/iai-mcp exists)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "global CLI symlink"
|
||||
LOCAL_BIN="${HOME}/.local/bin"
|
||||
LINK_PATH="${LOCAL_BIN}/iai-mcp"
|
||||
TARGET="${REPO_ROOT}/.venv/bin/iai-mcp"
|
||||
if [ -e "${LINK_PATH}" ] && [ ! -L "${LINK_PATH}" ]; then
|
||||
warn "${LINK_PATH} exists as a regular file — skipping symlink refresh"
|
||||
else
|
||||
mkdir -p "${LOCAL_BIN}"
|
||||
ln -sf "${TARGET}" "${LINK_PATH}"
|
||||
ok "${LINK_PATH} -> ${TARGET}"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Daemon (restart only if currently running; plist drift advisory)
|
||||
# ---------------------------------------------------------------------------
|
||||
step "daemon lifecycle"
|
||||
IAI_MCP="${REPO_ROOT}/.venv/bin/iai-mcp"
|
||||
|
||||
# Check template drift using a python one-liner (avoids shell grep, which is
|
||||
# hook-blocked in this repo's dev env).
|
||||
TEMPLATE_CHECK="$("${VENV_PY}" - <<'PY'
|
||||
import pathlib, sys
|
||||
home = pathlib.Path.home()
|
||||
installed = home / "Library/LaunchAgents/com.iai-mcp.daemon.plist"
|
||||
template = pathlib.Path.cwd() / "deploy/launchd/com.iai-mcp.daemon.plist"
|
||||
if not installed.exists() or not template.exists():
|
||||
print("none"); sys.exit(0)
|
||||
# Substitute USERNAME placeholder and compare env-var + args payload.
|
||||
rendered = template.read_text().replace("{USERNAME}", home.name)
|
||||
a_env = "IAI_MCP_STORE" in installed.read_text() and home.as_posix() + "/.iai-mcp" in installed.read_text()
|
||||
b_env = "IAI_MCP_STORE" in rendered and home.as_posix() + "/.iai-mcp" in rendered
|
||||
print("drift" if a_env != b_env else "same")
|
||||
PY
|
||||
)"
|
||||
|
||||
if [ "${TEMPLATE_CHECK}" = "drift" ]; then
|
||||
warn "launchd plist template drift detected"
|
||||
warn "run: '${IAI_MCP} daemon uninstall --yes && ${IAI_MCP} daemon install --yes' to pick up the new plist"
|
||||
fi
|
||||
|
||||
if "${IAI_MCP}" daemon status >/dev/null 2>&1; then
|
||||
# daemon status exits 0 only when running
|
||||
"${IAI_MCP}" daemon stop >/dev/null 2>&1 || true
|
||||
sleep 2
|
||||
"${IAI_MCP}" daemon start >/dev/null 2>&1 || warn "daemon start returned non-zero; check 'iai-mcp daemon status'"
|
||||
ok "daemon restarted on new code"
|
||||
else
|
||||
ok "daemon not running — nothing to restart"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Summary
|
||||
# ---------------------------------------------------------------------------
|
||||
step "done"
|
||||
if [ "${NOOP}" = "1" ]; then
|
||||
ok "no-op — everything already current"
|
||||
else
|
||||
ok "updated to $(git rev-parse --short HEAD)"
|
||||
echo
|
||||
git log --oneline "${BEFORE}..${AFTER}"
|
||||
fi
|
||||
19
src/iai_mcp/__init__.py
Normal file
19
src/iai_mcp/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""IAI-MCP -- autistic-style persistent memory MCP server."""
|
||||
from iai_mcp.types import (
|
||||
MemoryRecord,
|
||||
MemoryHit,
|
||||
RecallResponse,
|
||||
EdgeUpdate,
|
||||
ReconsolidationReceipt,
|
||||
TIER_ENUM,
|
||||
)
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = [
|
||||
"MemoryRecord",
|
||||
"MemoryHit",
|
||||
"RecallResponse",
|
||||
"EdgeUpdate",
|
||||
"ReconsolidationReceipt",
|
||||
"TIER_ENUM",
|
||||
]
|
||||
245
src/iai_mcp/aaak.py
Normal file
245
src/iai_mcp/aaak.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""AAAK index generator + English-Only storage enforcement.
|
||||
|
||||
Phase 1 constitutional rule:
|
||||
Storage is raw verbatim English always. AAAK is a RETRIEVAL VIEW only.
|
||||
|
||||
Phase 2 (superseded):
|
||||
Storage was briefly amended to raw verbatim in the user's original language.
|
||||
Every MemoryRecord carries an ISO-639-1 `language` tag retained as a column
|
||||
on legacy rows from that era.
|
||||
|
||||
Plan 05-08 (2026-04-19) restored the English-Only Brain (D-08 spirit):
|
||||
The surface (Claude) translates inbound text to English; storage holds the
|
||||
English form. The `language` column is retained for legacy compatibility;
|
||||
new records default to "en". Embedding default is bge-small-en-v1.5 (384d,
|
||||
English) per Plan 05-08.
|
||||
|
||||
This module provides:
|
||||
|
||||
- `generate_aaak_index(record)` -- builds a `W:<wing>/R:<room>/E:<entities>/T:<tags>`
|
||||
metadata string from a MemoryRecord's tier, community_id and tags. The returned
|
||||
string is guaranteed to contain none of record.literal_surface.
|
||||
|
||||
- `parse_aaak_index(idx)` -- inverse of the generator, returning a
|
||||
{wing, room, entities, tags} dict. Round-trips the entities/tags lists.
|
||||
|
||||
- `enforce_language_tagged(record, detect=False)` -- guard.
|
||||
Raises ValueError if record.language is empty and detect is False. When
|
||||
detect=True, runs langdetect on literal_surface; mutates record.language
|
||||
with the detected code if confidence >= 0.7, else raises. Empty text with
|
||||
detect=True defaults to "en" without raising.
|
||||
|
||||
- `enforce_english_raw(record)` -- shim retained for backward compat.
|
||||
Delegates to enforce_language_tagged for records with a language tag set;
|
||||
preserves Cyrillic/CJK rejection for records without one unless
|
||||
`raw:<lang>` tag is present.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
# constitutional: confidence threshold below which langdetect refuses.
|
||||
LANGDETECT_MIN_CONFIDENCE = 0.7
|
||||
|
||||
|
||||
# --------------------------------------------------------------- script regex
|
||||
# Covered: Cyrillic (Russian et al), Hiragana, Katakana, CJK Unified Ideographs.
|
||||
# Sufficient for (the three scripts the project explicitly documents
|
||||
# as needing `raw:<lang>` handling). Extend the alphabet list in only
|
||||
# if a genuine storage bug surfaces -- don't speculate.
|
||||
CYRILLIC = re.compile(r"[\u0400-\u04FF]") # U+0400..U+04FF
|
||||
HIRAGANA_KATAKANA = re.compile(r"[\u3040-\u30FF]") # U+3040..U+30FF
|
||||
CJK = re.compile(r"[\u4E00-\u9FFF]") # U+4E00..U+9FFF Unified Ideographs
|
||||
|
||||
|
||||
# ---------------------------------------------- tier -> wing alphabet (TOK-10)
|
||||
_TIER_TO_WING = {
|
||||
"working": "W",
|
||||
"episodic": "E",
|
||||
"semantic": "S",
|
||||
"procedural": "P",
|
||||
"parametric": "\u03a0", # Pi glyph -- distinct from Latin P
|
||||
}
|
||||
|
||||
|
||||
def _wing_from_tier(tier: str) -> str:
|
||||
return _TIER_TO_WING.get(tier, "unknown")
|
||||
|
||||
|
||||
def _room_from_community(record: "MemoryRecord") -> str:
|
||||
"""First 8 chars of community UUID; "unknown" if community not yet assigned.
|
||||
|
||||
Plan 02 assigns community_id; Plan 03 L0/L1 pinned records may still have
|
||||
community_id=None (they're pinned by UUID, not graph position).
|
||||
"""
|
||||
if record.community_id is None:
|
||||
return "unknown"
|
||||
return str(record.community_id)[:8]
|
||||
|
||||
|
||||
def _entities_from_tags(tags: list[str]) -> str:
|
||||
"""Up to 10 tags prefixed `entity:` (prefix stripped), joined by `,`.
|
||||
|
||||
`"-"` if none found, so the generator output has a stable shape with
|
||||
exactly 3 `/` separators regardless of tag content.
|
||||
"""
|
||||
ents = [t[len("entity:"):] for t in tags if t.startswith("entity:")][:10]
|
||||
if not ents:
|
||||
return "-"
|
||||
return ",".join(ents)
|
||||
|
||||
|
||||
def _tagline(tags: list[str]) -> str:
|
||||
"""Up to 10 non-entity tags joined by `,`. `"-"` if none."""
|
||||
non_ents = [t for t in tags if not t.startswith("entity:")][:10]
|
||||
if not non_ents:
|
||||
return "-"
|
||||
return ",".join(non_ents)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- public API
|
||||
|
||||
|
||||
def generate_aaak_index(record: "MemoryRecord") -> str:
|
||||
"""Build the AAAK index string for a record (D-08, TOK-10).
|
||||
|
||||
Format: `W:<wing>/R:<room>/E:<entities>/T:<tags>`
|
||||
|
||||
Guarantees:
|
||||
- Exactly 3 `/` separators regardless of content.
|
||||
- Contains NO substring of `record.literal_surface`. Verified by
|
||||
`tests/test_aaak.py::test_aaak_index_does_not_contain_literal_surface`.
|
||||
- Deterministic: same record -> same index on repeat calls.
|
||||
"""
|
||||
wing = _wing_from_tier(record.tier)
|
||||
room = _room_from_community(record)
|
||||
entities = _entities_from_tags(record.tags)
|
||||
tags = _tagline(record.tags)
|
||||
return f"W:{wing}/R:{room}/E:{entities}/T:{tags}"
|
||||
|
||||
|
||||
def parse_aaak_index(idx: str) -> dict[str, list[str]]:
|
||||
"""Inverse of generate_aaak_index. Returns wing/room/entities/tags lists.
|
||||
|
||||
Each value is a list (even wing/room which are single strings) so callers
|
||||
have a uniform shape. Unknown keys are ignored. Empty-value `-` becomes [].
|
||||
"""
|
||||
out: dict[str, list[str]] = {
|
||||
"wing": [],
|
||||
"room": [],
|
||||
"entities": [],
|
||||
"tags": [],
|
||||
}
|
||||
key_map = {"W": "wing", "R": "room", "E": "entities", "T": "tags"}
|
||||
for seg in idx.split("/"):
|
||||
if ":" not in seg:
|
||||
continue
|
||||
k, _, v = seg.partition(":")
|
||||
if k not in key_map:
|
||||
continue
|
||||
name = key_map[k]
|
||||
if v == "-" or v == "":
|
||||
out[name] = []
|
||||
else:
|
||||
# Wing/Room are single-token; entities/tags are comma-separated.
|
||||
if name in ("wing", "room"):
|
||||
out[name] = [v]
|
||||
else:
|
||||
out[name] = v.split(",")
|
||||
return out
|
||||
|
||||
|
||||
def enforce_language_tagged(
|
||||
record: "MemoryRecord",
|
||||
*,
|
||||
detect: bool = False,
|
||||
) -> None:
|
||||
"""D-08a constitutional: every Phase-2+ record MUST carry a language tag.
|
||||
|
||||
When record.language is a non-empty string, the guard passes unconditionally
|
||||
(the column is retained for legacy compatibility; the English-Only Brain
|
||||
pivot in means new records default to "en").
|
||||
|
||||
When record.language is empty/missing and detect is False, raises
|
||||
ValueError("constitutional violation: ...") because storage is
|
||||
tag-addressable -- not defaulting to English.
|
||||
|
||||
When detect=True and language is empty:
|
||||
- If literal_surface is empty/whitespace, sets language="en" and returns.
|
||||
- Else runs langdetect; if top candidate has probability >= 0.7 (D-08a
|
||||
threshold), mutates record.language with the detected code.
|
||||
- If langdetect fails or confidence < 0.7, raises ValueError.
|
||||
|
||||
The seed for langdetect's DetectorFactory is fixed at 42 so the same text
|
||||
always produces the same language code across runs.
|
||||
"""
|
||||
if record.language and isinstance(record.language, str) and record.language.strip():
|
||||
return # already tagged; accept
|
||||
|
||||
if not detect:
|
||||
raise ValueError(
|
||||
"constitutional violation: record.language is required. "
|
||||
"Pass detect=True to auto-detect via langdetect."
|
||||
)
|
||||
|
||||
text = record.literal_surface or ""
|
||||
if not text.strip():
|
||||
record.language = "en" # empty -> default en
|
||||
return
|
||||
|
||||
try:
|
||||
from langdetect import DetectorFactory, detect_langs
|
||||
DetectorFactory.seed = 42 # determinism
|
||||
candidates = detect_langs(text)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"constitutional violation: langdetect failed on record text: {e}"
|
||||
)
|
||||
|
||||
if not candidates or candidates[0].prob < LANGDETECT_MIN_CONFIDENCE:
|
||||
top = candidates[0] if candidates else None
|
||||
raise ValueError(
|
||||
f"constitutional violation: langdetect confidence too low "
|
||||
f"(<{LANGDETECT_MIN_CONFIDENCE}); top candidate={top}"
|
||||
)
|
||||
|
||||
record.language = candidates[0].lang
|
||||
|
||||
|
||||
def enforce_english_raw(record: "MemoryRecord") -> None:
|
||||
"""Phase 1 shim -- preserves the original script-based guard.
|
||||
|
||||
semantics (retained byte-for-byte for backward compatibility):
|
||||
- `raw:<lang>` tag present on record -> accept (explicit raw capture)
|
||||
- literal_surface contains Cyrillic / Hiragana / Katakana / CJK codepoints
|
||||
and no `raw:<lang>` tag -> raise ValueError("constitutional ...")
|
||||
- else -> accept
|
||||
|
||||
The guard is exposed as `enforce_language_tagged`. Downstream
|
||||
plans that want native-language storage should import that directly
|
||||
instead of this shim. This function is kept so the test fixtures
|
||||
(tests/test_aaak.py, tests/test_provenance.py) continue to assert the
|
||||
exact rejection behaviour they documented.
|
||||
"""
|
||||
text = record.literal_surface or ""
|
||||
has_non_english = bool(
|
||||
CYRILLIC.search(text)
|
||||
or HIRAGANA_KATAKANA.search(text)
|
||||
or CJK.search(text)
|
||||
)
|
||||
if not has_non_english:
|
||||
return
|
||||
|
||||
# Caller opted in via `raw:<lang>` tag -> accept.
|
||||
if any(t.startswith("raw:") for t in record.tags):
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
"constitutional violation: literal_surface contains non-English "
|
||||
"characters; storage must be English raw verbatim (D-08, TOK-10). "
|
||||
"Add 'raw:<lang>' tag to declare explicit raw capture."
|
||||
)
|
||||
155
src/iai_mcp/batch.py
Normal file
155
src/iai_mcp/batch.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""TOK-09 Batch API consolidation (Plan 02-04 Task 3, D-29).
|
||||
|
||||
D-29 (unified daily process): when Tier 1 is enabled + credentials + budget
|
||||
+ rate-limit all green (D-GUARD ladder via should_call_llm), submit a batch
|
||||
to Anthropic's Batch API at 50% discount vs synchronous calls. Falls back
|
||||
to Tier 0 stub results on any gate failure or SDK absence.
|
||||
|
||||
Plan 02-04 scope: the D-GUARD gate + budget side-effect + llm_health event
|
||||
emission are load-bearing. The actual anthropic.batches.create call is
|
||||
scaffolded behind a lazy import; when the SDK surface differs from what the
|
||||
Python core expects (e.g. version skew), the stub returns an empty result
|
||||
list and records llm_health fallback. Plan 03 / future phases own the real
|
||||
wire-up once the SDK API settles.
|
||||
|
||||
Pricing model:
|
||||
- Haiku 4.5 approx sync cost: prompt $0.25 / 1M tokens + output $1.25 / 1M
|
||||
- Batch discount: 50% off sync cost.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.guard import BudgetLedger, RateLimitLedger, should_call_llm
|
||||
|
||||
|
||||
# 50% discount vs sync tier.
|
||||
BATCH_DISCOUNT = 0.5
|
||||
|
||||
# scope: we do not poll in-process. Real-world Batch API can take
|
||||
# up to ~24h. The dispatch path is "submit -> return (True, 'ok', stub)" with
|
||||
# the actual results arriving via a future polling job. Tests assert the
|
||||
# gate + side-effects; the stub list is empty in Phase 2.
|
||||
BATCH_POLL_TIMEOUT_SEC = 60
|
||||
|
||||
# Haiku 4.5 approximate sync pricing (USD per 1M tokens).
|
||||
_HAIKU_PROMPT_USD_PER_MTOK = 0.25
|
||||
_HAIKU_OUTPUT_USD_PER_MTOK = 1.25
|
||||
|
||||
|
||||
def _sync_tier_cost(prompt_tokens: int, output_tokens: int) -> float:
|
||||
"""Approximate sync-tier USD cost for Haiku 4.5.
|
||||
|
||||
uses Haiku 4.5 for consolidation. Pricing is approximate and may
|
||||
drift; the gate uses this only for budget-cap decisions (D-GUARD step
|
||||
3+4), never for billing reconciliation.
|
||||
"""
|
||||
p = (float(prompt_tokens) / 1_000_000.0) * _HAIKU_PROMPT_USD_PER_MTOK
|
||||
o = (float(output_tokens) / 1_000_000.0) * _HAIKU_OUTPUT_USD_PER_MTOK
|
||||
return float(p + o)
|
||||
|
||||
|
||||
def _aggregate_estimated_usd(tasks: list[dict]) -> float:
|
||||
total_sync = 0.0
|
||||
for t in tasks:
|
||||
total_sync += _sync_tier_cost(
|
||||
int(t.get("prompt_tok", 0)),
|
||||
int(t.get("output_tok", 0)),
|
||||
)
|
||||
return total_sync * BATCH_DISCOUNT
|
||||
|
||||
|
||||
def submit_batch_consolidation(
|
||||
store,
|
||||
tasks: list[dict],
|
||||
budget: BudgetLedger,
|
||||
rate: RateLimitLedger,
|
||||
llm_enabled: bool = True,
|
||||
) -> tuple[bool, str, list[dict]]:
|
||||
"""Submit a batch of consolidation tasks to the Anthropic Batch API.
|
||||
|
||||
Returns (ok, reason, results). On any D-GUARD fallback, ok=False and
|
||||
results is an empty list; the caller falls back to local Tier 0 output.
|
||||
|
||||
Gate ordering (D-GUARD):
|
||||
1. llm_enabled toggle
|
||||
2. API key present
|
||||
3. Budget daily + monthly caps (can_spend)
|
||||
4. Rate-limit cooldown (last 429 < 15 min)
|
||||
5. SDK import path
|
||||
6. Real batch submission (Plan 02-04 stub; see module docstring)
|
||||
"""
|
||||
has_key = bool(os.environ.get("ANTHROPIC_API_KEY"))
|
||||
estimated_usd = _aggregate_estimated_usd(tasks)
|
||||
|
||||
ok, reason = should_call_llm(
|
||||
budget=budget,
|
||||
rate=rate,
|
||||
llm_enabled=llm_enabled,
|
||||
has_api_key=has_key,
|
||||
estimated_usd=estimated_usd,
|
||||
)
|
||||
if not ok:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "batch_consolidation",
|
||||
"tier": "fallback",
|
||||
"reason": reason,
|
||||
"task_count": len(tasks),
|
||||
"estimated_usd": estimated_usd,
|
||||
},
|
||||
severity="warning",
|
||||
)
|
||||
return False, reason, []
|
||||
|
||||
# Eligible path: lazy import the SDK. On ImportError or any runtime
|
||||
# failure, log critical and fall back. This is also how the current Plan
|
||||
# 02-04 scaffold returns -- the real batch submission is stubbed (the
|
||||
# SDK surface for batches.create has changed across minor versions).
|
||||
try:
|
||||
import anthropic # noqa: F401
|
||||
except Exception as exc:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "batch_consolidation",
|
||||
"tier": "fallback",
|
||||
"error": f"import anthropic: {exc}",
|
||||
},
|
||||
severity="critical",
|
||||
)
|
||||
return False, f"SDK unavailable: {exc}", []
|
||||
|
||||
# H-02 FIX (Phase 2 gap closure): budget stays untouched and
|
||||
# effective_tier stays tier0 until a REAL successful anthropic.batches.create
|
||||
# response lands. The previous behaviour called budget.record_spend + returned
|
||||
# (True, "ok", []), which caused run_heavy_consolidation to flip
|
||||
# effective_tier=tier1 and debit the BudgetLedger on a stub producing zero
|
||||
# output -- corrupts D-GUARD audit honesty + cost accounting.
|
||||
#
|
||||
# Real SDK wire-up is scope. Until then the scaffold is honestly
|
||||
# documented via an info-severity llm_health event so `iai-mcp audit`
|
||||
# observers can see the gap explicitly.
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "batch_consolidation",
|
||||
"tier": "fallback",
|
||||
"task_count": len(tasks),
|
||||
"estimated_usd": estimated_usd,
|
||||
"note": (
|
||||
"Plan 02-06 disables the scaffold-true return; "
|
||||
"real anthropic.batches.create wire-up is Phase 3. Budget "
|
||||
"stays untouched and effective_tier stays tier0 until a "
|
||||
"real successful SDK response lands."
|
||||
),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
return False, "stub: batch API not yet wired", []
|
||||
301
src/iai_mcp/bedtime.py
Normal file
301
src/iai_mcp/bedtime.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""Phase 4 -- bedtime wind-down detection (DAEMON-06, D-08/D-09/D-11).
|
||||
|
||||
Dual-gate bedtime suggestion emitter:
|
||||
Gate A: wind-down phrase regex match per language (D-11, 8 languages)
|
||||
Gate B: late in learned quiet window (inside OR within 30min of start, D-09)
|
||||
|
||||
When BOTH gates pass, `detect_wind_down` returns a small dict that `core.py`
|
||||
injects into `memory_recall` responses as `sleep_suggestion`. Claude (the
|
||||
LLM in the active session) decides social framing -- our code NEVER hardcodes
|
||||
user-facing phrasing.
|
||||
|
||||
Constitutional guard:
|
||||
- C2: this module does NOT initiate sleep. It only suggests. The only path
|
||||
that moves the daemon into SLEEP is `core.handle_initiate_sleep_mode`
|
||||
with `consent=True`. No auto-start in this file.
|
||||
- C5 / this module is read-only w.r.t. records. It reads `cue`
|
||||
strings; it NEVER mutates `literal_surface`.
|
||||
- C6: no fcntl, no daemon state mutation. All logic is pure in-process.
|
||||
|
||||
Patterns mirror `shield.py`'s 8-language dict style (same language set:
|
||||
en/ru/ja/ar/de/fr/es/zh per global-product mandate). Latin-script
|
||||
languages use `\b` word boundaries; CJK / Arabic use character-class
|
||||
proximity and whitespace-tolerant forms since Unicode `\b` is unreliable
|
||||
across scripts.
|
||||
|
||||
ReDoS-safe: every pattern uses bounded quantifiers only. No nested `(.+)+`
|
||||
constructs, no `.*.*`. Stress-tested against 10KB of "a"s under 100ms total.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from iai_mcp.quiet_window import BUCKET_MINUTES
|
||||
|
||||
|
||||
# ------------------------------------------------------------ constants
|
||||
|
||||
# dual-gate: within this many minutes of the learned quiet-window start
|
||||
# also counts as "late" (a user who says "good night" 25 minutes before their
|
||||
# usual quiet window is winding down, not speaking rhetorically).
|
||||
WIND_DOWN_GATE_MINUTES_BEFORE: int = 30
|
||||
|
||||
|
||||
# ------------------------------------------------------------ per-language regex
|
||||
|
||||
# English wind-down phrases. Case-insensitive match.
|
||||
WIND_DOWN_EN: list[str] = [
|
||||
r"\bgood\s*night\b",
|
||||
r"\bgoodnight\b",
|
||||
r"\bnight[,!.]?\s*$",
|
||||
r"\bI'?m\s+(heading|going)\s+to\s+bed\b",
|
||||
r"\b(time\s+(to|for)\s+bed|bedtime)\b",
|
||||
r"\bI'?m\s+(tired|exhausted|sleepy)\b",
|
||||
r"\b(catch\s+you\s+tomorrow|see\s+you\s+tomorrow)\b",
|
||||
r"\blet'?s\s+(continue|pick\s+up)\s+tomorrow\b",
|
||||
r"\bgoing\s+to\s+sleep\b",
|
||||
]
|
||||
|
||||
# Russian (same 8-language set as shield.py).
|
||||
WIND_DOWN_RU: list[str] = [
|
||||
r"спокойной\s+ночи",
|
||||
r"пойду\s+(спать|в\s+постель)",
|
||||
r"(я\s+)?(устал|устала|вымотан[аы]?|засыпаю)",
|
||||
r"пора\s+(спать|ложиться)",
|
||||
r"до\s+завтра",
|
||||
r"давай\s+завтра",
|
||||
r"ухожу\s+спать",
|
||||
r"(окей|ок|ладно),?\s+сплю",
|
||||
r"ложусь",
|
||||
]
|
||||
|
||||
# Japanese -- NREM cues + "see you tomorrow". No \b; lookaround on adjacent
|
||||
# punctuation / kana / CJK characters.
|
||||
WIND_DOWN_JA: list[str] = [
|
||||
r"お\s*や\s*す\s*み(なさい)?", # おやすみ / おやすみなさい
|
||||
r"寝\s*ます", # 寝ます
|
||||
r"(眠|ねむ)い", # 眠い / ねむい
|
||||
r"(寝る|ねる)(ね|よ|わ)?", # 寝る / ねる / 寝るね
|
||||
r"また\s*(明日|あした)", # また明日
|
||||
r"(疲|つか)れた", # 疲れた / つかれた
|
||||
r"ベッド\s*に\s*(入る|はいる)", # ベッドに入る
|
||||
]
|
||||
|
||||
# Arabic -- RTL script; use direct patterns.
|
||||
WIND_DOWN_AR: list[str] = [
|
||||
r"تصبح\s+على\s+خير",
|
||||
r"ليلة\s+سعيدة",
|
||||
r"أنا\s+(ذاهب|ذاهبة)\s+(للنوم|إلى\s+النوم)",
|
||||
r"أنا\s+(متعب|متعبة|تعبان[ةه]?)",
|
||||
r"سأنام",
|
||||
r"وقت\s+النوم",
|
||||
r"إلى\s+(الغد|اللقاء\s+غدا)",
|
||||
]
|
||||
|
||||
WIND_DOWN_DE: list[str] = [
|
||||
r"\bgute\s+nacht\b",
|
||||
r"\bgn8\b",
|
||||
r"\bich\s+gehe\s+(jetzt\s+)?(ins\s+bett|schlafen)\b",
|
||||
r"\b(ich\s+bin\s+)?(müde|kaputt|fertig)\b",
|
||||
r"\bschlafenszeit\b",
|
||||
r"\bbis\s+morgen\b",
|
||||
r"\blass\s+uns\s+morgen\s+weitermachen\b",
|
||||
]
|
||||
|
||||
WIND_DOWN_FR: list[str] = [
|
||||
r"\bbonne\s+nuit\b",
|
||||
r"\bje\s+(vais|pars)\s+(me\s+coucher|dormir)\b",
|
||||
r"\b(je\s+suis\s+)?(fatigu[ée]|[ée]puis[ée])\b",
|
||||
r"\b(il\s+est\s+)?l'?heure\s+de\s+(dormir|me\s+coucher)\b",
|
||||
r"\b[aà]\s+demain\b",
|
||||
r"\bon\s+reprend\s+demain\b",
|
||||
]
|
||||
|
||||
WIND_DOWN_ES: list[str] = [
|
||||
r"\bbuenas\s+noches\b",
|
||||
r"\bme\s+voy\s+a\s+(dormir|la\s+cama|descansar)\b",
|
||||
r"\b(estoy\s+)?(cansad[oa]|agotad[oa])\b",
|
||||
r"\bhora\s+de\s+dormir\b",
|
||||
r"\bhasta\s+ma[ñn]ana\b",
|
||||
r"\bseguimos\s+ma[ñn]ana\b",
|
||||
]
|
||||
|
||||
WIND_DOWN_ZH: list[str] = [
|
||||
r"晚\s*安", # 晚安
|
||||
r"我\s*(要|去)\s*睡\s*(觉|了)", # 我要睡觉 / 我去睡了
|
||||
r"累\s*了", # 累了
|
||||
r"(该|到)\s*睡\s*(觉)?\s*了", # 该睡了 / 到睡觉了
|
||||
r"明\s*天\s*见", # 明天见
|
||||
r"明\s*天\s*继\s*续", # 明天继续
|
||||
]
|
||||
|
||||
# language coverage: exactly the 8 languages shield.py supports.
|
||||
WIND_DOWN_BY_LANG: dict[str, list[str]] = {
|
||||
"en": WIND_DOWN_EN,
|
||||
"ru": WIND_DOWN_RU,
|
||||
"ja": WIND_DOWN_JA,
|
||||
"ar": WIND_DOWN_AR,
|
||||
"de": WIND_DOWN_DE,
|
||||
"fr": WIND_DOWN_FR,
|
||||
"es": WIND_DOWN_ES,
|
||||
"zh": WIND_DOWN_ZH,
|
||||
}
|
||||
|
||||
# Pre-compile every pattern once. IGNORECASE is safe on non-Latin scripts
|
||||
# (lowercasing is identity-preserving for CJK; Cyrillic handles cleanly).
|
||||
_COMPILED: dict[str, list[re.Pattern]] = {
|
||||
lang: [re.compile(p, re.IGNORECASE) for p in pats]
|
||||
for lang, pats in WIND_DOWN_BY_LANG.items()
|
||||
}
|
||||
|
||||
# Authoritative language set -- downstream greps against this constant.
|
||||
WIND_DOWN_LANGUAGES_SUPPORTED: frozenset[str] = frozenset(WIND_DOWN_BY_LANG.keys())
|
||||
|
||||
|
||||
# ------------------------------------------------------------ gate A: phrase match
|
||||
|
||||
|
||||
def detect_wind_down_phrase(cue: str, language: str) -> Tuple[bool, str]:
|
||||
"""Gate A: does the cue contain a wind-down phrase?
|
||||
|
||||
Policy mirrors shield.py: primary language is always tried; ALSO try
|
||||
English regardless of `language` because users cross-lingual mid-sentence
|
||||
("ok, going to sleep" in a Russian conversation is still a wind-down
|
||||
signal). We do NOT fall back to any other language beyond EN -- that
|
||||
would explode the FPR.
|
||||
|
||||
Returns (matched, matched_pattern). matched_pattern is the source regex
|
||||
string (not the compiled object) for audit/logging purposes.
|
||||
"""
|
||||
if not cue:
|
||||
return False, ""
|
||||
|
||||
# Primary language (when different from "en").
|
||||
for p in _COMPILED.get(language or "", []):
|
||||
if p.search(cue):
|
||||
return True, p.pattern
|
||||
|
||||
# Always also try EN if we haven't already.
|
||||
if language != "en":
|
||||
for p in _COMPILED["en"]:
|
||||
if p.search(cue):
|
||||
return True, p.pattern
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
# ------------------------------------------------------------ gate B: late in quiet window
|
||||
|
||||
|
||||
def is_late_in_quiet_window(
|
||||
window: Optional[Tuple[int, int]],
|
||||
now: datetime,
|
||||
tz: ZoneInfo,
|
||||
) -> bool:
|
||||
"""Gate B: is `now` inside the quiet window OR within 30min of its start?
|
||||
|
||||
`window` is the (start_bucket, duration_buckets) pair emitted by
|
||||
`quiet_window.learn_quiet_window` -- start_bucket is an index into the
|
||||
48-bucket local-time day (30min each) and duration is the number of
|
||||
buckets. Returns False if no window is set (learn_quiet_window returned
|
||||
None, caller should be using the bootstrap 2h-idle trigger instead).
|
||||
|
||||
Wrap-around: a window starting at 22:00 and lasting 8h crosses local
|
||||
midnight; "inside" then means `cur >= start_minutes` OR `cur < end_minutes`.
|
||||
"""
|
||||
if not window:
|
||||
return False
|
||||
|
||||
start_bucket, duration = window
|
||||
try:
|
||||
now_local = now.astimezone(tz)
|
||||
except Exception:
|
||||
# DST edge or bad tz -- fail closed (don't suggest bedtime on
|
||||
# malformed input).
|
||||
return False
|
||||
|
||||
cur_minutes = now_local.hour * 60 + now_local.minute
|
||||
start_minutes = start_bucket * BUCKET_MINUTES
|
||||
end_minutes = (start_bucket + duration) * BUCKET_MINUTES
|
||||
|
||||
# Handle wrap-around midnight explicitly.
|
||||
if end_minutes > 24 * 60:
|
||||
wrapped_end = end_minutes - 24 * 60
|
||||
inside = cur_minutes >= start_minutes or cur_minutes < wrapped_end
|
||||
else:
|
||||
inside = start_minutes <= cur_minutes < end_minutes
|
||||
|
||||
if inside:
|
||||
return True
|
||||
|
||||
# Within 30min of start (cyclic -- a 21:45 cue for a 22:00 window counts).
|
||||
minutes_until_start = (start_minutes - cur_minutes) % (24 * 60)
|
||||
return 0 <= minutes_until_start <= WIND_DOWN_GATE_MINUTES_BEFORE
|
||||
|
||||
|
||||
# ------------------------------------------------------------ dual-gate detector
|
||||
|
||||
|
||||
def detect_wind_down(
|
||||
cue: str,
|
||||
language: str,
|
||||
state: dict,
|
||||
now: datetime,
|
||||
tz: ZoneInfo,
|
||||
) -> Optional[dict]:
|
||||
"""D-09 dual-gate bedtime detector.
|
||||
|
||||
Returns a `sleep_suggestion` dict when BOTH gates pass:
|
||||
Gate A: wind-down phrase match (primary lang + EN fallback)
|
||||
Gate B: late-in-learned-quiet-window (inside OR within 30min of start)
|
||||
|
||||
Returns None otherwise -- never a partial / fuzzy signal. Downstream
|
||||
consumers (`core._inject_sleep_suggestion`) key on the presence of the
|
||||
key, so None means the response simply does not carry `sleep_suggestion`.
|
||||
|
||||
Payload shape (small, no PII beyond the matched regex pattern):
|
||||
{
|
||||
"message_hint": "user_wind_down_detected",
|
||||
"matched_pattern": str,
|
||||
"quiet_window_start_bucket": int,
|
||||
"quiet_window_duration": int,
|
||||
}
|
||||
"""
|
||||
matched, pattern = detect_wind_down_phrase(cue, language)
|
||||
if not matched:
|
||||
return None
|
||||
|
||||
window = state.get("quiet_window") if isinstance(state, dict) else None
|
||||
if not window:
|
||||
return None
|
||||
if not is_late_in_quiet_window(window, now, tz):
|
||||
return None
|
||||
|
||||
start_bucket, duration = window
|
||||
return {
|
||||
"message_hint": "user_wind_down_detected",
|
||||
"matched_pattern": pattern,
|
||||
"quiet_window_start_bucket": int(start_bucket),
|
||||
"quiet_window_duration": int(duration),
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WIND_DOWN_AR",
|
||||
"WIND_DOWN_BY_LANG",
|
||||
"WIND_DOWN_DE",
|
||||
"WIND_DOWN_EN",
|
||||
"WIND_DOWN_ES",
|
||||
"WIND_DOWN_FR",
|
||||
"WIND_DOWN_GATE_MINUTES_BEFORE",
|
||||
"WIND_DOWN_JA",
|
||||
"WIND_DOWN_LANGUAGES_SUPPORTED",
|
||||
"WIND_DOWN_RU",
|
||||
"WIND_DOWN_ZH",
|
||||
"detect_wind_down",
|
||||
"detect_wind_down_phrase",
|
||||
"is_late_in_quiet_window",
|
||||
]
|
||||
179
src/iai_mcp/camouflaging.py
Normal file
179
src/iai_mcp/camouflaging.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""Plan 03-03 — camouflaging detector + register relaxer (ecological self-regulation).
|
||||
|
||||
Constitutional anchor:
|
||||
- Observes the user's SURFACE formality over a weekly sliding 5-point window.
|
||||
- On a sustained over-formal trajectory, adjusts OUR register (the 14th profile
|
||||
knob `camouflaging_relaxation`). NEVER pushes the user to change. NEVER models
|
||||
user internal-state (Cook 2021 / Raymaker 2020 — masking is out-of-scope).
|
||||
- Chapman 2021 ecological self-regulation framing: the system relaxes ITS OWN
|
||||
response register so the user does not have to match ours.
|
||||
|
||||
Detection (D-AUTIST13-03): sliding 5-point weekly window. Trigger condition:
|
||||
linear-regression slope > 0.05/week AND current mean > 0.6. Both must hold.
|
||||
|
||||
Event kinds emitted (new in Phase 3):
|
||||
- `formality_score_weekly` — weekly aggregate of the user's formality scores.
|
||||
- `camouflaging_detected` — the detector fired (over-formal trajectory confirmed).
|
||||
- `register_relaxed` — OUR `camouflaging_relaxation` knob was bumped UP (toward
|
||||
informal register in OUR responses).
|
||||
|
||||
Knob semantics: `camouflaging_relaxation` in [0, 1]. Higher = more relaxed OUR register.
|
||||
relax_register INCREMENTS the knob (pushing OUR output toward informal) when the user
|
||||
is observed to be over-formal. The user is never modified or nudged.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.events import query_events, write_event
|
||||
from iai_mcp.formality import formality_score
|
||||
from iai_mcp.profile import profile_get, profile_set
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- constants
|
||||
DEFAULT_WINDOW_SIZE: int = 5 # D-AUTIST13-03 sliding 5-point window
|
||||
DEFAULT_CADENCE_DAYS: int = 7 # weekly
|
||||
TRIGGER_SLOPE: float = 0.05 # formality delta per week floor
|
||||
TRIGGER_MEAN: float = 0.6 # absolute formality floor
|
||||
DEFAULT_DELTA: float = 0.1 # knob step per relaxation
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- detector
|
||||
def detect_camouflaging(
|
||||
store,
|
||||
*,
|
||||
window_size: int = DEFAULT_WINDOW_SIZE,
|
||||
cadence_days: int = DEFAULT_CADENCE_DAYS,
|
||||
) -> dict:
|
||||
"""Sliding 5-point weekly window detector (D-AUTIST13-03).
|
||||
|
||||
Reads the last `window_size` `formality_score_weekly` events, computes the
|
||||
linear-regression slope (numpy.polyfit deg=1), and the current mean. Detected
|
||||
iff slope > TRIGGER_SLOPE AND mean > TRIGGER_MEAN (both required).
|
||||
|
||||
Args:
|
||||
store: open MemoryStore.
|
||||
window_size: number of weekly points to consider (default 5).
|
||||
cadence_days: cadence label (default 7 = weekly); not used arithmetically
|
||||
but stored in event metadata by callers.
|
||||
|
||||
Returns:
|
||||
{detected: bool, trajectory_slope: float, current_mean: float, sample_count: int}.
|
||||
"""
|
||||
events = query_events(store, kind="formality_score_weekly", limit=window_size)
|
||||
# Events are newest-first; we want chronological order for slope.
|
||||
events = list(reversed(events))
|
||||
sample_count = len(events)
|
||||
|
||||
if sample_count < window_size:
|
||||
return {
|
||||
"detected": False,
|
||||
"trajectory_slope": 0.0,
|
||||
"current_mean": 0.0,
|
||||
"sample_count": sample_count,
|
||||
}
|
||||
|
||||
scores = np.asarray(
|
||||
[float(e["data"].get("score", 0.0)) for e in events], dtype=np.float64
|
||||
)
|
||||
xs = np.arange(len(scores), dtype=np.float64)
|
||||
slope, _intercept = np.polyfit(xs, scores, 1)
|
||||
current_mean = float(scores.mean())
|
||||
|
||||
detected = bool(slope > TRIGGER_SLOPE and current_mean > TRIGGER_MEAN)
|
||||
|
||||
return {
|
||||
"detected": detected,
|
||||
"trajectory_slope": float(slope),
|
||||
"current_mean": current_mean,
|
||||
"sample_count": sample_count,
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- relaxer
|
||||
def relax_register(store, *, delta: float = DEFAULT_DELTA) -> None:
|
||||
"""Bump profile.camouflaging_relaxation by delta (capped at 1.0).
|
||||
|
||||
Writes go through `profile.profile_set(..., store=store)` so the existing
|
||||
`profile_updated` event also fires alongside `register_relaxed`. This is the
|
||||
ONE pathway the system uses to relax its own register in response to a
|
||||
detected over-formal user trajectory (D-AUTIST13-02).
|
||||
"""
|
||||
import iai_mcp.core as core
|
||||
|
||||
current = core._profile_state.get("camouflaging_relaxation", 0.0)
|
||||
new_value = min(1.0, max(0.0, current + delta))
|
||||
|
||||
# Only call profile_set if the value actually changes; otherwise profile_set
|
||||
# will silently no-op and NOT emit profile_updated (correct behaviour).
|
||||
if new_value != current:
|
||||
profile_set(
|
||||
"camouflaging_relaxation",
|
||||
new_value,
|
||||
core._profile_state,
|
||||
store=store,
|
||||
)
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="register_relaxed",
|
||||
data={
|
||||
"from": float(current),
|
||||
"to": float(new_value),
|
||||
"delta": float(delta),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- recorder
|
||||
def record_user_formality(store, text: str, lang: str) -> None:
|
||||
"""Compute formality on USER surface text and emit a formality_score_weekly event.
|
||||
|
||||
Called on every user turn. Constitutional guard: the scorer sees ONLY the
|
||||
user's surface output; no inferred state is computed or persisted.
|
||||
"""
|
||||
score = formality_score(text, lang)
|
||||
now = datetime.now(timezone.utc)
|
||||
# Simple per-turn emit; aggregation is done at query time in detect_camouflaging
|
||||
# (taking last window_size). Per-week aggregation via week_iso tag for audit.
|
||||
week_iso = f"{now.year}-W{now.isocalendar()[1]:02d}"
|
||||
write_event(
|
||||
store,
|
||||
kind="formality_score_weekly",
|
||||
data={
|
||||
"score": float(score),
|
||||
"lang": lang,
|
||||
"week_iso": week_iso,
|
||||
"samples": 1,
|
||||
"timestamp": now.isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- weekly pass
|
||||
def run_weekly_pass(store) -> dict:
|
||||
"""Convenience entry: detect_camouflaging; if detected, emit
|
||||
`camouflaging_detected` event AND call relax_register.
|
||||
|
||||
Returns the detection result dict (same shape as detect_camouflaging).
|
||||
"""
|
||||
result = detect_camouflaging(store)
|
||||
if result["detected"]:
|
||||
write_event(
|
||||
store,
|
||||
kind="camouflaging_detected",
|
||||
data={
|
||||
"slope": result["trajectory_slope"],
|
||||
"mean": result["current_mean"],
|
||||
"window_size": DEFAULT_WINDOW_SIZE,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
relax_register(store)
|
||||
return result
|
||||
520
src/iai_mcp/capture.py
Normal file
520
src/iai_mcp/capture.py
Normal file
|
|
@ -0,0 +1,520 @@
|
|||
"""Plan 06 memory_capture (WRITE-side ambient gap closure).
|
||||
|
||||
Context: prior phases shipped ambient READ (session_start compact handle) and
|
||||
ambient daemon (sleep cycles, REM, overnight digest). WRITE-side capture of
|
||||
conversation content was architectural gap — nothing in iai-mcp automatically
|
||||
ingested what the user said or what Claude decided during a session.
|
||||
|
||||
This module closes that gap with two entry points:
|
||||
|
||||
1. `capture_turn(store, cue, text, tier, session_id)`:
|
||||
in-session, explicit. Called via MCP tool `memory_capture` when Claude
|
||||
detects a surprising correction, load-bearing decision, or lesson.
|
||||
|
||||
2. `capture_transcript(store, transcript_path, session_id)`:
|
||||
end-of-session, ambient. Called by `~/.claude/hooks/iai-mcp-session-capture.sh`
|
||||
Stop-hook on SessionEnd. Reads Claude Code JSONL transcript, extracts
|
||||
user + assistant turns, filters through shield + dedup, inserts records.
|
||||
|
||||
Both paths respect:
|
||||
- Shield: HARD_BLOCK drops the record; FLAG_FOR_REVIEW stores with tag
|
||||
(policy: user chose visibility over paranoia, 2026-04-20).
|
||||
- Dedup: if query_similar returns a hit with cos >= DEDUP_THRESHOLD
|
||||
(0.95), we reinforce instead of insert (boost Hebbian edge).
|
||||
- Language: detected via langdetect; falls back to 'en' on ambiguity.
|
||||
- Encryption: goes through the standard store.insert() path which handles
|
||||
AES-256-GCM column encryption.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
# R3 deviation [Rule 3 - blocking import cost]: `iai_mcp.embed` pulls
|
||||
# in transformers + torch (~2.9s cold import). Loading capture.py for the
|
||||
# `--no-spawn` deferred path (which never embeds anything) blew the R3 2s
|
||||
# wall-clock budget. Moved to lazy import inside `capture_turn` — keeps the
|
||||
# write_deferred_captures cold path under ~1s. `from __future__ import
|
||||
# annotations` (line 29) keeps type hints intact without runtime import.
|
||||
# `MemoryStore` left at module top — its 0.4s import is acceptable.
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import (
|
||||
SCHEMA_VERSION_CURRENT,
|
||||
TIER_ENUM,
|
||||
MemoryRecord,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEDUP_COS_THRESHOLD = 0.95
|
||||
MIN_CAPTURE_LEN = 12
|
||||
MAX_CAPTURE_LEN = 8000
|
||||
|
||||
|
||||
def _detect_language(text: str) -> str:
|
||||
"""Best-effort ISO-639-1 via langdetect; 'en' on any failure."""
|
||||
try:
|
||||
from langdetect import detect # lazy: already a project dep
|
||||
|
||||
code = detect(text[:500])
|
||||
return code if len(code) == 2 else "en"
|
||||
except Exception:
|
||||
return "en"
|
||||
|
||||
|
||||
def _run_shield(text: str) -> tuple[str, list[str]]:
|
||||
"""Run shield; return (verdict, tags) where verdict in HARD_BLOCK|FLAG|OK."""
|
||||
try:
|
||||
from iai_mcp.shield import evaluate
|
||||
|
||||
result = evaluate(text)
|
||||
verdict = getattr(result, "verdict", "OK")
|
||||
tags = list(getattr(result, "tags", []) or [])
|
||||
return verdict, tags
|
||||
except Exception:
|
||||
return "OK", []
|
||||
|
||||
|
||||
def capture_turn(
|
||||
store: MemoryStore,
|
||||
*,
|
||||
cue: str,
|
||||
text: str,
|
||||
tier: str = "episodic",
|
||||
session_id: str = "-",
|
||||
role: str = "user",
|
||||
) -> dict[str, Any]:
|
||||
"""Write a single conversation turn to the iai-mcp store.
|
||||
|
||||
Returns {"status": "inserted|reinforced|skipped", "record_id": uuid-or-null,
|
||||
"reason": short-string}.
|
||||
"""
|
||||
if tier not in TIER_ENUM:
|
||||
return {"status": "skipped", "record_id": None, "reason": f"invalid tier {tier!r}"}
|
||||
|
||||
text = (text or "").strip()
|
||||
if len(text) < MIN_CAPTURE_LEN:
|
||||
return {"status": "skipped", "record_id": None, "reason": "too short"}
|
||||
if len(text) > MAX_CAPTURE_LEN:
|
||||
text = text[:MAX_CAPTURE_LEN]
|
||||
|
||||
verdict, shield_tags = _run_shield(text)
|
||||
if verdict == "HARD_BLOCK":
|
||||
return {"status": "skipped", "record_id": None, "reason": "shield HARD_BLOCK"}
|
||||
|
||||
# Lazy import: keeps the cold module-import cost low for the
|
||||
# `--no-spawn` deferred path (Phase 7.1 R3) which never embeds.
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
|
||||
emb = embedder_for_store(store).embed(cue or text)
|
||||
embedding = list(emb)
|
||||
|
||||
# Dedup: query_similar against existing records at the same tier.
|
||||
# Phase 07.11-01 / query_similar accepts a `tier` kwarg natively
|
||||
# (Bug A fix), returns list[tuple[MemoryRecord, float]] (legacy contract,
|
||||
# unchanged shape -- we unpack the tuple correctly in the loop body, Bug B
|
||||
# fix), and the dedup hit reinforces via the typed `reinforce_record`
|
||||
# wrapper (Bug C fix -- single-uuid argument shape against a single-uuid
|
||||
# API).
|
||||
try:
|
||||
neighbours = store.query_similar(embedding, k=3, tier=tier)
|
||||
except (ValueError, IOError) as exc:
|
||||
# Genuinely-recoverable cases only: bad tier validation surfaces as
|
||||
# ValueError (already caught by query_similar's pre-I/O guard); transient
|
||||
# LanceDB I/O surfaces as IOError. A TypeError from a wrong call shape
|
||||
# MUST surface in tests -- the silent `except Exception: pass` blanket
|
||||
# is removed deliberately (D-01 contract).
|
||||
log.warning(
|
||||
"capture_dedup_query_failed",
|
||||
extra={"err_type": type(exc).__name__, "err": str(exc)[:120]},
|
||||
)
|
||||
neighbours = []
|
||||
|
||||
for record, score in neighbours: # tuple-unpack -- fix for Bug B
|
||||
if score >= DEDUP_COS_THRESHOLD:
|
||||
# Single-record reinforcement: route through reinforce_record
|
||||
#, NOT boost_edges([UUID(...)]) which expects pairs.
|
||||
try:
|
||||
store.reinforce_record(record.id)
|
||||
except (ValueError, IOError) as exc:
|
||||
# Reinforce is best-effort observability; log and continue
|
||||
# so the duplicate is still detected even if the LTP write
|
||||
# fails. Same narrowed-except discipline as the query above.
|
||||
log.warning(
|
||||
"capture_dedup_reinforce_failed",
|
||||
extra={
|
||||
"err_type": type(exc).__name__,
|
||||
"record_id": str(record.id),
|
||||
},
|
||||
)
|
||||
return {
|
||||
"status": "reinforced",
|
||||
"record_id": str(record.id),
|
||||
"reason": f"cos={score:.3f} >= {DEDUP_COS_THRESHOLD}",
|
||||
}
|
||||
|
||||
tags = ["capture", f"role:{role}"]
|
||||
if verdict == "FLAG_FOR_REVIEW":
|
||||
tags.append("shield:flagged")
|
||||
tags.extend(f"shield:{t}" for t in shield_tags[:3])
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
rec = MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier=tier,
|
||||
literal_surface=text,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[{"ts": now.isoformat(), "cue": cue or "(auto-capture)",
|
||||
"session_id": session_id, "role": role}],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=tags,
|
||||
language=_detect_language(text),
|
||||
s5_trust_score=0.5,
|
||||
profile_modulation_gain={},
|
||||
schema_version=SCHEMA_VERSION_CURRENT,
|
||||
)
|
||||
|
||||
try:
|
||||
store.insert(rec)
|
||||
except Exception as e:
|
||||
log.exception("capture_turn insert failed")
|
||||
return {"status": "skipped", "record_id": None, "reason": f"insert-failed: {type(e).__name__}"}
|
||||
|
||||
return {"status": "inserted", "record_id": str(rec.id), "reason": f"tier={tier}"}
|
||||
|
||||
|
||||
def capture_transcript(
|
||||
store: MemoryStore,
|
||||
transcript_path: Path | str,
|
||||
*,
|
||||
session_id: str = "-",
|
||||
max_turns: int = 200,
|
||||
) -> dict[str, Any]:
|
||||
"""Read a Claude Code JSONL transcript, capture user + assistant turns.
|
||||
|
||||
Returns {"inserted": N, "reinforced": M, "skipped": K, "errors": E}.
|
||||
"""
|
||||
path = Path(transcript_path).expanduser()
|
||||
if not path.exists():
|
||||
return {"inserted": 0, "reinforced": 0, "skipped": 0, "errors": 1,
|
||||
"reason": f"transcript not found: {path}"}
|
||||
|
||||
counts = {"inserted": 0, "reinforced": 0, "skipped": 0, "errors": 0}
|
||||
seen = 0
|
||||
with path.open() as fh:
|
||||
for line in fh:
|
||||
if seen >= max_turns:
|
||||
break
|
||||
seen += 1
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except Exception:
|
||||
counts["errors"] += 1
|
||||
continue
|
||||
msg = obj.get("message") if isinstance(obj.get("message"), dict) else obj
|
||||
role = obj.get("type") or msg.get("role", "")
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Claude Code messages use block format; collect text blocks
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
text = "\n".join(text_parts).strip()
|
||||
else:
|
||||
text = str(content).strip()
|
||||
if not text:
|
||||
continue
|
||||
result = capture_turn(
|
||||
store,
|
||||
cue=f"session {session_id} turn {seen}",
|
||||
text=text,
|
||||
tier="episodic",
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
)
|
||||
status = result.get("status", "skipped")
|
||||
if status in counts:
|
||||
counts[status] += 1
|
||||
else:
|
||||
counts["skipped"] += 1
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# R3 / D7.1-04: deferred-captures writer for `--no-spawn` hook mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def write_deferred_captures(
|
||||
session_id: str,
|
||||
transcript_path: Path | str,
|
||||
*,
|
||||
cwd: str | None = None,
|
||||
max_turns: int = 200,
|
||||
) -> Path:
|
||||
"""Defer transcript capture by writing events to a JSONL file under
|
||||
``~/.iai-mcp/.deferred-captures/``. Returns the path written.
|
||||
|
||||
Used by ``iai-mcp capture-transcript --no-spawn`` (R3, D7.1-04) when the
|
||||
daemon is unreachable. The Stop hook calls this so it never blocks
|
||||
session teardown waiting for a daemon spawn (the third spawn vector
|
||||
forensic anomaly #3 from ``report-20260426-150300.md``).
|
||||
|
||||
The daemon's drain loop (Plan 07.1-05b, in daemon.py / WAKE handler)
|
||||
consumes these on next WAKE. Format is JSONL v1 per D7.1-04:
|
||||
|
||||
- Line 1: header ``{"version":1,"deferred_at":<ISO>,"session_id":<id>,"cwd":<path>}``
|
||||
- Lines 2..N: one event per user/assistant turn
|
||||
``{"text":<verbatim>,"cue":<short>,"tier":"episodic","role":<u|a>,"ts":<ISO>}``
|
||||
|
||||
Pure-write: no MemoryStore touch, no socket touch, no daemon import.
|
||||
Uses ``Path.home()`` at call time so HOME-monkeypatched tests get the
|
||||
right tmp dir. Idempotent ``mkdir(parents=True, exist_ok=True)``.
|
||||
|
||||
Args:
|
||||
session_id: Claude Code session id (provenance + filename component).
|
||||
transcript_path: path to the JSONL transcript file (or non-existent —
|
||||
we write the header then return; daemon drain treats as no-op).
|
||||
cwd: optional CWD override for the header (defaults to ``os.getcwd()``).
|
||||
max_turns: cap on transcript turns to emit (default 200, matches
|
||||
``capture_transcript`` semantics).
|
||||
|
||||
Returns:
|
||||
``Path`` of the written ``.jsonl`` file.
|
||||
|
||||
Notes:
|
||||
- Filename pattern ``{session_id}-{int(time.time())}.jsonl`` — the
|
||||
unix-ts suffix avoids collisions if the same session captures
|
||||
multiple times.
|
||||
- Reuses the same parsing logic as ``capture_transcript`` so the
|
||||
deferred path and the inline path stay consistent.
|
||||
- Returns even on missing transcript (writes header only) — daemon
|
||||
drain treats as no-op. Hook MUST never raise here.
|
||||
- Stdlib only: ``json``, ``time``, ``pathlib.Path``, ``datetime``, ``os``.
|
||||
"""
|
||||
deferred_dir = Path.home() / ".iai-mcp" / ".deferred-captures"
|
||||
deferred_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = deferred_dir / f"{session_id}-{int(time.time())}.jsonl"
|
||||
with out_path.open("w") as fh:
|
||||
# Header (line 1, version=1 forward-compat marker per D7.1-04).
|
||||
header = {
|
||||
"version": 1,
|
||||
"deferred_at": datetime.now(timezone.utc).isoformat(),
|
||||
"session_id": session_id,
|
||||
"cwd": cwd or os.getcwd(),
|
||||
}
|
||||
fh.write(json.dumps(header, ensure_ascii=False) + "\n")
|
||||
# Read transcript and emit one event per user/assistant turn.
|
||||
path = Path(transcript_path).expanduser()
|
||||
if not path.exists():
|
||||
return out_path # empty body — daemon drain will treat as no-op
|
||||
seen = 0
|
||||
with path.open() as src:
|
||||
for line in src:
|
||||
if seen >= max_turns:
|
||||
break
|
||||
seen += 1
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
msg = obj.get("message") if isinstance(obj.get("message"), dict) else obj
|
||||
role = obj.get("type") or msg.get("role", "")
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
text = "\n".join(text_parts).strip()
|
||||
else:
|
||||
text = str(content).strip()
|
||||
if not text:
|
||||
continue
|
||||
event = {
|
||||
"text": text,
|
||||
"cue": f"session {session_id} turn {seen}",
|
||||
"tier": "episodic",
|
||||
"role": role,
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
fh.write(json.dumps(event, ensure_ascii=False) + "\n")
|
||||
return out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# R3 / D7.1-04: deferred-captures drain (READ side, daemon-resident)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_deferred_captures(store: MemoryStore) -> dict[str, int]:
|
||||
"""Consume ``~/.iai-mcp/.deferred-captures/*.jsonl`` produced by
|
||||
``iai-mcp capture-transcript --no-spawn`` (Plan 07.1-05 WRITE side).
|
||||
|
||||
For each ``.jsonl`` file in the deferred-captures dir:
|
||||
|
||||
* Read line 1 (header). If ``version > 1`` (forward-compat guard), log a
|
||||
"skip" line to ``~/.iai-mcp/logs/deferred-drain-YYYY-MM-DD.log`` and
|
||||
leave the file in place — a future daemon version will know how to
|
||||
handle it.
|
||||
* For each event line (lines 2..N), call ``capture_turn(store, ...)``
|
||||
and inspect its return-status dict. W2 / D-02:
|
||||
- status="inserted" → events_inserted += 1
|
||||
- status="reinforced" → events_reinforced += 1
|
||||
- status="skipped" with reason matching ^insert-failed:* (capture_turn
|
||||
path where store.insert raised) → events_skipped_insert_failed += 1
|
||||
and the WHOLE FILE is treated as failed: renamed to
|
||||
.failed-<ts>.jsonl, NOT unlinked.
|
||||
- status="skipped" with any other reason (shield HARD_BLOCK, too short,
|
||||
invalid tier — all *intentional* drops) → events_skipped_intentional
|
||||
+= 1.
|
||||
* On full success (zero insert-failed events): delete the file,
|
||||
files_drained += 1.
|
||||
* On any insert-failed event: rename the file to
|
||||
``<basename>.failed-<unix_ts>.jsonl`` (preserves evidence for manual
|
||||
inspection), log a "insert-failed" line with the first error,
|
||||
files_failed += 1.
|
||||
* On parser/header exception: same outer rename + log path as before
|
||||
(existing behaviour), files_failed += 1.
|
||||
* On 0-byte / empty file: delete it (no-op header-only deferral).
|
||||
|
||||
Idempotent: re-running on a directory with no ``.jsonl`` files (or no
|
||||
deferred-captures dir at all) returns zero counts without error.
|
||||
|
||||
Returns dict with keys:
|
||||
files_drained, files_failed,
|
||||
events_inserted, events_reinforced,
|
||||
events_skipped_intentional, events_skipped_insert_failed.
|
||||
|
||||
Notes:
|
||||
- Uses ``Path.home()`` at call time so HOME-monkeypatched tests get
|
||||
the right tmp dir.
|
||||
- Stdlib only — no new deps.
|
||||
- Caller (daemon.main / _tick_body) MUST wrap in try/except so a
|
||||
drain crash never propagates into the asyncio event loop. This
|
||||
function itself catches per-file exceptions defensively.
|
||||
- The ``store`` argument is the same MemoryStore instance the
|
||||
daemon uses for all other writes (so connection/lock semantics
|
||||
are consistent). Drain MUST run inside ``asyncio.to_thread`` from
|
||||
async callers because ``capture_turn`` does sync LanceDB I/O.
|
||||
"""
|
||||
deferred_dir = Path.home() / ".iai-mcp" / ".deferred-captures"
|
||||
log_dir = Path.home() / ".iai-mcp" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_path = (
|
||||
log_dir / f"deferred-drain-{datetime.now(timezone.utc).strftime('%Y-%m-%d')}.log"
|
||||
)
|
||||
counts = {
|
||||
"files_drained": 0,
|
||||
"files_failed": 0,
|
||||
"events_inserted": 0,
|
||||
"events_reinforced": 0,
|
||||
"events_skipped_intentional": 0,
|
||||
"events_skipped_insert_failed": 0,
|
||||
}
|
||||
if not deferred_dir.exists():
|
||||
return counts
|
||||
for fpath in sorted(deferred_dir.glob("*.jsonl")):
|
||||
file_had_insert_failure = False
|
||||
file_first_error: str | None = None
|
||||
try:
|
||||
with fpath.open() as fh:
|
||||
lines = [ln.rstrip("\n") for ln in fh if ln.strip()]
|
||||
if not lines:
|
||||
# Empty file (e.g. partial write that never got header) — drop.
|
||||
fpath.unlink()
|
||||
continue
|
||||
header = json.loads(lines[0])
|
||||
if header.get("version", 0) > 1:
|
||||
# Forward-compat guard: leave the file in place; a future
|
||||
# daemon revision will know the format. Log + continue.
|
||||
with log_path.open("a") as logf:
|
||||
logf.write(
|
||||
f"{datetime.now(timezone.utc).isoformat()} skip {fpath.name}: "
|
||||
f"version={header.get('version')}\n"
|
||||
)
|
||||
continue
|
||||
session_id = header.get("session_id", "-")
|
||||
event_lines = lines[1:]
|
||||
for ln in event_lines:
|
||||
ev = json.loads(ln)
|
||||
# Reuse capture_turn so the deferred path lands in the same
|
||||
# shield + dedup + encryption pipeline as live captures.
|
||||
result = capture_turn(
|
||||
store,
|
||||
cue=ev.get("cue", ""),
|
||||
text=ev.get("text", ""),
|
||||
tier=ev.get("tier", "episodic"),
|
||||
session_id=session_id,
|
||||
role=ev.get("role", "user"),
|
||||
)
|
||||
status = result.get("status", "skipped")
|
||||
reason = result.get("reason", "")
|
||||
if status == "inserted":
|
||||
counts["events_inserted"] += 1
|
||||
elif status == "reinforced":
|
||||
counts["events_reinforced"] += 1
|
||||
elif status == "skipped" and reason.startswith("insert-failed:"):
|
||||
counts["events_skipped_insert_failed"] += 1
|
||||
file_had_insert_failure = True
|
||||
if file_first_error is None:
|
||||
file_first_error = reason
|
||||
else:
|
||||
counts["events_skipped_intentional"] += 1
|
||||
if file_had_insert_failure:
|
||||
# preserve the file as evidence — at least one
|
||||
# event hit the insert-failed code path inside capture_turn
|
||||
# (store.insert raised, capture_turn swallowed and returned
|
||||
# status=skipped reason=insert-failed:*). Pre-07.9 the file
|
||||
# was unlinked here and the data was silently lost.
|
||||
failed_path = fpath.with_suffix(f".failed-{int(time.time())}.jsonl")
|
||||
fpath.rename(failed_path)
|
||||
with log_path.open("a") as logf:
|
||||
logf.write(
|
||||
f"{datetime.now(timezone.utc).isoformat()} insert-failed "
|
||||
f"{fpath.name}: first_error={file_first_error}\n"
|
||||
)
|
||||
counts["files_failed"] += 1
|
||||
else:
|
||||
fpath.unlink()
|
||||
counts["files_drained"] += 1
|
||||
except Exception as e: # noqa: BLE001 -- per-file isolation, never raise
|
||||
try:
|
||||
# Preserve evidence: rename so the next drain pass skips it
|
||||
# AND a human can inspect the failure.
|
||||
failed_path = fpath.with_suffix(f".failed-{int(time.time())}.jsonl")
|
||||
fpath.rename(failed_path)
|
||||
with log_path.open("a") as logf:
|
||||
logf.write(
|
||||
f"{datetime.now(timezone.utc).isoformat()} failed "
|
||||
f"{fpath.name}: {type(e).__name__}: {e}\n"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
counts["files_failed"] += 1
|
||||
return counts
|
||||
522
src/iai_mcp/capture_queue.py
Normal file
522
src/iai_mcp/capture_queue.py
Normal file
|
|
@ -0,0 +1,522 @@
|
|||
"""Phase 10.2 -- persistent capture queue with atomic append + idempotent ingest.
|
||||
|
||||
The capture queue is the durable buffer that makes the L1 hibernation contract
|
||||
viable. Wrapper writes to ``~/.iai-mcp/pending/`` whenever the daemon socket
|
||||
is unreachable (Hibernation, mid-restart, crashed). On the next Wake transition
|
||||
the daemon drains the queue via ``ingest_pending(handler)`` -- the handler
|
||||
plugs into the existing ``iai_mcp.capture`` path so the verbatim contract
|
||||
(Phase 5/6) is preserved end-to-end.
|
||||
|
||||
Storage layout under ``~/.iai-mcp/pending/``::
|
||||
|
||||
pending-<ulid>.json -- one queued record (committed file)
|
||||
pending-<ulid>.json.tmp -- transient temp file before atomic rename
|
||||
pending-<ulid>.lock -- present only during in-flight ingest of <ulid>
|
||||
.overflow-audit.log -- JSONL append-only log of dropped-oldest events
|
||||
|
||||
Hard guarantees:
|
||||
|
||||
- **Atomic append**: writes go to ``.tmp`` then ``os.replace`` to final name
|
||||
(POSIX atomic rename). A crash mid-write leaves a stray ``.tmp`` but never
|
||||
a half-written final file. ``pending_count`` and ``list_pending`` ignore
|
||||
``.tmp``.
|
||||
- **Idempotent ingest**: each pending file is claimed via ``fcntl.flock`` on
|
||||
the matching ``.lock`` file. Lock contention => skip (another worker has
|
||||
it). Handler success => delete pending + lock atomically. Handler raises
|
||||
=> leave both intact for next-call retry.
|
||||
- **Bounded queue**: ``append`` triggers ``prune_oldest`` once
|
||||
``pending_count > max_size``. Drops the oldest ``max_size - 9_900`` files
|
||||
in one batch (amortised I/O) and writes one JSONL line per drop to the
|
||||
audit log.
|
||||
- **Verbatim round-trip**: the JSON payload uses ``ensure_ascii=False`` so
|
||||
``record["surface"]`` round-trips byte-identically including UTF-8 BMP +
|
||||
astral characters and combining marks.
|
||||
- **No new deps**: stdlib only -- ``os, pathlib, json, uuid, fcntl, secrets,
|
||||
time, datetime, threading, errno``.
|
||||
|
||||
ULID derivation: 48-bit millisecond unix timestamp (big-endian) + 80 bits of
|
||||
``secrets.token_bytes`` randomness, encoded with Crockford base32 per the
|
||||
ulid spec (https://github.com/ulid/spec). The result is 26 characters,
|
||||
lexicographically sortable by time, and collision-resistant for thousands of
|
||||
appends per millisecond. Implemented inline -- the project deliberately
|
||||
avoids a ``python-ulid`` dependency.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defaults / configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_QUEUE_DIR: Path = Path.home() / ".iai-mcp" / "pending"
|
||||
"""Production location for the persistent queue."""
|
||||
|
||||
DEFAULT_MAX_SIZE: int = 10_000
|
||||
"""Default ceiling before ``prune_oldest`` kicks in."""
|
||||
|
||||
# Drop ~100 oldest at once when overflowing so the I/O cost is amortised
|
||||
# across many subsequent appends rather than paid on every single overflow.
|
||||
_PRUNE_BATCH_HEADROOM: int = 100
|
||||
|
||||
SCHEMA_VERSION: int = 1
|
||||
"""Bumped only when the on-disk pending-<ulid>.json layout changes."""
|
||||
|
||||
_AUDIT_LOG_NAME: str = ".overflow-audit.log"
|
||||
|
||||
# Crockford base32 alphabet (no I, L, O, U) per ulid spec.
|
||||
_CROCKFORD: str = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CaptureQueueError(Exception):
|
||||
"""Base class for all capture-queue errors."""
|
||||
|
||||
|
||||
class CaptureQueueSchemaError(CaptureQueueError):
|
||||
"""Raised when a pending file declares a ``schema_version`` we don't grok."""
|
||||
|
||||
|
||||
class CaptureQueueLocked(CaptureQueueError):
|
||||
"""Raised when an in-flight ingest cannot acquire the per-record lock.
|
||||
|
||||
Currently only used internally; ``ingest_pending`` swallows lock contention
|
||||
and treats the file as "claimed by another worker" rather than raising.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ULID generator (stdlib-only, time-sortable)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Monotonic-ish guard: if two ULIDs would land in the same millisecond, bump
|
||||
# the timestamp by 1ms so lexicographic sort matches insertion order. The
|
||||
# bump is bounded -- once wall clock advances past the bumped value the
|
||||
# guard resets. Threadsafe via a module-level lock.
|
||||
_ulid_lock = threading.Lock()
|
||||
_last_ms: int = 0
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
"""Current wall-clock time in unix milliseconds (UTC)."""
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _b32_encode(data: bytes, length: int) -> str:
|
||||
"""Crockford base32 encode ``data`` to exactly ``length`` characters.
|
||||
|
||||
``data`` is treated as an unsigned big-endian integer. Result is
|
||||
zero-padded on the left if the integer would naturally render to
|
||||
fewer characters. Caller is responsible for sizing ``length``
|
||||
correctly: 10 chars for the 48-bit timestamp prefix, 16 chars for
|
||||
the 80-bit randomness suffix.
|
||||
"""
|
||||
n = int.from_bytes(data, "big")
|
||||
out = []
|
||||
for _ in range(length):
|
||||
out.append(_CROCKFORD[n & 0x1F])
|
||||
n >>= 5
|
||||
return "".join(reversed(out))
|
||||
|
||||
|
||||
def generate_ulid() -> str:
|
||||
"""Return a fresh 26-character Crockford-base32 ULID.
|
||||
|
||||
The first 10 chars encode the millisecond unix timestamp; the next 16
|
||||
encode 80 bits of random data. Lexicographic sort = chronological sort
|
||||
(with millisecond resolution; finer ordering within a millisecond is
|
||||
not guaranteed by ULID itself but the monotonic guard below preserves
|
||||
insertion order in practice).
|
||||
"""
|
||||
global _last_ms
|
||||
with _ulid_lock:
|
||||
ms = _now_ms()
|
||||
if ms <= _last_ms:
|
||||
ms = _last_ms + 1
|
||||
_last_ms = ms
|
||||
|
||||
ts_bytes = ms.to_bytes(6, "big") # 48 bits
|
||||
rand_bytes = secrets.token_bytes(10) # 80 bits
|
||||
return _b32_encode(ts_bytes, 10) + _b32_encode(rand_bytes, 16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CaptureQueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CaptureQueue:
|
||||
"""Persistent on-disk FIFO buffer for ``memory_capture`` records.
|
||||
|
||||
See module docstring for storage layout and guarantees.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_dir: Path | None = None,
|
||||
max_size: int = DEFAULT_MAX_SIZE,
|
||||
) -> None:
|
||||
if max_size <= 0:
|
||||
raise ValueError(f"max_size must be positive, got {max_size}")
|
||||
self._queue_dir = (
|
||||
Path(queue_dir) if queue_dir is not None else DEFAULT_QUEUE_DIR
|
||||
)
|
||||
self._queue_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._max_size = max_size
|
||||
self._audit_log = self._queue_dir / _AUDIT_LOG_NAME
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def queue_dir(self) -> Path:
|
||||
"""Filesystem location of the queue directory."""
|
||||
return self._queue_dir
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Maximum number of pending records before overflow pruning kicks in."""
|
||||
return self._max_size
|
||||
|
||||
@property
|
||||
def audit_log_path(self) -> Path:
|
||||
"""Path to ``.overflow-audit.log`` (may not exist if no overflows happened)."""
|
||||
return self._audit_log
|
||||
|
||||
def pending_count(self) -> int:
|
||||
"""Return number of committed pending files (ignores ``.tmp`` and ``.lock``)."""
|
||||
return sum(1 for _ in self._iter_pending_files())
|
||||
|
||||
def list_pending(self) -> list[Path]:
|
||||
"""Return committed pending files sorted by ULID (oldest first)."""
|
||||
return sorted(self._iter_pending_files(), key=lambda p: p.name)
|
||||
|
||||
def _iter_pending_files(self):
|
||||
"""Yield every ``pending-<ulid>.json`` (no ``.tmp``, no ``.lock``)."""
|
||||
for entry in self._queue_dir.iterdir():
|
||||
name = entry.name
|
||||
if (
|
||||
entry.is_file()
|
||||
and name.startswith("pending-")
|
||||
and name.endswith(".json")
|
||||
and not name.endswith(".json.tmp")
|
||||
):
|
||||
yield entry
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Append (atomic temp + rename)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def append(self, record: dict) -> str:
|
||||
"""Append a record to the queue. Returns the assigned ULID.
|
||||
|
||||
Atomic: writes ``pending-<ulid>.json.tmp`` then ``os.replace`` to
|
||||
``pending-<ulid>.json``. A crash between write and rename leaves a
|
||||
stray ``.tmp`` (cleaned up by future ``prune_oldest`` if it ever
|
||||
looks at the directory listing -- but ``pending_count`` already
|
||||
ignores it). Triggers ``prune_oldest`` once the post-append count
|
||||
exceeds ``max_size``.
|
||||
"""
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(f"record must be a dict, got {type(record).__name__}")
|
||||
|
||||
ulid = generate_ulid()
|
||||
appended_at = datetime.now(timezone.utc).isoformat()
|
||||
envelope: dict = {
|
||||
"ulid": ulid,
|
||||
"appended_at": appended_at,
|
||||
"record": record,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
}
|
||||
|
||||
final_path = self._queue_dir / f"pending-{ulid}.json"
|
||||
tmp_path = self._queue_dir / f"pending-{ulid}.json.tmp"
|
||||
|
||||
# Open with O_CREAT|O_EXCL|O_WRONLY so a colliding ULID is detected
|
||||
# rather than silently overwriting (collision => generate_ulid bug).
|
||||
# 0o600 keeps records user-only on disk.
|
||||
fd = os.open(
|
||||
str(tmp_path),
|
||||
os.O_WRONLY | os.O_CREAT | os.O_EXCL,
|
||||
0o600,
|
||||
)
|
||||
try:
|
||||
payload = json.dumps(
|
||||
envelope,
|
||||
ensure_ascii=False, # verbatim Unicode round-trip
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
os.write(fd, payload)
|
||||
os.fsync(fd)
|
||||
except Exception:
|
||||
# On any failure between open and rename, drop the temp file so
|
||||
# we don't accumulate orphans. If the unlink itself fails (very
|
||||
# unlikely on a file we just created) re-raise the original.
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
# POSIX-atomic rename: visible-or-not, never half-visible.
|
||||
os.replace(tmp_path, final_path)
|
||||
|
||||
# Overflow check happens AFTER the rename so the new record is
|
||||
# never the one we drop -- prune_oldest by definition drops the
|
||||
# oldest, not the newest.
|
||||
if self.pending_count() > self._max_size:
|
||||
target = max(0, self._max_size - _PRUNE_BATCH_HEADROOM)
|
||||
self.prune_oldest(target_size=target)
|
||||
|
||||
return ulid
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Ingest (idempotent, lock-claimed)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def ingest_pending(self, handler: Callable[[dict], None]) -> int:
|
||||
"""Drain pending records via ``handler``. Returns count successfully ingested.
|
||||
|
||||
For each pending file (oldest first):
|
||||
|
||||
1. ``open`` ``pending-<ulid>.lock`` (creating if needed).
|
||||
2. ``fcntl.flock(LOCK_EX | LOCK_NB)`` -- if already locked, skip.
|
||||
3. Read + JSON-decode ``pending-<ulid>.json``; raise
|
||||
``CaptureQueueSchemaError`` on schema mismatch.
|
||||
4. Call ``handler(record)`` where ``record`` is the inner dict
|
||||
(not the envelope).
|
||||
5. On success: ``unlink`` the pending file FIRST (so a crash
|
||||
between unlink calls cannot resurrect a deleted record), then
|
||||
release the lock and unlink the lock file.
|
||||
6. On handler exception: release the lock fd but leave the lock
|
||||
file AND the pending file on disk. Future calls retry.
|
||||
|
||||
Schema errors propagate to the caller after closing fds for the
|
||||
offending file -- we do NOT swallow them, because a schema bump
|
||||
is a deploy-time event the caller needs to see.
|
||||
"""
|
||||
if not callable(handler):
|
||||
raise TypeError("handler must be callable")
|
||||
|
||||
ingested = 0
|
||||
for pending_path in self.list_pending():
|
||||
ulid = self._ulid_from_path(pending_path)
|
||||
lock_path = self._queue_dir / f"pending-{ulid}.lock"
|
||||
|
||||
# Open (or create) the lock file. 0o600 to keep it user-only.
|
||||
try:
|
||||
lock_fd = os.open(
|
||||
str(lock_path),
|
||||
os.O_WRONLY | os.O_CREAT,
|
||||
0o600,
|
||||
)
|
||||
except OSError:
|
||||
# Cannot even create the lock -- skip this record. Leave
|
||||
# the pending file in place so a future retry can pick
|
||||
# it up once the disk situation clears.
|
||||
continue
|
||||
|
||||
try:
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except OSError as exc:
|
||||
# EWOULDBLOCK / EAGAIN => another worker has the lock.
|
||||
# Anything else: surface it; we don't expect it here.
|
||||
if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
|
||||
continue
|
||||
raise
|
||||
|
||||
# Lock acquired. The pending file may have been deleted
|
||||
# between list_pending and now (rare race with another
|
||||
# worker that claimed-and-finished), so re-check.
|
||||
if not pending_path.exists():
|
||||
continue
|
||||
|
||||
envelope = self._read_envelope(pending_path)
|
||||
# Schema check -- raise loud so deploys notice.
|
||||
version = envelope.get("schema_version")
|
||||
if version != SCHEMA_VERSION:
|
||||
raise CaptureQueueSchemaError(
|
||||
f"unsupported schema_version={version!r} in "
|
||||
f"{pending_path.name}; expected {SCHEMA_VERSION}",
|
||||
)
|
||||
|
||||
record = envelope["record"]
|
||||
# Handler runs OUTSIDE any try/except below: if it raises,
|
||||
# we explicitly leave the pending file + lock file on disk
|
||||
# for the next call to retry.
|
||||
handler(record)
|
||||
|
||||
# Handler returned cleanly: delete pending FIRST to make
|
||||
# the success durable; lock cleanup is best-effort.
|
||||
try:
|
||||
os.unlink(pending_path)
|
||||
except FileNotFoundError:
|
||||
# Already gone -- another worker raced us. Treat as
|
||||
# success since the record is no longer pending.
|
||||
pass
|
||||
ingested += 1
|
||||
finally:
|
||||
# Always release + unlink the lock fd. If the handler
|
||||
# raised, the bare ``finally`` runs before the exception
|
||||
# propagates, so the lock fd never leaks.
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
pass
|
||||
os.close(lock_fd)
|
||||
# Only unlink the lock file if we ALSO unlinked the pending
|
||||
# file (i.e. a clean handler success). On handler exception
|
||||
# we want the lock file to remain so a follow-up
|
||||
# ``ingest_pending`` can detect mid-flight crash state.
|
||||
if not pending_path.exists():
|
||||
try:
|
||||
os.unlink(lock_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return ingested
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Overflow pruning
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def prune_oldest(self, target_size: int | None = None) -> int:
|
||||
"""Drop oldest pending files until count <= ``target_size``.
|
||||
|
||||
``target_size`` defaults to ``max_size`` -- in normal overflow flow
|
||||
``append`` passes ``max_size - 100`` so the next 99 appends amortise
|
||||
the I/O cost. Each dropped file produces one JSONL line in
|
||||
``.overflow-audit.log``.
|
||||
"""
|
||||
if target_size is None:
|
||||
target_size = self._max_size
|
||||
if target_size < 0:
|
||||
raise ValueError(f"target_size must be >= 0, got {target_size}")
|
||||
|
||||
oldest_first = self.list_pending()
|
||||
excess = len(oldest_first) - target_size
|
||||
if excess <= 0:
|
||||
return 0
|
||||
|
||||
queue_size_before = len(oldest_first)
|
||||
dropped = 0
|
||||
for pending_path in oldest_first[:excess]:
|
||||
ulid = self._ulid_from_path(pending_path)
|
||||
try:
|
||||
envelope = self._read_envelope(pending_path)
|
||||
appended_at = envelope.get("appended_at", "")
|
||||
except (FileNotFoundError, json.JSONDecodeError, CaptureQueueError):
|
||||
# Read failure is non-fatal for pruning: we still drop the
|
||||
# file and log "unknown" appended_at to audit.
|
||||
appended_at = ""
|
||||
|
||||
try:
|
||||
os.unlink(pending_path)
|
||||
except FileNotFoundError:
|
||||
# Someone else raced us (concurrent prune?) -- skip
|
||||
# without auditing since we didn't actually drop it.
|
||||
continue
|
||||
|
||||
self._audit_drop(
|
||||
dropped_ulid=ulid,
|
||||
appended_at=appended_at,
|
||||
queue_size_before_prune=queue_size_before,
|
||||
)
|
||||
dropped += 1
|
||||
return dropped
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _ulid_from_path(path: Path) -> str:
|
||||
"""Extract the ULID from a ``pending-<ulid>.json`` filename."""
|
||||
# ``stem`` for ``pending-XYZ.json`` is ``pending-XYZ``.
|
||||
return path.stem[len("pending-"):]
|
||||
|
||||
@staticmethod
|
||||
def _read_envelope(path: Path) -> dict:
|
||||
"""Read + JSON-decode a pending file. Raises ``json.JSONDecodeError``
|
||||
or ``FileNotFoundError`` on read failure; caller decides handling."""
|
||||
with path.open("rb") as f:
|
||||
raw = f.read()
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
|
||||
def _audit_drop(
|
||||
self,
|
||||
*,
|
||||
dropped_ulid: str,
|
||||
appended_at: str,
|
||||
queue_size_before_prune: int,
|
||||
) -> None:
|
||||
"""Append one JSONL line to ``.overflow-audit.log``.
|
||||
|
||||
Uses ``O_APPEND`` + ``flock`` for cross-process safety, mirroring
|
||||
``LifecycleEventLog.append``. Failures are swallowed: the audit
|
||||
log is observability, not authoritative state -- a failed audit
|
||||
write must not abort the prune.
|
||||
"""
|
||||
line = (
|
||||
json.dumps(
|
||||
{
|
||||
"ts": datetime.now(timezone.utc).isoformat(),
|
||||
"dropped_ulid": dropped_ulid,
|
||||
"appended_at": appended_at,
|
||||
"reason": "queue_overflow",
|
||||
"queue_size_before_prune": queue_size_before_prune,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
try:
|
||||
fd = os.open(
|
||||
str(self._audit_log),
|
||||
os.O_WRONLY | os.O_APPEND | os.O_CREAT,
|
||||
0o600,
|
||||
)
|
||||
except OSError:
|
||||
return
|
||||
try:
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
os.write(fd, line.encode("utf-8"))
|
||||
os.fsync(fd)
|
||||
finally:
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CaptureQueue",
|
||||
"CaptureQueueError",
|
||||
"CaptureQueueLocked",
|
||||
"CaptureQueueSchemaError",
|
||||
"DEFAULT_MAX_SIZE",
|
||||
"DEFAULT_QUEUE_DIR",
|
||||
"SCHEMA_VERSION",
|
||||
"generate_ulid",
|
||||
]
|
||||
2896
src/iai_mcp/cli.py
Normal file
2896
src/iai_mcp/cli.py
Normal file
File diff suppressed because it is too large
Load diff
321
src/iai_mcp/community.py
Normal file
321
src/iai_mcp/community.py
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
"""Hierarchical community detection (D-05 bootstrap + stable UUIDs + CONN-01/04).
|
||||
|
||||
Policy:
|
||||
- N < SMALL_N_FLAT (200): single flat community. Rich-club coefficient is too noisy
|
||||
below this per van den Heuvel & Sporns 2011; Leiden output is unstable too.
|
||||
- SMALL_N_FLAT <= N < MID_N_LEIDEN (500): run Leiden; accept only if Q >= 0.2
|
||||
(MODULARITY_FLOOR), else fall back to flat. Protects against Leiden producing
|
||||
visible but unjustified communities in sparse graphs.
|
||||
- N >= MID_N_LEIDEN: always run Leiden; accept result regardless of Q
|
||||
(graph is big enough that any modular structure is meaningful).
|
||||
|
||||
Stable UUIDs:
|
||||
- Every community gets a persistent UUID at creation.
|
||||
- On re-run, each new community's centroid is matched against prior centroids;
|
||||
the highest cosine >= UUID_ROTATE_COSINE (0.7) reuses the prior UUID.
|
||||
If no prior centroid passes the 0.7 bar, a fresh UUID is allocated.
|
||||
- This prevents ID churn on re-runs where Leiden re-orders labels but the
|
||||
cluster membership is essentially the same.
|
||||
|
||||
CONN-01 three-level parcellation (Phase 1 approximation):
|
||||
- Level 1: top_communities -- top 7 (Yeo-like) by member count.
|
||||
- Level 2: mid_regions -- community UUID -> member node UUIDs
|
||||
(Schaefer-scale 200-400 sub-parcellation is a Phase-2 refinement;
|
||||
for we expose the community -> members mapping).
|
||||
- Level 3: node_to_community -- every leaf record's community assignment.
|
||||
|
||||
CONN-04 refresh threshold:
|
||||
- needs_refresh(prior, current_Q) returns True iff |prior.Q - current_Q| > 0.05.
|
||||
The pipeline or session-start assembler decides when to re-run detect_communities
|
||||
based on this signal.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.graph import _HAS_IGRAPH, IGRAPH_THRESHOLD, MemoryGraph
|
||||
|
||||
# bootstrap thresholds
|
||||
SMALL_N_FLAT = 200
|
||||
MID_N_LEIDEN = 500
|
||||
MODULARITY_FLOOR = 0.2
|
||||
|
||||
# CONN-04 refresh trigger
|
||||
REFRESH_DELTA = 0.05
|
||||
|
||||
# stable-UUID cosine floor
|
||||
UUID_ROTATE_COSINE = 0.7
|
||||
|
||||
# CONN-01 level-1 cap (Yeo-like 7 networks)
|
||||
MAX_TOP_COMMUNITIES = 7
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityAssignment:
|
||||
"""Output of detect_communities -- consumed by pipeline.pipeline_recall.
|
||||
|
||||
- node_to_community: leaf UUID -> community UUID
|
||||
- community_centroids: community UUID -> mean of member embeddings
|
||||
- modularity: Leiden Q (0.0 for flat)
|
||||
- backend: "flat" | "leiden-networkx" | "leiden-igraph"
|
||||
- top_communities: up to MAX_TOP_COMMUNITIES by member count (CONN-01 L1)
|
||||
- mid_regions: community UUID -> list of member leaf UUIDs (CONN-01 L2)
|
||||
"""
|
||||
|
||||
node_to_community: dict[UUID, UUID] = field(default_factory=dict)
|
||||
community_centroids: dict[UUID, list[float]] = field(default_factory=dict)
|
||||
modularity: float = 0.0
|
||||
backend: str = "flat"
|
||||
top_communities: list[UUID] = field(default_factory=list)
|
||||
mid_regions: dict[UUID, list[UUID]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- math helpers
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
av = np.asarray(a, dtype=np.float32)
|
||||
bv = np.asarray(b, dtype=np.float32)
|
||||
na = float(np.linalg.norm(av))
|
||||
nb = float(np.linalg.norm(bv))
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return float(np.dot(av, bv) / (na * nb))
|
||||
|
||||
|
||||
def _compute_centroid(embeddings: list[list[float]]) -> list[float]:
|
||||
if not embeddings:
|
||||
return []
|
||||
arr = np.asarray(embeddings, dtype=np.float32)
|
||||
centroid = arr.mean(axis=0)
|
||||
norm = float(np.linalg.norm(centroid))
|
||||
if norm > 0:
|
||||
centroid = centroid / norm
|
||||
return centroid.tolist()
|
||||
|
||||
|
||||
def _map_to_stable_uuids(
|
||||
raw_partition: dict[UUID, int],
|
||||
graph: MemoryGraph,
|
||||
prior: CommunityAssignment | None,
|
||||
) -> tuple[dict[UUID, UUID], dict[UUID, list[float]]]:
|
||||
"""assign UUIDs to raw integer community labels, reusing prior UUIDs
|
||||
when a new centroid matches a prior centroid with cosine >= UUID_ROTATE_COSINE.
|
||||
|
||||
Matching is greedy (descending best-match-first) and one-to-one: each prior
|
||||
UUID is claimed by at most one new community.
|
||||
"""
|
||||
# Group nodes by raw integer label.
|
||||
groups: dict[int, list[UUID]] = {}
|
||||
for node, grp in raw_partition.items():
|
||||
groups.setdefault(grp, []).append(node)
|
||||
|
||||
# Compute new centroids per group. Filter out nodes with no embedding
|
||||
# (e.g. sentinel UUIDs like PROFILE_SENTINEL) and zero-pad the remaining
|
||||
# members to the *current* store dim rather than a hardcoded 384d, so the
|
||||
# centroid input stays homogeneous after a 384d -> 1024d re-embed migration.
|
||||
new_centroids: dict[int, list[float]] = {}
|
||||
for grp, nodes in groups.items():
|
||||
valid = [e for n in nodes if (e := graph.get_embedding(n))]
|
||||
if not valid:
|
||||
continue
|
||||
dim = len(valid[0])
|
||||
embs = [graph.get_embedding(n) or [0.0] * dim for n in nodes]
|
||||
new_centroids[grp] = _compute_centroid(embs)
|
||||
|
||||
# Greedy one-to-one assignment: for each new group, pick the best unused
|
||||
# prior UUID with cosine >= UUID_ROTATE_COSINE.
|
||||
uuid_for_group: dict[int, UUID] = {}
|
||||
used_prior: set[UUID] = set()
|
||||
if prior:
|
||||
# Stable ordering: by group id ascending so tie-breaks are deterministic.
|
||||
for grp in sorted(new_centroids.keys()):
|
||||
cent = new_centroids[grp]
|
||||
best_prior: UUID | None = None
|
||||
best_sim: float = -1.0
|
||||
for prior_uuid, prior_cent in prior.community_centroids.items():
|
||||
if prior_uuid in used_prior:
|
||||
continue
|
||||
s = _cosine(cent, prior_cent)
|
||||
if s > best_sim:
|
||||
best_sim = s
|
||||
best_prior = prior_uuid
|
||||
if best_prior is not None and best_sim >= UUID_ROTATE_COSINE:
|
||||
uuid_for_group[grp] = best_prior
|
||||
used_prior.add(best_prior)
|
||||
|
||||
# Allocate fresh UUIDs for groups that didn't match any prior.
|
||||
for grp in groups:
|
||||
if grp not in uuid_for_group:
|
||||
uuid_for_group[grp] = uuid4()
|
||||
|
||||
# Build final maps.
|
||||
node_to_community: dict[UUID, UUID] = {}
|
||||
community_centroids: dict[UUID, list[float]] = {}
|
||||
for grp, nodes in groups.items():
|
||||
u = uuid_for_group[grp]
|
||||
community_centroids[u] = new_centroids[grp]
|
||||
for n in nodes:
|
||||
node_to_community[n] = u
|
||||
|
||||
return node_to_community, community_centroids
|
||||
|
||||
|
||||
# ------------------------------------------------------------- flat assignment
|
||||
|
||||
|
||||
def _flat_assignment(
|
||||
graph: MemoryGraph, prior: CommunityAssignment | None
|
||||
) -> CommunityAssignment:
|
||||
"""Single flat community covering every node."""
|
||||
nodes: list[UUID] = []
|
||||
valid_embs: list[list[float]] = []
|
||||
for node in graph._nx.nodes():
|
||||
u = UUID(node)
|
||||
nodes.append(u)
|
||||
emb = graph.get_embedding(u)
|
||||
if emb:
|
||||
valid_embs.append(emb)
|
||||
if not nodes:
|
||||
return CommunityAssignment(backend="flat")
|
||||
|
||||
# Zero-pad any sentinel nodes to the detected store dim so centroid math
|
||||
# stays homogeneous post-re-embed (was hardcoded 384d before 1024d support).
|
||||
dim = len(valid_embs[0]) if valid_embs else 0
|
||||
embs: list[list[float]] = []
|
||||
for node in graph._nx.nodes():
|
||||
u = UUID(node)
|
||||
emb = graph.get_embedding(u)
|
||||
embs.append(emb if emb else [0.0] * dim)
|
||||
centroid = _compute_centroid(embs) if dim else []
|
||||
|
||||
# Stable UUID across flat runs: reuse prior's single UUID if centroid matches.
|
||||
flat_uuid: UUID | None = None
|
||||
if prior and len(prior.community_centroids) == 1:
|
||||
prior_uuid, prior_cent = next(iter(prior.community_centroids.items()))
|
||||
if _cosine(centroid, prior_cent) >= UUID_ROTATE_COSINE:
|
||||
flat_uuid = prior_uuid
|
||||
if flat_uuid is None:
|
||||
flat_uuid = uuid4()
|
||||
|
||||
node_to_community = {n: flat_uuid for n in nodes}
|
||||
community_centroids = {flat_uuid: centroid}
|
||||
return CommunityAssignment(
|
||||
node_to_community=node_to_community,
|
||||
community_centroids=community_centroids,
|
||||
modularity=0.0,
|
||||
backend="flat",
|
||||
top_communities=[flat_uuid],
|
||||
mid_regions={flat_uuid: nodes},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ leiden run
|
||||
|
||||
|
||||
def _run_leiden(graph: MemoryGraph) -> tuple[dict[UUID, int], float, str]:
|
||||
"""Run leidenalg on a NetworkX graph via an igraph mirror.
|
||||
|
||||
Returns (node_uuid -> int label, modularity Q, backend_label).
|
||||
Backend label reflects which library owns the hot path per D-04:
|
||||
"leiden-igraph" for N >= IGRAPH_THRESHOLD, "leiden-networkx" for smaller graphs
|
||||
(both internally use leidenalg since python-louvain is Louvain, not Leiden).
|
||||
Seed=42 for determinism across calls.
|
||||
"""
|
||||
import igraph as ig # local import so leiden dep is lazy
|
||||
import leidenalg
|
||||
|
||||
g = graph._nx
|
||||
nodes = list(g.nodes())
|
||||
idx = {n: i for i, n in enumerate(nodes)}
|
||||
edges = [(idx[u], idx[v]) for u, v in g.edges()]
|
||||
weights = [float(g[u][v].get("weight", 1.0)) for u, v in g.edges()]
|
||||
|
||||
ih = ig.Graph(n=len(nodes), edges=edges, directed=False)
|
||||
if weights:
|
||||
ih.es["weight"] = weights
|
||||
|
||||
part = leidenalg.find_partition(
|
||||
ih,
|
||||
leidenalg.ModularityVertexPartition,
|
||||
seed=42,
|
||||
weights="weight" if weights else None,
|
||||
)
|
||||
q = float(part.modularity)
|
||||
mapping = {
|
||||
UUID(nodes[i]): int(part.membership[i]) for i in range(len(nodes))
|
||||
}
|
||||
|
||||
# Backend label matches split even though both paths use leidenalg.
|
||||
if _HAS_IGRAPH and graph.node_count() >= IGRAPH_THRESHOLD:
|
||||
return mapping, q, "leiden-igraph"
|
||||
return mapping, q, "leiden-networkx"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ public API
|
||||
|
||||
|
||||
def detect_communities(
|
||||
graph: MemoryGraph,
|
||||
prior: CommunityAssignment | None = None,
|
||||
) -> CommunityAssignment:
|
||||
"""D-05 bootstrap + stable UUIDs + CONN-01 three-level parcellation.
|
||||
|
||||
Empty graph -> empty CommunityAssignment(backend="flat").
|
||||
"""
|
||||
n = graph.node_count()
|
||||
if n == 0:
|
||||
return CommunityAssignment(backend="flat")
|
||||
if n < SMALL_N_FLAT:
|
||||
return _flat_assignment(graph, prior)
|
||||
|
||||
try:
|
||||
raw_partition, q, backend = _run_leiden(graph)
|
||||
except Exception:
|
||||
# Leiden unavailable or graph pathological -> degrade gracefully.
|
||||
return _flat_assignment(graph, prior)
|
||||
|
||||
# Mid-N guard: Leiden output only acceptable if Q >= 0.2.
|
||||
if n < MID_N_LEIDEN and q < MODULARITY_FLOOR:
|
||||
return _flat_assignment(graph, prior)
|
||||
|
||||
node_to_community, community_centroids = _map_to_stable_uuids(
|
||||
raw_partition, graph, prior
|
||||
)
|
||||
|
||||
# CONN-01 level 1: top 7 communities by member count.
|
||||
counts: dict[UUID, int] = {}
|
||||
for c in node_to_community.values():
|
||||
counts[c] = counts.get(c, 0) + 1
|
||||
top = sorted(counts.items(), key=lambda kv: kv[1], reverse=True)[
|
||||
:MAX_TOP_COMMUNITIES
|
||||
]
|
||||
top_communities = [u for u, _ in top]
|
||||
|
||||
# CONN-01 level 2 (mid-regions): community UUID -> member node UUIDs.
|
||||
mid_regions: dict[UUID, list[UUID]] = {}
|
||||
for node, comm in node_to_community.items():
|
||||
mid_regions.setdefault(comm, []).append(node)
|
||||
|
||||
return CommunityAssignment(
|
||||
node_to_community=node_to_community,
|
||||
community_centroids=community_centroids,
|
||||
modularity=q,
|
||||
backend=backend,
|
||||
top_communities=top_communities,
|
||||
mid_regions=mid_regions,
|
||||
)
|
||||
|
||||
|
||||
def needs_refresh(
|
||||
prior: CommunityAssignment, current_modularity: float
|
||||
) -> bool:
|
||||
"""CONN-04: refresh signal when |Δ modularity| > REFRESH_DELTA (0.05).
|
||||
|
||||
Consumer (session-start assembler / maintenance job) calls this on each
|
||||
new Leiden run; a True return triggers a re-assignment + cache invalidation.
|
||||
"""
|
||||
return abs(prior.modularity - current_modularity) > REFRESH_DELTA
|
||||
199
src/iai_mcp/compress.py
Normal file
199
src/iai_mcp/compress.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""TOK-04 LLMLingua-2 compression (Plan 02-04 Task 2, D-25).
|
||||
|
||||
Compression is allowed ONLY on retrieval views and summaries, NEVER on raw
|
||||
content. Enforcement lives in `is_compressible`:
|
||||
|
||||
Forbidden:
|
||||
- pinned records (includes L0 identity)
|
||||
- invariant_anchor records (s5_trust_score >= 0.9)
|
||||
- user-tagged raw: records (raw:en, raw:ru, ...)
|
||||
- normal episodic records (default reject; literal_surface is constitutional
|
||||
per MEM-01)
|
||||
|
||||
Allowed:
|
||||
- records tagged cls_summary (CLS consolidation output)
|
||||
- records tagged schema (LEARN-03 induction output)
|
||||
- records tagged session_summary
|
||||
|
||||
Runtime fallback: when `llmlingua` is not installed, `compress_llmlingua2`
|
||||
returns the input unchanged and emits an llm_health event. This keeps the
|
||||
Tier-0 path green on minimal installs (CI, fresh user machines).
|
||||
|
||||
Constants:
|
||||
- COMPRESSION_TARGET_L2 = 0.5 (community descriptors)
|
||||
- COMPRESSION_TARGET_SUMMARY = 0.3 (session summaries)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
|
||||
|
||||
# ratio targets.
|
||||
COMPRESSION_TARGET_L2 = 0.5
|
||||
COMPRESSION_TARGET_SUMMARY = 0.3
|
||||
|
||||
# threshold -- records at or above this trust score are invariant anchors.
|
||||
INVARIANT_TRUST_THRESHOLD = 0.9
|
||||
|
||||
|
||||
# ----------------------------------------------------------- scope gate
|
||||
|
||||
|
||||
def is_compressible(record) -> tuple[bool, str]:
|
||||
"""Return (allowed, reason) for a given MemoryRecord.
|
||||
|
||||
Reason is a short English diagnostic consumed only in tests / debug logs.
|
||||
"""
|
||||
if getattr(record, "pinned", False):
|
||||
return False, "pinned record (D-14 L0 / user-pinned)"
|
||||
|
||||
trust = getattr(record, "s5_trust_score", 0.5)
|
||||
try:
|
||||
if float(trust) >= INVARIANT_TRUST_THRESHOLD:
|
||||
return False, (
|
||||
f"invariant anchor (trust={float(trust):.2f} >= "
|
||||
f"{INVARIANT_TRUST_THRESHOLD}); forbids compression"
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
tags = getattr(record, "tags", None) or []
|
||||
for tag in tags:
|
||||
if tag.startswith("raw:"):
|
||||
return False, f"raw-tagged record ({tag}); user flagged as raw"
|
||||
|
||||
# Explicit allowlist.
|
||||
allow_tags = {"cls_summary", "schema", "session_summary"}
|
||||
for tag in tags:
|
||||
if tag in allow_tags:
|
||||
return True, ""
|
||||
|
||||
return False, "literal_surface constitutional (D-25 default deny)"
|
||||
|
||||
|
||||
# ----------------------------------------------------------- llmlingua loader
|
||||
|
||||
|
||||
_LLMLINGUA_LOCK = threading.Lock()
|
||||
_LLMLINGUA_CACHE: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _load_llmlingua2():
|
||||
"""Lazy-load llmlingua's PromptCompressor (LLMLingua-2 model).
|
||||
|
||||
Returns the compressor instance on success; None if the package is absent
|
||||
or fails to instantiate. Callers log a fallback event and passthrough.
|
||||
"""
|
||||
with _LLMLINGUA_LOCK:
|
||||
if "instance" in _LLMLINGUA_CACHE:
|
||||
return _LLMLINGUA_CACHE["instance"]
|
||||
try:
|
||||
from llmlingua import PromptCompressor # type: ignore
|
||||
except Exception:
|
||||
_LLMLINGUA_CACHE["instance"] = None
|
||||
return None
|
||||
try:
|
||||
# Device auto-detection: CUDA if available (Linux GPU), else MPS on
|
||||
# Apple Silicon (torch.backends.mps), else CPU. llmlingua's default
|
||||
# assumes CUDA which breaks on macOS ARM64.
|
||||
import torch # type: ignore
|
||||
if torch.cuda.is_available():
|
||||
device_map = "cuda"
|
||||
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
device_map = "mps"
|
||||
else:
|
||||
device_map = "cpu"
|
||||
# microsoft/llmlingua-2-xlm-roberta-large-meetingbank (default in
|
||||
# llmlingua>=0.2). Although this compressor is multilingual-capable,
|
||||
# the IAI-MCP brain itself is English-only; the
|
||||
# multilingual support is incidental and only matters for the
|
||||
# opt-in bge-m3 path.
|
||||
compressor = PromptCompressor(
|
||||
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
|
||||
use_llmlingua2=True,
|
||||
device_map=device_map,
|
||||
)
|
||||
except Exception:
|
||||
_LLMLINGUA_CACHE["instance"] = None
|
||||
return None
|
||||
_LLMLINGUA_CACHE["instance"] = compressor
|
||||
return compressor
|
||||
|
||||
|
||||
# ----------------------------------------------------------- core compression
|
||||
|
||||
|
||||
def compress_llmlingua2(
|
||||
text: str,
|
||||
target_ratio: float = 0.5,
|
||||
store=None,
|
||||
) -> str:
|
||||
"""Compress `text` to approximately `target_ratio` of original tokens.
|
||||
|
||||
On any failure (package missing, model load error, runtime exception):
|
||||
- Return `text` unchanged (passthrough).
|
||||
- If `store` is provided, emit an llm_health event of kind
|
||||
'compression_fallback' with severity='warning'.
|
||||
|
||||
scope is the caller's responsibility (is_compressible must be
|
||||
consulted BEFORE reaching this function).
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
compressor = _load_llmlingua2()
|
||||
if compressor is None:
|
||||
if store is not None:
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "compress_llmlingua2",
|
||||
"tier": "fallback",
|
||||
"reason": "llmlingua package unavailable or model load failed",
|
||||
},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return text
|
||||
|
||||
try:
|
||||
result = compressor.compress_prompt(text, rate=float(target_ratio))
|
||||
if isinstance(result, dict):
|
||||
return str(result.get("compressed_prompt", text))
|
||||
return str(result)
|
||||
except Exception as exc: # pragma: no cover -- runtime failure passthrough
|
||||
if store is not None:
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "compress_llmlingua2",
|
||||
"tier": "fallback",
|
||||
"error": str(exc),
|
||||
},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return text
|
||||
|
||||
|
||||
def compress_l2_descriptor(descriptor: str, store=None) -> str:
|
||||
"""Compress an L2 community descriptor (D-25 target ratio 0.5)."""
|
||||
return compress_llmlingua2(
|
||||
descriptor, target_ratio=COMPRESSION_TARGET_L2, store=store,
|
||||
)
|
||||
|
||||
|
||||
def compress_summary(summary: str, store=None) -> str:
|
||||
"""Compress a session summary (D-25 target ratio 0.3)."""
|
||||
return compress_llmlingua2(
|
||||
summary, target_ratio=COMPRESSION_TARGET_SUMMARY, store=store,
|
||||
)
|
||||
499
src/iai_mcp/concurrency.py
Normal file
499
src/iai_mcp/concurrency.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""Phase 4 daemon concurrency primitives (DAEMON-04, DAEMON-05).
|
||||
|
||||
Persistent-fd flock wrapper. Hold one instance for process lifetime.
|
||||
fcntl.flock (NOT lockf) -- fd-close does not release (see apenwarr 2010, Pitfall 2).
|
||||
|
||||
Constitutional guard:
|
||||
- C1 HUMAN-FIRST: ProcessLock.try_acquire_exclusive is non-blocking; daemon
|
||||
yields immediately when any shared lockholder exists.
|
||||
- C-USER-CONSENT (formerly C2 per D7-16): the user_initiated_sleep
|
||||
branch of _dispatch_socket_request only sets pending flags after receiving
|
||||
an explicit consent payload from the wrapper; the FSM transition itself is
|
||||
performed by _tick_body, never by the dispatcher (C-DISPATCHER-FSM-ISOLATION).
|
||||
- C-DISPATCHER-FSM-ISOLATION (Phase 7 structural; supersedes the bare `C2`
|
||||
inline-comment shorthand previously used at the FSM-yield call sites): the
|
||||
socket dispatcher MUST NOT transition the FSM directly; it only sets pending
|
||||
flags consumed by _tick_body under the FSM lock. New socket_server
|
||||
inherits this invariant.
|
||||
- T-04-06 mitigation: flock is bound to process + open-file-description,
|
||||
so closing an unrelated fd (e.g. /etc/passwd) does NOT release our lock.
|
||||
- T-04-02 mitigation: cleanup_stale_socket + asyncio cleanup_socket kwarg
|
||||
survive SIGKILL-orphaned sockets.
|
||||
- T-04-07 mitigation: lock + socket created with mode 0o600 so cross-user
|
||||
access requires OS privilege escalation (out of scope).
|
||||
|
||||
This module has NO LLM code and NO paid-API env var references.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
LOCK_PATH: Path = Path.home() / ".iai-mcp" / ".lock"
|
||||
SOCKET_PATH: Path = Path.home() / ".iai-mcp" / ".daemon.sock"
|
||||
|
||||
|
||||
class ProcessLock:
|
||||
"""Persistent-fd flock wrapper.
|
||||
|
||||
Hold one instance per process for the entire process lifetime.
|
||||
fcntl.flock (BSD) NOT lockf (POSIX) -- closing an unrelated fd does NOT
|
||||
release our lock (see apenwarr 2010, Pitfall 2).
|
||||
|
||||
Semantics:
|
||||
- acquire_shared(): blocking LOCK_SH (MCP pattern)
|
||||
- try_acquire_exclusive(): LOCK_EX | LOCK_NB (daemon heavy-op pattern)
|
||||
- holds_exclusive_nb(): cooperative-yield probe
|
||||
- release(): LOCK_UN (release without closing fd)
|
||||
- close(): os.close() the fd (shutdown only)
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path = LOCK_PATH) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# O_CREAT so lock file is created if missing; mode 0o600 keeps it user-only.
|
||||
self._fd: int | None = os.open(path, os.O_RDWR | os.O_CREAT, 0o600)
|
||||
# Ensure mode is actually 0o600 even if umask altered it on create.
|
||||
try:
|
||||
os.chmod(path, 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
self._path = path
|
||||
|
||||
def acquire_shared(self) -> None:
|
||||
"""Blocking LOCK_SH. MCP sessions call this at session start."""
|
||||
if self._fd is None:
|
||||
raise RuntimeError("ProcessLock closed; cannot acquire")
|
||||
fcntl.flock(self._fd, fcntl.LOCK_SH)
|
||||
|
||||
def try_acquire_exclusive(self) -> bool:
|
||||
"""Non-blocking LOCK_EX | LOCK_NB.
|
||||
|
||||
Returns True if acquired, False if any shared holder blocks us.
|
||||
Daemon calls this before heavy ops; False -> yield to MCP.
|
||||
"""
|
||||
if self._fd is None:
|
||||
return False
|
||||
try:
|
||||
fcntl.flock(self._fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
return True
|
||||
except OSError as exc:
|
||||
if exc.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
|
||||
return False
|
||||
raise
|
||||
|
||||
def holds_exclusive_nb(self) -> bool:
|
||||
"""D-06 cooperative-yield probe.
|
||||
|
||||
Non-blocking check: do we still hold the exclusive lock?
|
||||
|
||||
Returns True if our fd has the exclusive lock. Returns False if
|
||||
another process (e.g., MCP) acquired a shared lock while we were
|
||||
working between REM cycles.
|
||||
|
||||
Implementation: fcntl.flock with LOCK_EX | LOCK_NB on our existing fd.
|
||||
On Linux/macOS, re-acquiring an already-held lock is a no-op success.
|
||||
On contention (shared lock held by another process), raises BlockingIOError
|
||||
which we catch and translate to False. EWOULDBLOCK/EAGAIN may surface as
|
||||
OSError on some platforms -- caught the same way.
|
||||
"""
|
||||
if self._fd is None:
|
||||
return False
|
||||
try:
|
||||
fcntl.flock(self._fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
return True
|
||||
except BlockingIOError:
|
||||
return False
|
||||
except OSError as exc:
|
||||
if exc.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
|
||||
return False
|
||||
raise
|
||||
|
||||
def release(self) -> None:
|
||||
"""LOCK_UN: release lock but keep fd open for later reacquisition."""
|
||||
if self._fd is None:
|
||||
return
|
||||
fcntl.flock(self._fd, fcntl.LOCK_UN)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close fd. Only call at process shutdown -- closing releases the lock."""
|
||||
if self._fd is not None:
|
||||
try:
|
||||
os.close(self._fd)
|
||||
finally:
|
||||
self._fd = None
|
||||
|
||||
|
||||
def cleanup_stale_socket(path: Path = SOCKET_PATH) -> None:
|
||||
"""Remove a stale socket file left over from SIGKILL-orphaned daemon.
|
||||
|
||||
Pitfall 10 mitigation: the in-process case is handled either by the
|
||||
3.13+ kwarg (see serve_control_socket) or by the 3.12 finally-block
|
||||
emulation, but a prior daemon killed with SIGKILL never got to run its
|
||||
cleanup. Call this BEFORE the server binds.
|
||||
"""
|
||||
try:
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError:
|
||||
# Path may be a non-socket file -- still try to unlink. If even that
|
||||
# fails (e.g. permission), let asyncio surface the EADDRINUSE.
|
||||
try:
|
||||
path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _validate_socket_message(req: dict) -> tuple[bool, str | None]:
|
||||
"""Per-type schema validation (ASVS V5).
|
||||
|
||||
Returns (ok, error_message). `req` must already be known to be a dict.
|
||||
"""
|
||||
req_type = req.get("type")
|
||||
if not isinstance(req_type, str):
|
||||
return False, "type must be a string"
|
||||
|
||||
if req_type == "status":
|
||||
# No required fields.
|
||||
return True, None
|
||||
|
||||
if req_type == "user_initiated_sleep":
|
||||
reason = req.get("reason")
|
||||
ts = req.get("ts")
|
||||
if not isinstance(reason, str):
|
||||
return False, "reason must be a string"
|
||||
if not isinstance(ts, str):
|
||||
return False, "ts must be a string"
|
||||
return True, None
|
||||
|
||||
if req_type in ("force_wake", "force_rem"):
|
||||
ts = req.get("ts")
|
||||
if not isinstance(ts, str):
|
||||
return False, "ts must be a string"
|
||||
return True, None
|
||||
|
||||
if req_type in ("pause", "resume"):
|
||||
# pause may optionally carry `seconds`; we don't persist it as a timer
|
||||
# (the flag is binary) but we DO validate the type if supplied.
|
||||
if "seconds" in req:
|
||||
seconds = req.get("seconds")
|
||||
if not isinstance(seconds, int) or isinstance(seconds, bool):
|
||||
return False, "seconds must be an int"
|
||||
return True, None
|
||||
|
||||
# TOK-14 / D5-05: 7th message type `session_open`.
|
||||
# Both session_id and ts are OPTIONAL; when supplied, they must be strings.
|
||||
# Absence is tolerated so the TS wrapper can emit a bare ping on MCP boot
|
||||
# without stalling on id/ts bookkeeping.
|
||||
if req_type == "session_open":
|
||||
if "session_id" in req and not isinstance(req["session_id"], str):
|
||||
return False, "session_id must be a string"
|
||||
if "ts" in req and not isinstance(req["ts"], str):
|
||||
return False, "ts must be a string"
|
||||
return True, None
|
||||
|
||||
# Unknown types are not rejected at validation time; the dispatcher
|
||||
# returns a structured unknown_message_type response so the caller sees
|
||||
# a different reason code from "invalid_message".
|
||||
return True, None
|
||||
|
||||
|
||||
async def _dispatch_socket_request(
|
||||
req: dict,
|
||||
store: Any,
|
||||
lock: ProcessLock,
|
||||
state: dict,
|
||||
) -> dict:
|
||||
"""Default dispatcher for NDJSON socket requests.
|
||||
|
||||
Handles seven message types; mutates `state` in-place and persists via
|
||||
`save_state` when the message changes scheduler control flags. The
|
||||
dispatcher thread NEVER transitions the FSM directly
|
||||
(C-DISPATCHER-FSM-ISOLATION; renamed from bare `C2` per D7-16) --
|
||||
it only sets pending flags that `_tick_body` reads under the FSM lock.
|
||||
|
||||
Handled types:
|
||||
- status -> state snapshot including version
|
||||
- user_initiated_sleep -> set user_sleep_request pending flag
|
||||
- force_wake -> set force_wake_request pending flag
|
||||
- force_rem -> set force_rem_request pending flag
|
||||
- pause -> scheduler_paused=True
|
||||
- resume -> scheduler_paused=False
|
||||
- session_open -> set first_turn_pending + hippea_cascade_request
|
||||
(Plan 05-04 TOK-14 / D5-05)
|
||||
- any other -> {"ok": False, "reason": "unknown_message_type"}
|
||||
"""
|
||||
# Reject non-dict requests (defence-in-depth; caller already json.loaded).
|
||||
if not isinstance(req, dict):
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "invalid_message",
|
||||
"error": "request must be a JSON object",
|
||||
}
|
||||
|
||||
# Per-type schema validation (ASVS V5).
|
||||
ok, err = _validate_socket_message(req)
|
||||
if not ok:
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "invalid_message",
|
||||
"error": err or "schema_validation_failed",
|
||||
}
|
||||
|
||||
req_type = req.get("type")
|
||||
|
||||
# Lazy imports so test monkeypatches of STATE_PATH (via daemon_state) and
|
||||
# __version__ (via iai_mcp) always resolve to the current module state.
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from iai_mcp import __version__ as pkg_version
|
||||
from iai_mcp.daemon_state import save_state
|
||||
|
||||
# -------------------------------------------------------- status snapshot
|
||||
if req_type == "status":
|
||||
fsm_state = state.get("fsm_state", "WAKE")
|
||||
started_at = state.get("daemon_started_at")
|
||||
uptime_sec: float | None = None
|
||||
if started_at:
|
||||
try:
|
||||
start_dt = datetime.fromisoformat(started_at)
|
||||
uptime_sec = (datetime.now(timezone.utc) - start_dt).total_seconds()
|
||||
except (TypeError, ValueError):
|
||||
uptime_sec = None
|
||||
|
||||
# Truncate pending_digest to the top-level counters for socket
|
||||
# transport; the full digest can be multi-KB once insights are baked.
|
||||
pending_digest = state.get("pending_digest")
|
||||
if isinstance(pending_digest, dict):
|
||||
truncated_digest = {
|
||||
"rem_cycles_completed": pending_digest.get("rem_cycles_completed", 0),
|
||||
"episodes_processed": pending_digest.get("episodes_processed", 0),
|
||||
"schemas_induced_tier0": pending_digest.get(
|
||||
"schemas_induced_tier0", 0,
|
||||
),
|
||||
"claude_call_used": pending_digest.get("claude_call_used", False),
|
||||
}
|
||||
else:
|
||||
truncated_digest = None
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
# Backwards-compat key used by tests/test_concurrency.py Test 6.
|
||||
"state": fsm_state,
|
||||
"uptime_sec": uptime_sec,
|
||||
# Plan 04-gap-1 additions:
|
||||
"version": pkg_version,
|
||||
"fsm_state": fsm_state,
|
||||
"last_tick_at": state.get("last_tick_at"),
|
||||
"quiet_window": state.get("quiet_window"),
|
||||
"pending_digest": truncated_digest,
|
||||
"daemon_started_at": started_at,
|
||||
"scheduler_paused": bool(state.get("scheduler_paused", False)),
|
||||
}
|
||||
|
||||
# -------------------------------------------------- user_initiated_sleep
|
||||
if req_type == "user_initiated_sleep":
|
||||
current_fsm = state.get("fsm_state", "WAKE")
|
||||
if current_fsm in ("SLEEP", "DREAMING", "TRANSITIONING"):
|
||||
return {"ok": False, "reason": "already_sleeping"}
|
||||
|
||||
# Clip reason to 500 chars (ASVS V5 output hardening mirror).
|
||||
reason = str(req.get("reason", ""))[:500]
|
||||
ts = str(req.get("ts", ""))
|
||||
state["user_sleep_request"] = {
|
||||
"reason": reason,
|
||||
"ts": ts,
|
||||
"pending": True,
|
||||
}
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001 -- socket must never crash daemon
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
# Tell the caller we queued the transition; the scheduler owns the FSM
|
||||
# and will move WAKE->TRANSITIONING->SLEEP on the next tick
|
||||
# (C-DISPATCHER-FSM-ISOLATION; renamed from bare `C2` per D7-16).
|
||||
return {"ok": True, "state": "TRANSITIONING"}
|
||||
|
||||
# ---------------------------------------------------------- force_wake
|
||||
if req_type == "force_wake":
|
||||
ts = str(req.get("ts", ""))
|
||||
state["force_wake_request"] = {"ts": ts, "pending": True}
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
return {"ok": True, "reason": "wake_queued"}
|
||||
|
||||
# ----------------------------------------------------------- force_rem
|
||||
if req_type == "force_rem":
|
||||
ts = str(req.get("ts", ""))
|
||||
state["force_rem_request"] = {"ts": ts, "pending": True}
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
return {"ok": True, "reason": "rem_queued"}
|
||||
|
||||
# --------------------------------------------------------- pause/resume
|
||||
if req_type == "pause":
|
||||
state["scheduler_paused"] = True
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
return {"ok": True, "paused": True}
|
||||
|
||||
if req_type == "resume":
|
||||
state["scheduler_paused"] = False
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
return {"ok": True, "paused": False}
|
||||
|
||||
# ---------------------------------------------------------- session_open
|
||||
# TOK-14 / D5-05: 7th message type. Sets two flags:
|
||||
# - first_turn_pending[session_id] = True -> consumed by core's
|
||||
# _first_turn_recall_hook exactly once per session.
|
||||
# - hippea_cascade_request {pending=True, session_id, ts} -> polled by
|
||||
# daemon._hippea_cascade_loop which pre-warms the LRU with records
|
||||
# from the top-K salient communities (Van de Cruys HIPPEA operational
|
||||
# form).
|
||||
# Both flags are idempotent under a re-emit: set_overwrite is intentional
|
||||
# so a client that retries session_open gets a fresh cascade.
|
||||
if req_type == "session_open":
|
||||
# Clip session_id to 128 chars (ASVS V5 output hardening — matches
|
||||
# user_initiated_sleep.reason clip at 500).
|
||||
session_id = str(req.get("session_id", ""))[:128]
|
||||
ts = str(req.get("ts", ""))
|
||||
state["last_session_open"] = {"session_id": session_id, "ts": ts}
|
||||
# first-turn hook flag. Co-exists with existing dict form
|
||||
# written by daemon_state.mark_session_opened.
|
||||
first_turn = state.setdefault("first_turn_pending", {})
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
if isinstance(first_turn, dict):
|
||||
first_turn[session_id] = now_iso
|
||||
else:
|
||||
# Legacy scalar-bool state -> upgrade in place to the dict form.
|
||||
state["first_turn_pending"] = {session_id: now_iso}
|
||||
# cascade flag.
|
||||
state["hippea_cascade_request"] = {
|
||||
"session_id": session_id,
|
||||
"ts": ts,
|
||||
"pending": True,
|
||||
}
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return {"ok": False, "reason": "state_write_failed", "error": str(exc)[:200]}
|
||||
return {"ok": True, "reason": "session_open_queued"}
|
||||
|
||||
# ------------------------------------------------------------ unknown
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "unknown_message_type",
|
||||
"type": req_type,
|
||||
}
|
||||
|
||||
|
||||
async def serve_control_socket(
|
||||
store: Any,
|
||||
lock: ProcessLock,
|
||||
state: dict,
|
||||
shutdown: asyncio.Event,
|
||||
*,
|
||||
dispatcher: Callable[[dict], Awaitable[dict]] | None = None,
|
||||
socket_path: Path = SOCKET_PATH,
|
||||
) -> None:
|
||||
"""Unix socket NDJSON server at ~/.iai-mcp/.daemon.sock.
|
||||
|
||||
Protocol: each line from client is a JSON request; each response is one
|
||||
JSON line back. The cleanup_socket kwarg (Python 3.13+) auto-removes the
|
||||
socket file on server shutdown; on 3.12 we emulate in the finally-block.
|
||||
Stale-socket pre-cleanup protects against SIGKILL-orphaned files.
|
||||
|
||||
Permissions: chmod 0o600 immediately after bind so cross-user access
|
||||
requires privilege escalation (T-04-04 accepted risk).
|
||||
|
||||
When dispatcher is provided it receives only the parsed request dict and
|
||||
must return a dict. When None, the default _dispatch_socket_request is used.
|
||||
"""
|
||||
cleanup_stale_socket(socket_path)
|
||||
# Ensure parent dir exists (Path.home() / .iai-mcp could be first-run).
|
||||
socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Python 3.13 added a `cleanup_socket` kwarg to the event-loop unix server
|
||||
# that auto-removes the socket file on shutdown. On 3.12 we emulate the
|
||||
# same behaviour by unlinking in the finally-block below. See:
|
||||
# https://docs.python.org/3.13/library/asyncio-stream.html
|
||||
_supports_cleanup_socket = False
|
||||
try:
|
||||
import inspect as _inspect
|
||||
import asyncio as _asyncio_mod
|
||||
_loop_sig = _inspect.signature(
|
||||
_asyncio_mod.get_event_loop_policy().new_event_loop().create_unix_server
|
||||
)
|
||||
_supports_cleanup_socket = "cleanup_socket" in _loop_sig.parameters
|
||||
except Exception:
|
||||
_supports_cleanup_socket = False
|
||||
|
||||
async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
try:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
return
|
||||
try:
|
||||
req = json.loads(line)
|
||||
except (TypeError, ValueError) as exc:
|
||||
writer.write((json.dumps({"error": f"invalid_json: {exc}"}) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
return
|
||||
try:
|
||||
if dispatcher is not None:
|
||||
resp = await dispatcher(req)
|
||||
else:
|
||||
resp = await _dispatch_socket_request(req, store, lock, state)
|
||||
except Exception as exc: # noqa: BLE001 -- socket must never crash daemon
|
||||
resp = {"error": str(exc)}
|
||||
writer.write((json.dumps(resp) + "\n").encode("utf-8"))
|
||||
await writer.drain()
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build server kwargs. The native 3.13+ behaviour is opted in via
|
||||
# `cleanup_socket=True`; on 3.12 the finally-block emulates the same unlink
|
||||
# so a subsequent daemon boot cannot hit EADDRINUSE.
|
||||
_server_kwargs = {"cleanup_socket": True} if _supports_cleanup_socket else {}
|
||||
server = await asyncio.start_unix_server(
|
||||
handle, path=str(socket_path), **_server_kwargs,
|
||||
)
|
||||
# chmod 0o600 immediately after bind (T-04-07 mitigation).
|
||||
try:
|
||||
os.chmod(str(socket_path), 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
async with server:
|
||||
await shutdown.wait()
|
||||
finally:
|
||||
# Python 3.12 cleanup-socket emulation: remove the socket file on
|
||||
# shutdown so the next daemon boot doesn't hit EADDRINUSE. 3.13+ does
|
||||
# this natively inside the server.__aexit__.
|
||||
if not _supports_cleanup_socket:
|
||||
try:
|
||||
socket_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
1332
src/iai_mcp/core.py
Normal file
1332
src/iai_mcp/core.py
Normal file
File diff suppressed because it is too large
Load diff
432
src/iai_mcp/crypto.py
Normal file
432
src/iai_mcp/crypto.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
"""Plan 02-08 / AES-256-GCM encryption-at-rest primitives + file-backed key storage.
|
||||
|
||||
Ciphertext format (string-encoded for LanceDB string-column storage):
|
||||
|
||||
iai:enc:v1:<base64(nonce || ciphertext || tag)>
|
||||
|
||||
Components:
|
||||
- prefix "iai:enc:v1:" (identifies encrypted payload; enables mixed
|
||||
plaintext/ciphertext coexistence during v2->v3 migration)
|
||||
- nonce 12 random bytes (AES-GCM standard IV length)
|
||||
- ciphertext+tag AESGCM.encrypt(nonce, plaintext_utf8, associated_data) output;
|
||||
the 16-byte GCM authentication tag is appended by AESGCM.
|
||||
|
||||
Associated data (AD) is the UUID bytes of the record id: this binds the
|
||||
ciphertext to its row so an attacker with write access cannot swap ciphertext
|
||||
values between rows (T-02-08-01 tampering mitigation).
|
||||
|
||||
Key storage (Phase 07.10 — file-backed primary, no keyring at module scope):
|
||||
- Primary: a 32-raw-byte file at ``{store_root}/.crypto.key`` (default
|
||||
``~/.iai-mcp/.crypto.key``), mode ``0o600``, owner-uid validated. Resolved
|
||||
via the ``store_root`` constructor argument (single-source path, threaded
|
||||
from ``MemoryStore.root`` — see D-03). When ``store_root`` is
|
||||
``None`` the path is read lazily from ``IAI_MCP_STORE`` env or the
|
||||
``DEFAULT_STORAGE_PATH`` (``~/.iai-mcp``).
|
||||
- Fallback: passphrase via ``IAI_MCP_CRYPTO_PASSPHRASE`` env var (CI / fresh
|
||||
installs / non-interactive environments). Key derived via PBKDF2-HMAC-
|
||||
SHA256 with 600_000 iterations (OWASP 2023 recommendation) and a per-user
|
||||
salt (``sha256(user_id)[:16]``). Deterministic given passphrase + user_id,
|
||||
so the same machine survives reboots without persisting anything new.
|
||||
- If neither path resolves, ``CryptoKey.get_or_create()`` raises
|
||||
``CryptoKeyError`` with a dual-remediation message naming
|
||||
``iai-mcp crypto migrate-to-file`` (existing macOS Keychain key from before
|
||||
Phase 07.10), ``iai-mcp crypto init`` (fresh install), and the
|
||||
``IAI_MCP_CRYPTO_PASSPHRASE`` env var (CI / non-interactive). No silent
|
||||
key generation — that would render existing data unreadable.
|
||||
|
||||
The migration CLI command ``iai-mcp crypto migrate-to-file`` keeps
|
||||
a function-local ``import keyring`` to read an existing macOS Keychain key
|
||||
once and write it to the file backend; this module never imports ``keyring``
|
||||
at file scope, so daemon boot under launchd does not block on the Keychain
|
||||
ACL prompt (Phase 07.10 / D-12).
|
||||
|
||||
Module contract:
|
||||
- encrypt_field(plaintext, key, associated_data) -> str (prefixed base64)
|
||||
- decrypt_field(ciphertext_b64, key, associated_data) -> str
|
||||
- is_encrypted(field) -> bool
|
||||
- CryptoKey(user_id, store_root=None).get_or_create() / rotate() / delete()
|
||||
- derive_key_from_passphrase(passphrase, salt) -> bytes (32)
|
||||
|
||||
Constitutional fit:
|
||||
- D-STORAGE: no keys stored in the LanceDB store; only ciphertext.
|
||||
- D-GUARD: file backend missing degrades to passphrase fallback; absent both,
|
||||
refusal is loud with an actionable error pointing at both remediation paths.
|
||||
- encryption is lossless -- decrypt(encrypt(x)) == x byte-for-byte.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
|
||||
# Constitutional constants (module-scope for grep-discoverability).
|
||||
CIPHERTEXT_PREFIX: str = "iai:enc:v1:"
|
||||
NONCE_BYTES: int = 12 # AES-GCM standard IV length
|
||||
KEY_BYTES: int = 32 # 256-bit key
|
||||
PBKDF2_ITERATIONS: int = 600_000 # OWASP 2023 minimum for PBKDF2-HMAC-SHA256
|
||||
SERVICE_NAME_DEFAULT: str = "iai-mcp"
|
||||
|
||||
# Default storage root mirrors store.DEFAULT_STORAGE_PATH so a CryptoKey that
|
||||
# is constructed without a ``store_root`` argument resolves to the same
|
||||
# location MemoryStore would have used. Kept as a module-private to avoid
|
||||
# importing store.py here (would create a circular import).
|
||||
_DEFAULT_STORE_ROOT: Path = Path.home() / ".iai-mcp"
|
||||
_KEY_FILE_NAME: str = ".crypto.key"
|
||||
|
||||
|
||||
class CryptoKeyError(RuntimeError):
|
||||
"""Raised when a CryptoKey cannot be loaded or created.
|
||||
|
||||
Typical triggers:
|
||||
- The key file exists at the resolved path but is unreadable, has an
|
||||
insecure mode, is owned by a different uid, or has the wrong length.
|
||||
- Neither a key file NOR ``IAI_MCP_CRYPTO_PASSPHRASE`` is present;
|
||||
``MemoryStore`` surfaces the error so the daemon refuses to start with
|
||||
a clear actionable message instead of silently proceeding without
|
||||
encryption (Phase 07.10 D-04).
|
||||
"""
|
||||
|
||||
|
||||
def is_encrypted(field: Optional[str]) -> bool:
|
||||
"""Cheap prefix check supporting mixed-plaintext/ciphertext coexistence.
|
||||
|
||||
Returns True only when `field` is a non-empty string that starts with the
|
||||
exact version prefix `iai:enc:v1:`. Used by:
|
||||
- store._decrypt_fields to know whether to attempt decryption
|
||||
- migrate_encryption_v2_to_v3 to skip already-encrypted rows
|
||||
"""
|
||||
if not field or not isinstance(field, str):
|
||||
return False
|
||||
return field.startswith(CIPHERTEXT_PREFIX)
|
||||
|
||||
|
||||
def encrypt_field(
|
||||
plaintext: str,
|
||||
key: bytes,
|
||||
associated_data: bytes = b"",
|
||||
) -> str:
|
||||
"""AES-256-GCM encrypt a UTF-8 string; return prefixed base64 ciphertext.
|
||||
|
||||
The nonce is generated randomly with secrets.token_bytes (not os.urandom
|
||||
for slight additional entropy guarantees). A fresh nonce is REQUIRED for
|
||||
every call with a given key -- reusing a nonce with AES-GCM breaks the
|
||||
security of both messages.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
plaintext:
|
||||
Any UTF-8 string (including empty string). Cyrillic / CJK / Arabic
|
||||
preserved byte-for-byte.
|
||||
key:
|
||||
32-byte (256-bit) key. Typically sourced from CryptoKey.get_or_create().
|
||||
associated_data:
|
||||
Arbitrary bytes that are authenticated but not encrypted. In this
|
||||
codebase: the record id in UUID-string form (binds ciphertext to row).
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: "iai:enc:v1:" + base64(nonce || ciphertext || tag)
|
||||
"""
|
||||
if len(key) != KEY_BYTES:
|
||||
raise ValueError(f"key must be {KEY_BYTES} bytes (got {len(key)})")
|
||||
aesgcm = AESGCM(key)
|
||||
nonce = secrets.token_bytes(NONCE_BYTES)
|
||||
ct_with_tag = aesgcm.encrypt(
|
||||
nonce, plaintext.encode("utf-8"), associated_data or None
|
||||
)
|
||||
payload = nonce + ct_with_tag
|
||||
return CIPHERTEXT_PREFIX + base64.b64encode(payload).decode("ascii")
|
||||
|
||||
|
||||
def decrypt_field(
|
||||
ciphertext_b64: str,
|
||||
key: bytes,
|
||||
associated_data: bytes = b"",
|
||||
) -> str:
|
||||
"""Decrypt a prefixed base64 AES-256-GCM payload back to a UTF-8 string.
|
||||
|
||||
Raises cryptography.exceptions.InvalidTag on:
|
||||
- Wrong key
|
||||
- Tampered ciphertext (single-bit flip in nonce / ct / tag)
|
||||
- Mismatched associated_data (even one byte off)
|
||||
|
||||
Raises ValueError if the field doesn't carry the iai:enc:v1: prefix -- the
|
||||
caller should have guarded with is_encrypted() first.
|
||||
"""
|
||||
if not is_encrypted(ciphertext_b64):
|
||||
raise ValueError("field is not iai:enc:v1:-prefixed ciphertext")
|
||||
if len(key) != KEY_BYTES:
|
||||
raise ValueError(f"key must be {KEY_BYTES} bytes (got {len(key)})")
|
||||
payload_b64 = ciphertext_b64[len(CIPHERTEXT_PREFIX):]
|
||||
payload = base64.b64decode(payload_b64)
|
||||
if len(payload) < NONCE_BYTES + 16: # nonce + min GCM tag
|
||||
raise ValueError("ciphertext payload too short")
|
||||
nonce = payload[:NONCE_BYTES]
|
||||
ct_with_tag = payload[NONCE_BYTES:]
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext_bytes = aesgcm.decrypt(
|
||||
nonce, ct_with_tag, associated_data or None
|
||||
)
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
def derive_key_from_passphrase(passphrase: str, salt: bytes) -> bytes:
|
||||
"""PBKDF2-HMAC-SHA256 key derivation for the passphrase-fallback path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
passphrase:
|
||||
User-supplied passphrase (via IAI_MCP_CRYPTO_PASSPHRASE env var in the
|
||||
current design -- first-run prompt is future work when we have a CLI
|
||||
interaction point).
|
||||
salt:
|
||||
16+ bytes of salt. In practice the CryptoKey fallback uses
|
||||
sha256(user_id)[:16] so the derived key is deterministic per
|
||||
(passphrase, user_id) pair on a given machine.
|
||||
|
||||
Returns 32 bytes (256-bit) suitable for AESGCM.
|
||||
"""
|
||||
if len(salt) < 16:
|
||||
raise ValueError(f"salt must be at least 16 bytes (got {len(salt)})")
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=KEY_BYTES,
|
||||
salt=salt,
|
||||
iterations=PBKDF2_ITERATIONS,
|
||||
)
|
||||
return kdf.derive(passphrase.encode("utf-8"))
|
||||
|
||||
|
||||
class CryptoKey:
|
||||
"""File-backed 256-bit AES key with passphrase fallback.
|
||||
|
||||
redesign:
|
||||
File backend at ``{store_root}/.crypto.key`` (32 raw bytes, mode
|
||||
``0o600``, owner-uid validated) is the primary. Passphrase via
|
||||
``IAI_MCP_CRYPTO_PASSPHRASE`` is the second-tier fallback. If neither
|
||||
resolves, ``get_or_create()`` raises ``CryptoKeyError`` with an
|
||||
actionable error message naming both remediation paths plus
|
||||
``iai-mcp crypto migrate-to-file`` (one-time migration of an existing
|
||||
Keychain key) and ``iai-mcp crypto init`` (fresh install).
|
||||
|
||||
Usage:
|
||||
ck = CryptoKey(user_id="default", store_root=Path("~/.iai-mcp"))
|
||||
key = ck.get_or_create() # 32 bytes; reads from file or falls back
|
||||
# to passphrase
|
||||
# ...
|
||||
new_key = ck.rotate() # writes a fresh key file (atomic temp+rename);
|
||||
# caller is responsible for re-encrypting data
|
||||
ck.delete() # remove the key file (test teardown / uninstall)
|
||||
|
||||
Multi-user ready: each ``user_id`` derives its own passphrase salt
|
||||
(``sha256(user_id)[:16]``). The current product ships a single
|
||||
``user_id="default"`` but the architecture supports per-user isolation for
|
||||
future multi-tenant deployments. (The file backend itself is currently
|
||||
single-tenant — one ``.crypto.key`` per store root.)
|
||||
|
||||
Thread-safety: instance-level ``_cached_key`` hides repeated
|
||||
``get_or_create()`` calls from the file backend (one read per process
|
||||
lifetime, not per call).
|
||||
"""
|
||||
|
||||
SERVICE_NAME: str = SERVICE_NAME_DEFAULT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "default",
|
||||
store_root: Path | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.store_root: Path | None = store_root
|
||||
self._cached_key: Optional[bytes] = None
|
||||
|
||||
# ---------------------------------------------------------------- helpers
|
||||
|
||||
def _passphrase_salt(self) -> bytes:
|
||||
"""Per-user salt for the passphrase fallback; deterministic across runs."""
|
||||
return hashlib.sha256(self.user_id.encode("utf-8")).digest()[:16]
|
||||
|
||||
def _key_file_path(self) -> Path:
|
||||
"""Resolve ``{store_root}/.crypto.key`` (Phase 07.10 D-03).
|
||||
|
||||
Lazy resolution: if ``self.store_root`` was not supplied at
|
||||
construction, read ``IAI_MCP_STORE`` env or fall back to the project
|
||||
default ``~/.iai-mcp`` — the same precedence ``MemoryStore.__init__``
|
||||
uses. Resolving here (not in ``__init__``) lets a test set
|
||||
``IAI_MCP_STORE`` after a CryptoKey instance was already created
|
||||
without the kwarg.
|
||||
"""
|
||||
if self.store_root is not None:
|
||||
root = Path(self.store_root)
|
||||
else:
|
||||
env_path = os.environ.get("IAI_MCP_STORE")
|
||||
root = Path(env_path) if env_path else _DEFAULT_STORE_ROOT
|
||||
return root / _KEY_FILE_NAME
|
||||
|
||||
def _try_file_get(self) -> Optional[bytes]:
|
||||
"""Return 32 raw bytes from the key file; ``None`` if the file is absent.
|
||||
|
||||
strict validation:
|
||||
- mode strictly ``0o600`` — refuse if any group/world bits are set
|
||||
(``mode & 0o077 != 0``) with ``CryptoKeyError("...insecure mode...")``
|
||||
- ``st_uid == os.geteuid()`` — refuse files owned by a different user
|
||||
with ``CryptoKeyError("...uid...")``
|
||||
- file length exactly ``KEY_BYTES`` — refuse with
|
||||
``CryptoKeyError("...wrong length...")``
|
||||
|
||||
Each rejection emits a distinct error message so misconfigurations are
|
||||
diagnosable at a glance.
|
||||
"""
|
||||
path = self._key_file_path()
|
||||
if not path.exists():
|
||||
return None
|
||||
# Use ``os.stat`` rather than ``Path.stat`` so test harnesses can
|
||||
# monkeypatch ``os.stat`` to simulate foreign-uid scenarios at the
|
||||
# syscall boundary (Phase 07.10 W1 case 4 path-scoped fake stat).
|
||||
st = os.stat(path)
|
||||
# Mode check: owner-only bits permitted.
|
||||
if st.st_mode & 0o077 != 0:
|
||||
raise CryptoKeyError(
|
||||
f"crypto key file at {path} has insecure mode "
|
||||
f"0o{st.st_mode & 0o777:03o}; expected 0o600 "
|
||||
f"(run: chmod 0o600 {path})"
|
||||
)
|
||||
# UID check: refuse files owned by a different user.
|
||||
if st.st_uid != os.geteuid():
|
||||
raise CryptoKeyError(
|
||||
f"crypto key file at {path} is owned by uid={st.st_uid}; "
|
||||
f"current process runs as uid={os.geteuid()} (refusing to read)"
|
||||
)
|
||||
raw = path.read_bytes()
|
||||
if len(raw) != KEY_BYTES:
|
||||
raise CryptoKeyError(
|
||||
f"crypto key file at {path} has wrong length {len(raw)} "
|
||||
f"(expected {KEY_BYTES})"
|
||||
)
|
||||
return raw
|
||||
|
||||
def _try_file_set(self, key: bytes) -> None:
|
||||
"""Atomically write ``key`` to the key file (Phase 07.10 D-07).
|
||||
|
||||
Pattern:
|
||||
1. ``mkdir -p`` the parent directory.
|
||||
2. Remove any stale ``{path}.tmp.*`` siblings from prior crashed runs.
|
||||
3. Open ``{path}.tmp.{pid}`` with ``O_CREAT|O_EXCL|O_WRONLY`` mode
|
||||
``0o600`` — refuses if a tmp file at the same pid already exists.
|
||||
4. ``os.fchmod(fd, 0o600)`` BEFORE writing bytes — defends against
|
||||
umask quirks, makes the mode-restriction window zero.
|
||||
5. ``os.write`` + ``os.fsync`` + ``os.close``.
|
||||
6. ``os.rename`` the tmp file to the final path (atomic on POSIX).
|
||||
|
||||
``ValueError`` is raised if ``key`` is not exactly ``KEY_BYTES`` long.
|
||||
"""
|
||||
if len(key) != KEY_BYTES:
|
||||
raise ValueError(f"key must be {KEY_BYTES} bytes (got {len(key)})")
|
||||
final = self._key_file_path()
|
||||
final.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Clean stale tmp files from prior crashed runs so the new write is
|
||||
# never confused by leftover state.
|
||||
for stale in final.parent.glob(f"{final.name}.tmp.*"):
|
||||
try:
|
||||
stale.unlink()
|
||||
except OSError:
|
||||
# Best-effort cleanup; if unlink fails we still proceed and
|
||||
# the EXCL open below will refuse if our pid happens to
|
||||
# collide with a leftover.
|
||||
pass
|
||||
tmp = final.parent / f"{final.name}.tmp.{os.getpid()}"
|
||||
# ``O_CREAT | O_EXCL | O_WRONLY`` refuses if a tmp at this exact pid
|
||||
# already exists; combined with the cleanup above, this guarantees a
|
||||
# fresh write path. ``mode=0o600`` is enforced atomically by ``open``.
|
||||
fd = os.open(str(tmp), os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
|
||||
try:
|
||||
# Explicit ``fchmod`` BEFORE writing bytes: defends against any
|
||||
# umask quirk that might subtly relax the mode after open. The
|
||||
# window where the tmp file exists with permissive bits is zero.
|
||||
os.fchmod(fd, 0o600)
|
||||
os.write(fd, key)
|
||||
os.fsync(fd)
|
||||
finally:
|
||||
os.close(fd)
|
||||
os.rename(str(tmp), str(final))
|
||||
|
||||
# -------------------------------------------------------- public API
|
||||
|
||||
def get_or_create(self) -> bytes:
|
||||
"""Return the 256-bit AES key for this user_id.
|
||||
|
||||
priority:
|
||||
1. Instance cache (``self._cached_key``) — avoids repeated file reads.
|
||||
2. File backend (``_try_file_get``) — returns the 32 raw bytes from
|
||||
``{store_root}/.crypto.key`` if present, else ``None``.
|
||||
3. Passphrase fallback — derives a key from
|
||||
``IAI_MCP_CRYPTO_PASSPHRASE`` via PBKDF2; deterministic given
|
||||
``(passphrase, user_id)``. The derived key is NOT written to disk
|
||||
— it lives only in the instance cache for the session.
|
||||
4. Otherwise raise ``CryptoKeyError`` naming all remediation paths
|
||||
(``iai-mcp crypto migrate-to-file``, ``iai-mcp crypto init``,
|
||||
``IAI_MCP_CRYPTO_PASSPHRASE``).
|
||||
"""
|
||||
if self._cached_key is not None:
|
||||
return self._cached_key
|
||||
|
||||
# Priority 1: file backend (Phase 07.10 D-02).
|
||||
existing = self._try_file_get()
|
||||
if existing is not None:
|
||||
self._cached_key = existing
|
||||
return existing
|
||||
|
||||
# Priority 2: passphrase fallback (CI / non-interactive / fresh-install opt-in).
|
||||
passphrase = os.environ.get("IAI_MCP_CRYPTO_PASSPHRASE")
|
||||
if passphrase:
|
||||
derived = derive_key_from_passphrase(passphrase, self._passphrase_salt())
|
||||
self._cached_key = derived
|
||||
return derived
|
||||
|
||||
# Priority 3: refuse with a dual-remediation error message (Phase 07.10 D-04).
|
||||
path = self._key_file_path()
|
||||
raise CryptoKeyError(
|
||||
f"crypto key file not found at {path} and IAI_MCP_CRYPTO_PASSPHRASE "
|
||||
f"is not set.\n"
|
||||
f"\n"
|
||||
f"To fix:\n"
|
||||
f" - Existing install (key was in macOS Keychain before Phase 07.10): "
|
||||
f"run `iai-mcp crypto migrate-to-file` from a Terminal where the "
|
||||
f"Keychain prompt can appear, then click \"Always Allow\".\n"
|
||||
f" - Fresh install: run `iai-mcp crypto init` to generate a new key "
|
||||
f"file, OR set IAI_MCP_CRYPTO_PASSPHRASE to a strong passphrase "
|
||||
f"(suitable for CI or non-interactive environments)."
|
||||
)
|
||||
|
||||
def rotate(self) -> bytes:
|
||||
"""Generate a fresh 32-byte key, write it to the key file, return it.
|
||||
|
||||
rotation is now an atomic file-write operation,
|
||||
irrespective of how the previous key was sourced. Caller is responsible
|
||||
for re-encrypting any existing ciphertext under the old key (see
|
||||
``iai-mcp crypto rotate`` CLI; re-encryption is an application-layer
|
||||
concern). The cached instance key is updated so subsequent calls in
|
||||
the same process see the new key.
|
||||
"""
|
||||
fresh = secrets.token_bytes(KEY_BYTES)
|
||||
self._try_file_set(fresh)
|
||||
self._cached_key = fresh
|
||||
return fresh
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Remove the key file (and drop the cache). Idempotent on absent files."""
|
||||
self._cached_key = None
|
||||
path = self._key_file_path()
|
||||
try:
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
# Idempotent: nothing to delete.
|
||||
pass
|
||||
77
src/iai_mcp/crypto_key_watch.py
Normal file
77
src/iai_mcp/crypto_key_watch.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Boot-time detection of ``.crypto.key`` file rotation for audit events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
WATCHER_REL = ".crypto-key-watcher.json"
|
||||
|
||||
|
||||
def _watcher_path(store: "MemoryStore") -> Path:
|
||||
return store.root / WATCHER_REL
|
||||
|
||||
|
||||
def _key_path(store: "MemoryStore") -> Path:
|
||||
return store.root / ".crypto.key"
|
||||
|
||||
|
||||
def sync_crypto_key_watcher_to_disk(store: "MemoryStore") -> None:
|
||||
"""Persist watcher state matching the current key file (no event)."""
|
||||
kp = _key_path(store)
|
||||
if not kp.is_file():
|
||||
return
|
||||
st = kp.stat()
|
||||
cur = {"mtime_ns": int(st.st_mtime_ns), "size": int(st.st_size)}
|
||||
wp = _watcher_path(store)
|
||||
wp.write_text(json.dumps(cur), encoding="utf-8")
|
||||
try:
|
||||
os.chmod(wp, 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def check_crypto_key_file_rotation_event(store: "MemoryStore") -> None:
|
||||
"""Emit ``crypto_key_rotated`` when ``.crypto.key`` mtime/size changed since last persist.
|
||||
|
||||
First run (no watcher file): writes baseline only — no event (cannot
|
||||
distinguish "first install" from "rotation" without prior state).
|
||||
"""
|
||||
from iai_mcp.events import write_event
|
||||
|
||||
kp = _key_path(store)
|
||||
if not kp.is_file():
|
||||
return
|
||||
st = kp.stat()
|
||||
cur = {"mtime_ns": int(st.st_mtime_ns), "size": int(st.st_size)}
|
||||
wp = _watcher_path(store)
|
||||
prev: dict | None = None
|
||||
if wp.is_file():
|
||||
try:
|
||||
prev = json.loads(wp.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
prev = None
|
||||
if prev is None:
|
||||
sync_crypto_key_watcher_to_disk(store)
|
||||
return
|
||||
if prev.get("mtime_ns") == cur["mtime_ns"] and prev.get("size") == cur["size"]:
|
||||
return
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
kind="crypto_key_rotated",
|
||||
data={
|
||||
"source": "daemon_boot",
|
||||
"previous": prev,
|
||||
"current": cur,
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
sync_crypto_key_watcher_to_disk(store)
|
||||
81
src/iai_mcp/cue_router.py
Normal file
81
src/iai_mcp/cue_router.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Plan 06-04 R4: cue-detection router.
|
||||
|
||||
Classifies a memory_recall cue into 'verbatim' or 'concept' mode based on
|
||||
surface signals (quoted phrases, exact-recall markers, RU starts-with
|
||||
triggers). Drives mode-dependent retrieval in both pipeline_recall (full
|
||||
graph path) and retrieve.recall (baseline fallback).
|
||||
|
||||
Constitutional framing:
|
||||
- Mottron EPF / Bowler TSH / Murray monotropism: when the cue signals exact
|
||||
recall, the user wants ONE hit, not 30. Verbatim mode is the response shape.
|
||||
- McClelland CLS: episodic and semantic stores have distinguishable retrieval
|
||||
surfaces; the cue tells us which store the user is asking.
|
||||
- Beer VSM S1 vs S4: verbatim is operations, schema is intelligence; the
|
||||
router separates the two recursion levels at the entrypoint.
|
||||
- Ashby ultrastability: the North-Star verbatim ≥99% essential variable is
|
||||
defended at the entrypoint — any verbatim-flavoured cue routes to the
|
||||
surface that protects it (tier filter + zeroed graph-bonus).
|
||||
|
||||
Triggers per CONTEXT (compiled once at module load):
|
||||
|
||||
EN (re.IGNORECASE):
|
||||
- quoted-phrase : "..." (one pair of straight double quotes around text)
|
||||
- european-quote : «...» (one pair of guillemets around text)
|
||||
- word-marker : verbatim | exact | quote | quoted | said | wrote
|
||||
- day-N : day <digits> (e.g. "day 17", "Day 7")
|
||||
|
||||
RU (case-insensitive, anchored at start-of-cue ^):
|
||||
- ru-start-найди-дословно
|
||||
- ru-start-точная-цитата
|
||||
- ru-start-что-я-сказал
|
||||
- ru-start-что-я-писал
|
||||
|
||||
Behaviour:
|
||||
- Any one EN match wins (returned with its label) and the function returns
|
||||
("verbatim", label) immediately.
|
||||
- Otherwise any one RU match wins (returned with its label).
|
||||
- No match -> ("concept", None).
|
||||
- Empty / falsy text -> ("concept", None).
|
||||
|
||||
The triggered_pattern label is for diagnostic logging (event payloads,
|
||||
debug traces) and is NOT surfaced on the JSON-RPC response — only the
|
||||
mode string lives in RecallResponse.cue_mode.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
EN_TRIGGERS: list[tuple[str, re.Pattern]] = [
|
||||
("quoted-phrase", re.compile(r'"[^"]+"')),
|
||||
("european-quote", re.compile(r'«[^»]+»')),
|
||||
("word-marker", re.compile(r'\b(verbatim|exact|quote|quoted|said|wrote)\b', re.IGNORECASE)),
|
||||
("day-N", re.compile(r'\bday\s+\d+\b', re.IGNORECASE)),
|
||||
]
|
||||
|
||||
RU_TRIGGERS: list[tuple[str, re.Pattern]] = [
|
||||
("ru-start-найди-дословно", re.compile(r'^найди дословно', re.IGNORECASE)),
|
||||
("ru-start-точная-цитата", re.compile(r'^точная цитата', re.IGNORECASE)),
|
||||
("ru-start-что-я-сказал", re.compile(r'^что я сказал', re.IGNORECASE)),
|
||||
("ru-start-что-я-писал", re.compile(r'^что я писал', re.IGNORECASE)),
|
||||
]
|
||||
|
||||
|
||||
def _classify_cue(text: str) -> tuple[str, str | None]:
|
||||
"""Return (mode, triggered_pattern) for the given cue.
|
||||
|
||||
mode is "verbatim" if any trigger matches, else "concept".
|
||||
triggered_pattern is the trigger label (string) on a verbatim hit, or
|
||||
None when the cue routes to concept (no trigger matched).
|
||||
|
||||
Empty / None-ish input returns ("concept", None) — defensive default
|
||||
so the dispatcher never crashes on a missing cue field.
|
||||
"""
|
||||
if not text:
|
||||
return "concept", None
|
||||
for label, pat in EN_TRIGGERS:
|
||||
if pat.search(text):
|
||||
return "verbatim", label
|
||||
for label, pat in RU_TRIGGERS:
|
||||
if pat.search(text):
|
||||
return "verbatim", label
|
||||
return "concept", None
|
||||
225
src/iai_mcp/curiosity.py
Normal file
225
src/iai_mcp/curiosity.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Active curiosity (LEARN-04, D-23, D-24) -- Task 4.
|
||||
|
||||
D-23 trigger: prediction entropy > 0.7 bits AND 3-turn cooldown since last
|
||||
curiosity question in this session.
|
||||
|
||||
D-24 tiered style:
|
||||
- entropy in [ENTROPY_LOW, ENTROPY_MID) -> silent log event, no question
|
||||
- entropy in [ENTROPY_MID, ENTROPY_HIGH) -> inline hint
|
||||
- entropy >= ENTROPY_HIGH -> direct clarifying question
|
||||
|
||||
Every question creates curiosity_bridge edges from each triggering record to
|
||||
the question's UUID (used as a stable hub id). The question itself lives in
|
||||
the events table (kind=curiosity_question); callers may insert a first-class
|
||||
record if persistent text is desired, but keeps questions
|
||||
event-sourced to minimise LanceDB write volume.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from iai_mcp.events import query_events, write_event
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- constants
|
||||
|
||||
|
||||
ENTROPY_LOW: float = 0.4
|
||||
ENTROPY_MID: float = 0.7
|
||||
ENTROPY_HIGH: float = 0.9
|
||||
COOLDOWN_TURNS: int = 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- types
|
||||
|
||||
|
||||
@dataclass
|
||||
class CuriosityQuestion:
|
||||
"""One curiosity question surfaced by fire_curiosity."""
|
||||
|
||||
id: UUID
|
||||
text: str
|
||||
triggered_by_record_ids: list[UUID] = field(default_factory=list)
|
||||
entropy: float = 0.0
|
||||
tier: str = "question" # "silent" | "inline" | "question"
|
||||
resolved: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- helpers
|
||||
|
||||
|
||||
def compute_entropy(scores: list[float]) -> float:
|
||||
"""Shannon entropy (base-2, bits) over a score distribution.
|
||||
|
||||
Returns 0.0 for empty or degenerate inputs. Negative scores are clamped
|
||||
to 0 before normalisation so the probability vector is well-defined.
|
||||
"""
|
||||
if not scores:
|
||||
return 0.0
|
||||
positive = [max(0.0, float(s)) for s in scores]
|
||||
total = sum(positive)
|
||||
if total <= 0:
|
||||
return 0.0
|
||||
probs = [p / total for p in positive]
|
||||
h = 0.0
|
||||
for p in probs:
|
||||
if p > 0:
|
||||
h -= p * math.log2(p)
|
||||
return h
|
||||
|
||||
|
||||
def _last_curiosity_turn(store: MemoryStore, session_id: str) -> int | None:
|
||||
"""Return the turn of the most recent curiosity_question in this session."""
|
||||
events = query_events(store, kind="curiosity_question", limit=20)
|
||||
for e in events:
|
||||
if e.get("session_id") == session_id:
|
||||
try:
|
||||
return int(e["data"].get("turn", 0))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- fire_curiosity
|
||||
|
||||
|
||||
def fire_curiosity(
|
||||
store: MemoryStore,
|
||||
hits: list,
|
||||
cue: str,
|
||||
entropy: float,
|
||||
session_id: str,
|
||||
turn: int,
|
||||
) -> CuriosityQuestion | None:
|
||||
"""D-23 gate + tiering.
|
||||
|
||||
Returns a CuriosityQuestion (or None) and, as a side effect:
|
||||
- emits a curiosity_silent_log event for low-entropy misses
|
||||
- emits a curiosity_question event for mid/high fires
|
||||
- creates curiosity_bridge edges from each triggering record -> question
|
||||
"""
|
||||
if entropy < ENTROPY_LOW:
|
||||
return None
|
||||
|
||||
# Low-mid band -> silent log, no question.
|
||||
if entropy < ENTROPY_MID:
|
||||
write_event(
|
||||
store,
|
||||
kind="curiosity_silent_log",
|
||||
data={
|
||||
"cue": cue[:200],
|
||||
"entropy": float(entropy),
|
||||
"source_ids": [str(h.record_id) for h in hits[:3]],
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Cooldown check.
|
||||
last = _last_curiosity_turn(store, session_id)
|
||||
if last is not None and (turn - last) < COOLDOWN_TURNS:
|
||||
return None
|
||||
|
||||
q_id = uuid4()
|
||||
if entropy < ENTROPY_HIGH:
|
||||
tier = "inline"
|
||||
text = f"I'm not fully sure -- did you mean {cue!r}?"
|
||||
else:
|
||||
tier = "question"
|
||||
text = f"Could you clarify: {cue!r}?"
|
||||
|
||||
trigger_ids: list[UUID] = [h.record_id for h in hits[:5]]
|
||||
question = CuriosityQuestion(
|
||||
id=q_id,
|
||||
text=text,
|
||||
triggered_by_record_ids=trigger_ids,
|
||||
entropy=float(entropy),
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
# curiosity_bridge edges. Delta proportional to entropy so higher-entropy
|
||||
# questions get stronger edges.
|
||||
# R3: batch all triggers into a single boost_edges call
|
||||
# (one merge_insert + one tbl.add at most). The diagnostic try/except
|
||||
# boundary is preserved at the SINGLE-call level — failure of the batched
|
||||
# write must never block the curiosity fire path.
|
||||
bridge_pairs = [(tid, q_id) for tid in trigger_ids]
|
||||
if bridge_pairs:
|
||||
try:
|
||||
store.boost_edges(
|
||||
bridge_pairs,
|
||||
edge_type="curiosity_bridge",
|
||||
delta=float(entropy),
|
||||
)
|
||||
except Exception:
|
||||
# Diagnostic; never block the curiosity fire on edge failure.
|
||||
pass
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="curiosity_question",
|
||||
data={
|
||||
"question_id": str(q_id),
|
||||
"text": text,
|
||||
"tier": tier,
|
||||
"entropy": float(entropy),
|
||||
"turn": int(turn),
|
||||
"triggered_by": [str(t) for t in trigger_ids],
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=trigger_ids,
|
||||
)
|
||||
return question
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- pending
|
||||
|
||||
|
||||
def pending_questions(
|
||||
store: MemoryStore,
|
||||
session_id: str | None = None,
|
||||
) -> list[CuriosityQuestion]:
|
||||
"""Return unresolved curiosity questions, optionally scoped to a session."""
|
||||
events = query_events(store, kind="curiosity_question", limit=200)
|
||||
resolved_events = query_events(store, kind="curiosity_resolved", limit=500)
|
||||
resolved_ids = {
|
||||
r["data"].get("question_id")
|
||||
for r in resolved_events
|
||||
if r["data"].get("question_id")
|
||||
}
|
||||
out: list[CuriosityQuestion] = []
|
||||
for e in events:
|
||||
if session_id is not None and e.get("session_id") != session_id:
|
||||
continue
|
||||
data = e["data"]
|
||||
qid_raw = data.get("question_id")
|
||||
if not qid_raw:
|
||||
continue
|
||||
if qid_raw in resolved_ids:
|
||||
continue
|
||||
try:
|
||||
qid = UUID(qid_raw)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
triggered: list[UUID] = []
|
||||
for t in data.get("triggered_by", []):
|
||||
try:
|
||||
triggered.append(UUID(t))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
out.append(
|
||||
CuriosityQuestion(
|
||||
id=qid,
|
||||
text=data.get("text", ""),
|
||||
triggered_by_record_ids=triggered,
|
||||
entropy=float(data.get("entropy", 0.0)),
|
||||
tier=data.get("tier", "question"),
|
||||
resolved=False,
|
||||
)
|
||||
)
|
||||
return out
|
||||
1690
src/iai_mcp/daemon.py
Normal file
1690
src/iai_mcp/daemon.py
Normal file
File diff suppressed because it is too large
Load diff
294
src/iai_mcp/daemon_state.py
Normal file
294
src/iai_mcp/daemon_state.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Phase 4 -- atomic daemon state persistence (DAEMON-01 / D-24).
|
||||
|
||||
State file at ~/.iai-mcp/.daemon-state.json holds:
|
||||
- fsm_state -- WAKE / TRANSITIONING / SLEEP / DREAMING
|
||||
- daemon_started_at -- ISO8601 UTC
|
||||
- last_digest_shown_at -- ISO8601 UTC, used by morning digest gate
|
||||
- pending_digest -- dict ready to surface in next memory_recall
|
||||
- last_learned_at -- last quiet-window learn timestamp
|
||||
- last_session_ts -- last observed session_started event ts
|
||||
|
||||
All writes via tempfile + os.replace (POSIX atomic rename). Crash-mid-write
|
||||
leaves the old file intact; readers either see old complete or new complete,
|
||||
never partial.
|
||||
|
||||
T-04-01 mitigation: atomic rename precludes torn writes.
|
||||
T-04-07 mitigation: file mode 0o600 user-only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
STATE_PATH: Path = Path.home() / ".iai-mcp" / ".daemon-state.json"
|
||||
|
||||
# morning-digest gating threshold. The digest is surfaced only when it
|
||||
# has been at least this many hours since the last show (or has never shown).
|
||||
DIGEST_SHOW_THRESHOLD_HOURS: int = 18
|
||||
|
||||
# first_turn_pending eviction guards. A session is considered stale once it
|
||||
# has sat in the dict for longer than FIRST_TURN_TTL_HOURS -- typically it
|
||||
# means the client died before consuming the flag, so the entry will never
|
||||
# be popped by ``consume_first_turn``. MAX_FIRST_TURN_ENTRIES caps the dict
|
||||
# as a secondary safety net when many sessions open in a short window.
|
||||
FIRST_TURN_TTL_HOURS: int = 24
|
||||
MAX_FIRST_TURN_ENTRIES: int = 100
|
||||
|
||||
|
||||
def load_state() -> dict:
|
||||
"""Read the state file; return {} if missing or malformed (self-heal)."""
|
||||
if not STATE_PATH.exists():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(STATE_PATH.read_text())
|
||||
except (OSError, json.JSONDecodeError):
|
||||
# Corrupt file -- return empty dict; next save_state writes fresh.
|
||||
return {}
|
||||
|
||||
|
||||
def save_state(state: dict) -> None:
|
||||
"""Atomically persist state via tempfile + os.replace.
|
||||
|
||||
Semantics:
|
||||
- Creates parent dir if missing.
|
||||
- Writes to a sibling temp file in the same directory (required so
|
||||
os.replace can do an atomic rename on the same filesystem).
|
||||
- fsync the file contents before rename so the data is on disk.
|
||||
- chmod 0o600 before the swap so the visible file is never world-readable.
|
||||
- On exception: unlink the temp file so `/tmp` doesn't accumulate.
|
||||
"""
|
||||
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".daemon-state.",
|
||||
suffix=".tmp",
|
||||
dir=str(STATE_PATH.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.chmod(tmp, 0o600)
|
||||
os.replace(tmp, STATE_PATH)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def prune_stale_first_turn(
|
||||
state: dict,
|
||||
now: datetime | None = None,
|
||||
ttl_hours: int = FIRST_TURN_TTL_HOURS,
|
||||
max_entries: int = MAX_FIRST_TURN_ENTRIES,
|
||||
) -> int:
|
||||
"""Evict first_turn_pending entries older than ``ttl_hours`` and cap the
|
||||
dict at ``max_entries`` (keep newest by timestamp). Returns the number
|
||||
of entries removed.
|
||||
|
||||
Accepts legacy values ``True`` / ``False`` as "unknown timestamp" and
|
||||
stamps them with ``now`` so they age out on the next prune. Idempotent;
|
||||
safe to call on every save.
|
||||
"""
|
||||
pending = state.get("first_turn_pending")
|
||||
if not isinstance(pending, dict) or not pending:
|
||||
return 0
|
||||
|
||||
current = now if now is not None else datetime.now(timezone.utc)
|
||||
if current.tzinfo is None:
|
||||
current = current.replace(tzinfo=timezone.utc)
|
||||
cutoff = current - timedelta(hours=ttl_hours)
|
||||
|
||||
def _as_dt(value: object) -> datetime:
|
||||
"""Parse stored value into an aware datetime; unknown -> epoch (evict).
|
||||
|
||||
Legacy bool / malformed strings are treated as "stale, evict now" —
|
||||
they cannot be aged sensibly without a real timestamp, and the
|
||||
former "stamp with current" behaviour kept the dict from ever
|
||||
draining when clients died before writing ISO timestamps.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
return datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
# Normalise every entry to an ISO timestamp string so downstream
|
||||
# callers see a consistent value shape after the first prune.
|
||||
removed = 0
|
||||
for sid, value in list(pending.items()):
|
||||
dt = _as_dt(value)
|
||||
if dt < cutoff:
|
||||
pending.pop(sid, None)
|
||||
removed += 1
|
||||
elif not isinstance(value, str):
|
||||
pending[sid] = dt.isoformat()
|
||||
|
||||
# Secondary cap — keep the newest ``max_entries`` by timestamp.
|
||||
if len(pending) > max_entries:
|
||||
ordered = sorted(
|
||||
pending.items(),
|
||||
key=lambda kv: _as_dt(kv[1]),
|
||||
reverse=True,
|
||||
)
|
||||
keep = dict(ordered[:max_entries])
|
||||
removed += len(pending) - len(keep)
|
||||
state["first_turn_pending"] = keep
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def mark_session_opened(state: dict, session_id: str) -> None:
|
||||
"""Plan 05-03 TOK-12 / D5-03: mark first_turn_pending for a session.
|
||||
|
||||
Stores the opening timestamp as the dict value so ``prune_stale_first_turn``
|
||||
can evict entries whose client died before consuming the flag. Opportunistic
|
||||
prune on every mark keeps the dict bounded without a dedicated reaper.
|
||||
|
||||
Idempotent. Persistence is the caller's responsibility (typical callers:
|
||||
concurrency socket handler; tests directly).
|
||||
"""
|
||||
if not isinstance(session_id, str) or not session_id:
|
||||
return
|
||||
pending = state.setdefault("first_turn_pending", {})
|
||||
pending[session_id] = datetime.now(timezone.utc).isoformat()
|
||||
prune_stale_first_turn(state)
|
||||
|
||||
|
||||
def consume_first_turn(state: dict, session_id: str) -> bool:
|
||||
"""Return True iff first call for session; atomic pop+save.
|
||||
|
||||
D5-03: the first memory_recall in a session consumes the
|
||||
flag so subsequent recalls bypass the first-turn hook.
|
||||
"""
|
||||
try:
|
||||
pending = state.get("first_turn_pending")
|
||||
if not isinstance(pending, dict):
|
||||
return False
|
||||
if pending.pop(session_id, False):
|
||||
try:
|
||||
save_state(state)
|
||||
except Exception:
|
||||
# save failure is non-fatal — returning True still triggers
|
||||
# the hook exactly once in-process; cross-process atomicity
|
||||
# is best-effort.
|
||||
pass
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# R3 (per D7.2-07 / D7.2-08 / D7.2-10): a per-tick + startup
|
||||
# reaper for stale `first_turn_pending` entries with a 1-hour TTL and a
|
||||
# tuple return shape (updated_state, dropped_session_ids).
|
||||
#
|
||||
# Distinct from `prune_stale_first_turn` above which has a 24h ceiling and
|
||||
# is opportunistically invoked from `mark_session_opened`. Both helpers
|
||||
# coexist by design (researcher finding #1 + advisor recommendation):
|
||||
# - `prune_stale_first_turn` keeps its 24h opportunistic path on session-open;
|
||||
# - `prune_first_turn_pending` is the per-tick + startup reaper that needs
|
||||
# the dropped IDs back so the caller can emit
|
||||
# `kind=first_turn_pending_expired` events (D7.2-10).
|
||||
#
|
||||
# Pure function — no I/O. Caller is responsible for `save_state(state)`
|
||||
# and the event emit. Idempotent; safe on empty/missing input.
|
||||
|
||||
FIRST_TURN_PENDING_TTL_SEC_DEFAULT: float = 3600.0 # D7.2-08 1h default
|
||||
|
||||
|
||||
def prune_first_turn_pending(
|
||||
state: dict,
|
||||
now: datetime | None = None,
|
||||
ttl_sec: float = FIRST_TURN_PENDING_TTL_SEC_DEFAULT,
|
||||
) -> tuple[dict, list[str]]:
|
||||
"""Phase 7.2 R3: drain stale `first_turn_pending` entries.
|
||||
|
||||
Returns (updated_state_dict, dropped_session_ids). Pure function —
|
||||
does NOT call save_state; does NOT emit events. Caller decides
|
||||
persistence + event emission.
|
||||
|
||||
Eviction rules:
|
||||
- String value parsed as ISO timestamp; entry evicts if (now - ts) >= ttl_sec.
|
||||
- Non-string value (legacy bool / dict / None) treated as stale → evict.
|
||||
Matches the established behavior of `prune_stale_first_turn` for
|
||||
legacy entries (cannot be aged sensibly without a timestamp).
|
||||
- Naive timestamps assumed UTC.
|
||||
- Malformed ISO strings → evict (defensive against corruption).
|
||||
|
||||
Distinct from `prune_stale_first_turn` (24h default, returns int);
|
||||
this helper is per-tick + startup with a shorter TTL and visibility
|
||||
into which sessions were dropped (D7.2-10 event payload needs the
|
||||
session_ids list).
|
||||
"""
|
||||
pending = state.get("first_turn_pending")
|
||||
if not isinstance(pending, dict) or not pending:
|
||||
return state, []
|
||||
|
||||
current = now if now is not None else datetime.now(timezone.utc)
|
||||
if current.tzinfo is None:
|
||||
current = current.replace(tzinfo=timezone.utc)
|
||||
cutoff = current - timedelta(seconds=ttl_sec)
|
||||
|
||||
dropped: list[str] = []
|
||||
fresh: dict = {}
|
||||
for sid, value in pending.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
ts = datetime.fromisoformat(value)
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
dropped.append(sid)
|
||||
continue
|
||||
if ts < cutoff:
|
||||
dropped.append(sid)
|
||||
continue
|
||||
fresh[sid] = value
|
||||
else:
|
||||
# Legacy bool / dict / None / number — no recoverable timestamp.
|
||||
dropped.append(sid)
|
||||
|
||||
state["first_turn_pending"] = fresh
|
||||
return state, dropped
|
||||
|
||||
|
||||
def get_pending_digest(state: dict, now: datetime) -> dict | None:
|
||||
"""D-24 / DAEMON-11: return pending morning digest if eligible, else None.
|
||||
|
||||
Eligibility gate: >= DIGEST_SHOW_THRESHOLD_HOURS since last_digest_shown_at
|
||||
OR never shown. When returned, the digest is consumed from state and
|
||||
last_digest_shown_at is advanced to `now`; state is persisted via
|
||||
save_state so the same digest never appears twice in the same 18h window.
|
||||
"""
|
||||
last_shown = state.get("last_digest_shown_at")
|
||||
if last_shown:
|
||||
try:
|
||||
last_dt = datetime.fromisoformat(last_shown)
|
||||
if last_dt.tzinfo is None:
|
||||
last_dt = last_dt.replace(tzinfo=timezone.utc)
|
||||
now_cmp = now if now.tzinfo is not None else now.replace(tzinfo=timezone.utc)
|
||||
if now_cmp - last_dt < timedelta(hours=DIGEST_SHOW_THRESHOLD_HOURS):
|
||||
return None
|
||||
except (TypeError, ValueError):
|
||||
# Malformed timestamp -- treat as never shown, fall through.
|
||||
pass
|
||||
|
||||
digest = state.get("pending_digest")
|
||||
if not digest:
|
||||
return None
|
||||
|
||||
now_cmp = now if now.tzinfo is not None else now.replace(tzinfo=timezone.utc)
|
||||
state["last_digest_shown_at"] = now_cmp.isoformat()
|
||||
state.pop("pending_digest", None)
|
||||
save_state(state)
|
||||
return digest
|
||||
79
src/iai_mcp/delegate.py
Normal file
79
src/iai_mcp/delegate.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""TOK-07 subagent delegation context (Plan 02-04 Task 3, D-27).
|
||||
|
||||
Parent session exposes a JSON blob containing the 4-segment session-start
|
||||
payload (L0, L1, L2, rich-club) plus per-component hashes (for delta
|
||||
encoding) and a proxy-tools schema listing the 5 Phase-1 memory tools the
|
||||
subagent may invoke via the parent.
|
||||
|
||||
The subagent inherits the parent's session cache; it does NOT re-load the
|
||||
graph from scratch. This matches the Claude Code subagent-context feature
|
||||
request (#20304).
|
||||
|
||||
Constitutional note: the 3 MCP surface tools (curiosity_pending,
|
||||
schema_list, events_query) are user-introspection surfaces and are NOT
|
||||
included in SUBAGENT_HOT_TOOLS. Subagents receive the 5 memory tools; user
|
||||
introspection stays with the parent session.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# The 5 memory tools exposed to subagents (Phase 1 hot surface). Plan 02-04's
|
||||
# new user-introspection tools are intentionally excluded.
|
||||
SUBAGENT_HOT_TOOLS: tuple[str, ...] = (
|
||||
"memory_recall",
|
||||
"memory_reinforce",
|
||||
"memory_contradict",
|
||||
"memory_consolidate",
|
||||
"profile_get_set",
|
||||
)
|
||||
|
||||
|
||||
def subagent_proxy_tools() -> list[dict]:
|
||||
"""Return a list of tool stubs advertised to the subagent.
|
||||
|
||||
Each stub carries `name` + `proxied_via`; the subagent invokes its
|
||||
parent's MCP bridge with the tool name, and the parent forwards the call
|
||||
to the Python core.
|
||||
"""
|
||||
return [
|
||||
{"name": name, "proxied_via": "parent_session"}
|
||||
for name in SUBAGENT_HOT_TOOLS
|
||||
]
|
||||
|
||||
|
||||
def serialize_session_for_subagent(
|
||||
store,
|
||||
assignment,
|
||||
rich_club,
|
||||
) -> dict:
|
||||
"""Build a JSON-safe dict for subagent spawn.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"l0": str,
|
||||
"l1": str,
|
||||
"l2": list[str],
|
||||
"rich_club": str,
|
||||
"hashes": {"l0": str, "l1": str, "l2": str, "rich_club": str},
|
||||
"proxy_tools": [{"name": ..., "proxied_via": "parent_session"}, ...],
|
||||
}
|
||||
"""
|
||||
from iai_mcp.delta import build_delta
|
||||
from iai_mcp.session import assemble_session_start
|
||||
|
||||
payload = assemble_session_start(store, assignment, rich_club)
|
||||
payload_dict = {
|
||||
"l0": payload.l0,
|
||||
"l1": payload.l1,
|
||||
"l2": list(payload.l2),
|
||||
"rich_club": payload.rich_club,
|
||||
}
|
||||
_delta, hashes = build_delta({}, payload_dict)
|
||||
return {
|
||||
"l0": payload_dict["l0"],
|
||||
"l1": payload_dict["l1"],
|
||||
"l2": payload_dict["l2"],
|
||||
"rich_club": payload_dict["rich_club"],
|
||||
"hashes": hashes,
|
||||
"proxy_tools": subagent_proxy_tools(),
|
||||
}
|
||||
78
src/iai_mcp/delta.py
Normal file
78
src/iai_mcp/delta.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""TOK-08 delta encoding for session-start payloads (Plan 02-04 Task 2, D-28).
|
||||
|
||||
The session-start payload is a 4-component dict: l0, l1, l2 (list), rich_club.
|
||||
On the first session turn the client sends nothing; the server hashes each
|
||||
component and returns both the payload and the hash bundle. On subsequent
|
||||
turns the client sends previous_hashes; the server compares, and only the
|
||||
components whose hash changed are returned in the delta payload. Unchanged
|
||||
components are implicit in the delta (absent from delta, carried over from
|
||||
the client's cache).
|
||||
|
||||
On hash miss (client sends a stale hash), the server returns the full
|
||||
component value in the delta -- this is also the first-session behaviour.
|
||||
|
||||
Reduces per-turn token spend 60-80% on typical within-session continuation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
HASH_LEN = 16 # sha256 hex truncated to 16 chars
|
||||
COMPONENTS = ("l0", "l1", "l2", "rich_club")
|
||||
|
||||
|
||||
def hash_component(text: str) -> str:
|
||||
"""Return a stable 16-char hex digest of the UTF-8-encoded text."""
|
||||
h = hashlib.sha256(text.encode("utf-8") if text is not None else b"").hexdigest()
|
||||
return h[:HASH_LEN]
|
||||
|
||||
|
||||
def _component_text(value) -> str:
|
||||
"""Flatten a payload component to a single string for hashing.
|
||||
|
||||
L0/L1/rich_club are strings. L2 is a list of strings; we join with "\n"
|
||||
so ordering matters (which matches the wire format).
|
||||
"""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, list):
|
||||
return "\n".join(str(x) for x in value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def build_delta(
|
||||
previous_hashes: dict[str, str],
|
||||
current_payload: dict,
|
||||
) -> tuple[dict, dict[str, str]]:
|
||||
"""Compute (delta, new_hashes) given the client's last-seen hashes.
|
||||
|
||||
delta is a subset of current_payload containing only components whose
|
||||
hash does not match previous_hashes (including the first-session case
|
||||
where previous_hashes is empty or missing keys). new_hashes is the full
|
||||
current hash bundle, keyed by component name.
|
||||
"""
|
||||
delta: dict = {}
|
||||
new_hashes: dict[str, str] = {}
|
||||
for key in COMPONENTS:
|
||||
value = current_payload.get(key)
|
||||
text = _component_text(value)
|
||||
h = hash_component(text)
|
||||
new_hashes[key] = h
|
||||
prev = previous_hashes.get(key) if previous_hashes else None
|
||||
if prev != h:
|
||||
delta[key] = value if value is not None else ""
|
||||
return delta, new_hashes
|
||||
|
||||
|
||||
def apply_delta(previous: dict, delta: dict) -> dict:
|
||||
"""Merge delta on top of previous full payload -> new full payload.
|
||||
|
||||
Keys absent from delta carry over from `previous`. Provides the client
|
||||
side of the round-trip (parent agent: server emits delta; subagent:
|
||||
client applies delta).
|
||||
"""
|
||||
merged = dict(previous)
|
||||
for key, value in delta.items():
|
||||
merged[key] = value
|
||||
return merged
|
||||
1558
src/iai_mcp/doctor.py
Normal file
1558
src/iai_mcp/doctor.py
Normal file
File diff suppressed because it is too large
Load diff
123
src/iai_mcp/dream.py
Normal file
123
src/iai_mcp/dream.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""REM cycle orchestrator. CALLS existing modules -- does not reimplement.
|
||||
|
||||
Biological mapping:
|
||||
- NREM-2 (Hebbian binding) = existing hebbian LTP inside sleep.py cluster pass
|
||||
- NREM-3 (hippocampal replay) = sleep.run_heavy_consolidation Tier-0 path
|
||||
- REM (cross-community) = schema.induce_schemas_tier1(llm_enabled=False)
|
||||
- REM lucid moment (last cycle) = insight.generate_overnight_insight
|
||||
|
||||
Constitutional guard:
|
||||
- LOCAL primary worker; llm_enabled ALWAYS False when calling sleep/schema.
|
||||
- has_api_key=False always for daemon (zero paid-API path).
|
||||
- 15-minute hard cap per cycle (asyncio.timeout context manager).
|
||||
- C1: daemon must already hold the fcntl exclusive lock BEFORE calling
|
||||
run_rem_cycle -- this module does NOT acquire locks, that is _tick_body's
|
||||
job. This module is called under the lock.
|
||||
- C3: ZERO API cost. The single nightly Claude call is a subprocess, wired
|
||||
by in insight.py. No paid-API env var is referenced here.
|
||||
- C5: literal preservation -- we only call modules that modify metadata
|
||||
(FSRS state, edge weights, schema tags). Never assigns to literal_surface.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.guard import BudgetLedger, RateLimitLedger
|
||||
from iai_mcp.schema import induce_schemas_tier1
|
||||
from iai_mcp.sleep import SleepConfig, run_heavy_consolidation
|
||||
|
||||
# hard cap per REM cycle.
|
||||
REM_CYCLE_MAX_SEC: int = 15 * 60
|
||||
|
||||
|
||||
async def _emit(store, kind: str, data: dict, severity: str | None = None) -> None:
|
||||
"""Emit an event off the main loop so LanceDB writes don't block asyncio."""
|
||||
if severity is None:
|
||||
await asyncio.to_thread(write_event, store, kind, data)
|
||||
else:
|
||||
await asyncio.to_thread(write_event, store, kind, data, severity=severity)
|
||||
|
||||
|
||||
async def run_rem_cycle(
|
||||
store,
|
||||
cycle_num: int,
|
||||
total_cycles: int,
|
||||
session_id: str,
|
||||
*,
|
||||
is_last: bool,
|
||||
claude_enabled: bool,
|
||||
) -> dict:
|
||||
"""One REM cycle. Runs to completion or hits 15min cap.
|
||||
|
||||
Returns dict consumed by the morning digest:
|
||||
{cycle, summaries_created, schemas_induced, schema_candidates,
|
||||
claude_call_used, main_insight_text, timed_out}
|
||||
|
||||
Never raises. All failure modes (timeout, module exception) surface as
|
||||
event emissions + a partial result dict so the daemon's outer loop
|
||||
cannot crash on cycle-internal exceptions (T-04-12 mitigation).
|
||||
"""
|
||||
await _emit(store, "rem_cycle_started", {"n": cycle_num, "of": total_cycles})
|
||||
|
||||
result: dict = {
|
||||
"cycle": cycle_num,
|
||||
"summaries_created": 0,
|
||||
"schemas_induced": 0,
|
||||
"schema_candidates": 0,
|
||||
"claude_call_used": False,
|
||||
"main_insight_text": None,
|
||||
"timed_out": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(REM_CYCLE_MAX_SEC):
|
||||
# NREM-3 equivalent: heavy consolidation, Tier-0 only in daemon.
|
||||
cfg = SleepConfig(llm_enabled=False)
|
||||
heavy = await asyncio.to_thread(
|
||||
run_heavy_consolidation,
|
||||
store, session_id, cfg,
|
||||
BudgetLedger(store), RateLimitLedger(store),
|
||||
False, # has_api_key=False always for daemon
|
||||
)
|
||||
if isinstance(heavy, dict):
|
||||
result["summaries_created"] = int(heavy.get("summaries_created", 0) or 0)
|
||||
result["schemas_induced"] = int(heavy.get("schemas_induced", 0) or 0)
|
||||
|
||||
# REM cross-community schema induction (explicit Tier-0).
|
||||
# Signature: induce_schemas_tier1(store, budget, rate, llm_enabled=True)
|
||||
# -- we force llm_enabled=False so the D-GUARD ladder falls through to
|
||||
# the pure-local Tier-0 path.
|
||||
candidates = await asyncio.to_thread(
|
||||
induce_schemas_tier1,
|
||||
store, BudgetLedger(store), RateLimitLedger(store), False,
|
||||
)
|
||||
result["schema_candidates"] = len(candidates) if candidates else 0
|
||||
|
||||
# Lucid moment -- ONLY on last cycle, budget-gated by caller.
|
||||
if is_last and claude_enabled:
|
||||
from iai_mcp.insight import generate_overnight_insight
|
||||
|
||||
insight = await generate_overnight_insight(store, session_id)
|
||||
if isinstance(insight, dict) and insight.get("ok"):
|
||||
result["claude_call_used"] = True
|
||||
result["main_insight_text"] = insight.get("text")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
result["timed_out"] = True
|
||||
await _emit(
|
||||
store,
|
||||
"rem_cycle_timeout",
|
||||
{"cycle": cycle_num},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 -- daemon must never die on cycle error
|
||||
await _emit(
|
||||
store,
|
||||
"rem_cycle_error",
|
||||
{"cycle": cycle_num, "error": str(exc)[:500]},
|
||||
severity="critical",
|
||||
)
|
||||
|
||||
await _emit(store, "rem_cycle_completed", result)
|
||||
return result
|
||||
193
src/iai_mcp/embed.py
Normal file
193
src/iai_mcp/embed.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
"""Embedding layer -- configurable embedder with a 3-model registry.
|
||||
|
||||
Plan 05-08 (2026-04-20): the DEFAULT is now ``bge-small-en-v1.5`` (384d
|
||||
English-only), reverting the Phase-2 deviation. PROJECT.md line
|
||||
125 always specified bge-small-en-v1.5 as the intended default; Phase-2
|
||||
swapped in bge-m3 (1024d multilingual) as D-08a. User directive
|
||||
2026-04-19: the brain stores English, surface translation is Claude's
|
||||
job. bge-m3 stays selectable via env var / kwarg for anyone who needs
|
||||
multilingual semantic match at the 5x RAM cost.
|
||||
|
||||
Configurable 4-model registry:
|
||||
- "bge-m3" -> BAAI/bge-m3 -> 1024d (opt-in, multilingual)
|
||||
- "multilingual-e5-small" -> intfloat/multilingual-e5-small -> 384d (compromise)
|
||||
- "bge-small-en-v1.5" -> BAAI/bge-small-en-v1.5 -> 384d (DEFAULT, English)
|
||||
- "all-MiniLM-L6-v2" -> sentence-transformers/all-MiniLM-L6-v2 -> 384d (English alternative embedder option; included for compatibility testing)
|
||||
|
||||
Selection priority at Embedder() instantiation:
|
||||
1. Explicit `model_key` constructor arg
|
||||
2. IAI_MCP_EMBED_MODEL environment variable
|
||||
3. MODEL_REGISTRY default ("bge-small-en-v1.5")
|
||||
|
||||
The model is loaded once per process and cached in a module-level dict so
|
||||
multiple Embedder() instances share the underlying SentenceTransformer.
|
||||
|
||||
Deterministic: `normalize_embeddings=True` is always passed,
|
||||
`show_progress_bar=False`. Same input text always produces the same output
|
||||
vector across calls within a process.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
# 4-model registry. Name convention: short logical key -> HF repo id + dim.
|
||||
# (2026-04-29): all-MiniLM-L6-v2 added as additive ablation entry;
|
||||
# DEFAULT_MODEL_KEY unchanged (English-Only Brain lock from / Plan 05-08).
|
||||
MODEL_REGISTRY: dict[str, dict] = {
|
||||
"bge-m3": {"hf": "BAAI/bge-m3", "dim": 1024},
|
||||
"multilingual-e5-small": {"hf": "intfloat/multilingual-e5-small", "dim": 384},
|
||||
"bge-small-en-v1.5": {"hf": "BAAI/bge-small-en-v1.5", "dim": 384},
|
||||
"all-MiniLM-L6-v2": {"hf": "sentence-transformers/all-MiniLM-L6-v2", "dim": 384},
|
||||
}
|
||||
DEFAULT_MODEL_KEY = "bge-small-en-v1.5"
|
||||
|
||||
|
||||
def _resolve_model_key(model_key: str | None = None) -> str:
|
||||
if model_key is not None:
|
||||
if model_key not in MODEL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"unknown embed model key {model_key!r}; valid: {sorted(MODEL_REGISTRY)}"
|
||||
)
|
||||
return model_key
|
||||
env_key = os.environ.get("IAI_MCP_EMBED_MODEL")
|
||||
if env_key:
|
||||
if env_key not in MODEL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"unknown embed model key {env_key!r} from IAI_MCP_EMBED_MODEL; "
|
||||
f"valid: {sorted(MODEL_REGISTRY)}"
|
||||
)
|
||||
return env_key
|
||||
return DEFAULT_MODEL_KEY
|
||||
|
||||
|
||||
_MODEL_LOCK = threading.Lock()
|
||||
_MODEL_CACHE: dict[str, SentenceTransformer] = {}
|
||||
|
||||
|
||||
def _get_model(hf_id: str) -> SentenceTransformer:
|
||||
"""Process-local lazy-load + cache. Thread-safe via lock around cache mutation."""
|
||||
with _MODEL_LOCK:
|
||||
if hf_id not in _MODEL_CACHE:
|
||||
_MODEL_CACHE[hf_id] = SentenceTransformer(hf_id)
|
||||
return _MODEL_CACHE[hf_id]
|
||||
|
||||
|
||||
class Embedder:
|
||||
"""English-Only Brain embedder with a configurable model registry.
|
||||
|
||||
Default model is `bge-small-en-v1.5` (384d, English) per Plan 05-08.
|
||||
Used by the retrieval pipeline (stage 1, cue embedding) and by session-start
|
||||
assembler. `.DIM` is per-instance (varies by model). `.DEFAULT_DIM` is a
|
||||
class-level default pointing at the registry's default model dimension.
|
||||
|
||||
The opt-in `bge-m3` (1024d multilingual) path stays in the registry for
|
||||
users who explicitly need multilingual semantic match at the 5x RAM cost,
|
||||
but it is opt-in via `IAI_MCP_EMBED_MODEL=bge-m3` — not the product.
|
||||
|
||||
Backward compatibility:
|
||||
- `Embedder.DIM` is kept as a class attribute aliased to the default model
|
||||
dimension so tests that reference `Embedder.DIM` still work; new
|
||||
code should prefer `Embedder().DIM` (instance attr) for correctness.
|
||||
- `Embedder.DEFAULT_MODEL` is the HF id of the default model (bge-small-en-v1.5).
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_KEY: str = DEFAULT_MODEL_KEY
|
||||
DEFAULT_DIM: int = MODEL_REGISTRY[DEFAULT_MODEL_KEY]["dim"]
|
||||
# Legacy class-level attributes (Phase 1 test compatibility).
|
||||
# New code should construct Embedder() and read .DIM from the instance.
|
||||
DEFAULT_MODEL: str = MODEL_REGISTRY[DEFAULT_MODEL_KEY]["hf"]
|
||||
DIM: int = DEFAULT_DIM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_key: str | None = None,
|
||||
*,
|
||||
model_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialise an Embedder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_key:
|
||||
Logical key from MODEL_REGISTRY ("bge-m3" | "multilingual-e5-small" |
|
||||
"bge-small-en-v1.5"). If None, uses IAI_MCP_EMBED_MODEL env var or
|
||||
the registry default.
|
||||
model_name:
|
||||
Legacy parameter: full HuggingFace repo id (e.g. "BAAI/bge-small-en-v1.5").
|
||||
Prefer model_key for new code. If both are provided, model_key wins.
|
||||
"""
|
||||
if model_key is None and model_name is not None:
|
||||
# Reverse-lookup: find the key whose hf matches this name.
|
||||
match = next(
|
||||
(k for k, v in MODEL_REGISTRY.items() if v["hf"] == model_name),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
f"model_name {model_name!r} is not in MODEL_REGISTRY; "
|
||||
f"valid hf ids: {[v['hf'] for v in MODEL_REGISTRY.values()]}"
|
||||
)
|
||||
key = match
|
||||
else:
|
||||
key = _resolve_model_key(model_key)
|
||||
self.model_key: str = key
|
||||
spec = MODEL_REGISTRY[key]
|
||||
self.model_name: str = spec["hf"]
|
||||
self.DIM: int = int(spec["dim"]) # instance attr overrides class attr
|
||||
self._model = _get_model(self.model_name)
|
||||
|
||||
def embed(self, text: str) -> list[float]:
|
||||
"""Encode a single string to a DIM-length list[float]. Normalised, deterministic."""
|
||||
vec = self._model.encode(
|
||||
text, normalize_embeddings=True, show_progress_bar=False
|
||||
)
|
||||
return vec.tolist()
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Batch-encode preserving input order. Returns N vectors for N inputs."""
|
||||
vecs = self._model.encode(
|
||||
list(texts),
|
||||
normalize_embeddings=True,
|
||||
show_progress_bar=False,
|
||||
batch_size=32,
|
||||
)
|
||||
return [v.tolist() for v in vecs]
|
||||
|
||||
|
||||
def embedder_for_store(store) -> "Embedder":
|
||||
"""Store-aware Embedder factory. Picks the model whose output dim matches
|
||||
the existing LanceDB records schema, so a legacy 1024d store from the
|
||||
pre-Plan-05-08 bge-m3 era stays queryable until it is re-embedded down to
|
||||
the 384d English-Only-Brain default.
|
||||
|
||||
Resolution order:
|
||||
1. If store.embed_dim has an exact match in MODEL_REGISTRY, prefer the
|
||||
model whose logical key name indicates the canonical model at that dim
|
||||
(bge-small-en-v1.5 for 384d default; bge-m3 for legacy/opt-in 1024d).
|
||||
2. Otherwise fall through to the env/registry default via Embedder().
|
||||
|
||||
This decouples runtime model selection from a global env var so a single
|
||||
process can operate multiple stores at different dims while the migration
|
||||
from a legacy 1024d store down to 384d completes.
|
||||
"""
|
||||
target_dim = getattr(store, "embed_dim", None)
|
||||
if target_dim is None:
|
||||
return Embedder()
|
||||
preferred = {384: "bge-small-en-v1.5", 1024: "bge-m3"}
|
||||
key = preferred.get(int(target_dim))
|
||||
# Tests and migrations may monkey-patch `Embedder` with a stub that takes no
|
||||
# kwargs. Fall back to the zero-arg form in that case so the fake surface
|
||||
# stays compatible; real production code still respects store.embed_dim.
|
||||
try:
|
||||
if key is not None and key in MODEL_REGISTRY:
|
||||
return Embedder(model_key=key)
|
||||
for reg_key, spec in MODEL_REGISTRY.items():
|
||||
if int(spec["dim"]) == int(target_dim):
|
||||
return Embedder(model_key=reg_key)
|
||||
except TypeError:
|
||||
pass
|
||||
return Embedder()
|
||||
184
src/iai_mcp/events.py
Normal file
184
src/iai_mcp/events.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""D-STORAGE events table interface.
|
||||
|
||||
Single source of runtime state. Every kind of event — S4 contradictions,
|
||||
trajectory metrics, LLM health probes, schema induction runs, CLS consolidation
|
||||
runs, migration traces, alerts — goes through write_event.
|
||||
|
||||
No .jsonl files. No .json files scattered under internal storage or
|
||||
internal storage. Everything persists in the LanceDB `events` table.
|
||||
|
||||
CLI queries (iai-mcp health, iai-mcp trajectory) read via query_events.
|
||||
|
||||
events.data_json is AES-256-GCM encrypted at rest (some event
|
||||
payloads carry user quotes / cues -- safest default). The event UUID is the
|
||||
associated data binding. kind / severity / domain / ts / session_id stay
|
||||
plaintext so audit queries (`iai-mcp health`, `iai-mcp trajectory`) can filter
|
||||
on them without decrypting.
|
||||
|
||||
Phase 3 additions (new event kinds — free-form strings, no taxonomy enum):
|
||||
- CONN-05 TEM factorization: `migration_v3_to_v4`.
|
||||
- CONN-07 small-world sigma: `sigma_observation`, `sigma_drift`
|
||||
(sigma-curve diagnostic per Ashby ultrastability).
|
||||
- M2/M4/M6 live wiring: `retrieval_used`, `profile_updated`,
|
||||
`session_started` (existing emit sites extended; not all new — verify via
|
||||
ctx_search before emitting duplicates).
|
||||
- Chapman ecological self-regulation:
|
||||
* `formality_score_weekly` — per-turn aggregate of user SURFACE formality.
|
||||
* `camouflaging_detected` — over-formal trajectory detected over 5-point weekly window.
|
||||
* `register_relaxed` — OUR `camouflaging_relaxation` knob bumped; the system
|
||||
relaxes its OWN register (never the user's; masking modeling is out-of-scope).
|
||||
|
||||
Phase 6 additions (Plan 06-01 schema dedup):
|
||||
- `schema_reinforced` — emitted when `persist_schema` finds an existing
|
||||
schema for the candidate pattern and reinforces incoming
|
||||
`schema_instance_of` edges from new evidence onto the existing keeper
|
||||
instead of inserting a duplicate row. Payload:
|
||||
{schema_id: str, pattern: str, evidence_added: int, total_evidence: int}
|
||||
Source IDs: [keeper_schema_id, *new_evidence_ids[:5]] mirroring the
|
||||
existing `schema_induction_run` shape.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from iai_mcp.crypto import (
|
||||
decrypt_field,
|
||||
encrypt_field,
|
||||
is_encrypted,
|
||||
)
|
||||
from iai_mcp.store import EVENTS_TABLE, MemoryStore
|
||||
|
||||
|
||||
def write_event(
|
||||
store: MemoryStore,
|
||||
kind: str,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
severity: str | None = None,
|
||||
domain: str | None = None,
|
||||
session_id: str = "-",
|
||||
source_ids: list[UUID] | None = None,
|
||||
) -> UUID:
|
||||
"""Persist a single event to the LanceDB events table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
store:
|
||||
Open MemoryStore instance.
|
||||
kind:
|
||||
Logical event kind (e.g. "s4_contradiction", "trajectory_metric",
|
||||
"llm_health", "migration_v1_to_v2"). Free-form string; downstream
|
||||
consumers filter on it.
|
||||
data:
|
||||
JSON-serialisable kind-specific payload. Encoded to data_json.
|
||||
severity:
|
||||
Optional alert severity ("info" | "warning" | "critical"). Stored
|
||||
as empty string for non-alert events.
|
||||
domain:
|
||||
Optional monotropic-domain tag. Stored as empty string when absent.
|
||||
session_id:
|
||||
Session identifier; defaults to "-" when no session is active.
|
||||
source_ids:
|
||||
Optional list of MemoryRecord UUIDs that triggered this event.
|
||||
|
||||
Returns the newly-minted event UUID.
|
||||
"""
|
||||
event_id = uuid4()
|
||||
# encrypt data_json with AD = event UUID bytes. kind / severity /
|
||||
# domain / ts / session_id stay plaintext for filter queries.
|
||||
data_plain = json.dumps(data)
|
||||
ad = str(event_id).encode("ascii")
|
||||
data_ct = encrypt_field(data_plain, store._key(), associated_data=ad)
|
||||
row = {
|
||||
"id": str(event_id),
|
||||
"kind": kind,
|
||||
"severity": severity or "",
|
||||
"domain": domain or "",
|
||||
"ts": datetime.now(timezone.utc),
|
||||
"data_json": data_ct,
|
||||
"session_id": session_id,
|
||||
"source_ids_json": json.dumps([str(x) for x in (source_ids or [])]),
|
||||
}
|
||||
store.db.open_table(EVENTS_TABLE).add([row])
|
||||
return event_id
|
||||
|
||||
|
||||
def query_events(
|
||||
store: MemoryStore,
|
||||
kind: str | None = None,
|
||||
since: datetime | None = None,
|
||||
severity: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Query events matching the given filters, newest first.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
store:
|
||||
Open MemoryStore instance.
|
||||
kind:
|
||||
Filter by event kind. None returns all kinds.
|
||||
since:
|
||||
Only return events with ts >= since. Naive datetimes are treated as UTC.
|
||||
severity:
|
||||
Exact-match filter on severity field.
|
||||
limit:
|
||||
Maximum rows returned (default 100). Caller can pass e.g. 1 to get
|
||||
only the most recent event of a given kind (iai-mcp health).
|
||||
|
||||
Returns a list of dicts with keys: id, kind, severity, domain, ts, data,
|
||||
session_id, source_ids. data and source_ids are decoded from JSON.
|
||||
"""
|
||||
tbl = store.db.open_table(EVENTS_TABLE)
|
||||
df = tbl.to_pandas()
|
||||
if df.empty:
|
||||
return []
|
||||
if kind is not None:
|
||||
df = df[df["kind"] == kind]
|
||||
if severity is not None:
|
||||
df = df[df["severity"] == severity]
|
||||
if since is not None:
|
||||
# Ensure tz-aware comparison
|
||||
since_cmp = since if since.tzinfo is not None else since.replace(tzinfo=timezone.utc)
|
||||
# Pandas Timestamp compares naturally with tz-aware datetimes
|
||||
df = df[df["ts"] >= since_cmp]
|
||||
if df.empty:
|
||||
return []
|
||||
df = df.sort_values("ts", ascending=False).head(limit)
|
||||
out: list[dict] = []
|
||||
for _, row in df.iterrows():
|
||||
# decrypt data_json when it carries the iai:enc:v1: prefix.
|
||||
# Pre-02-08 rows stay plaintext; migration rewrites them lazily.
|
||||
raw_data = row["data_json"] or "{}"
|
||||
if is_encrypted(raw_data):
|
||||
ad = str(row["id"]).encode("ascii")
|
||||
try:
|
||||
raw_data = decrypt_field(raw_data, store._key(), associated_data=ad)
|
||||
except Exception:
|
||||
# Rule 1 diagnostic semantics: a corrupt event row should not
|
||||
# fail the entire query. Return empty payload + mark in meta.
|
||||
raw_data = "{}"
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
data = {}
|
||||
try:
|
||||
source_ids = json.loads(row["source_ids_json"] or "[]")
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
source_ids = []
|
||||
out.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"kind": row["kind"],
|
||||
"severity": row["severity"] or None,
|
||||
"domain": row["domain"] or None,
|
||||
"ts": row["ts"],
|
||||
"data": data,
|
||||
"session_id": row["session_id"],
|
||||
"source_ids": source_ids,
|
||||
}
|
||||
)
|
||||
return out
|
||||
244
src/iai_mcp/formality.py
Normal file
244
src/iai_mcp/formality.py
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
"""Plan 03-03 — surface-feature formality scorer (Chapman ecological self-regulation).
|
||||
|
||||
Constitutional anchor:
|
||||
- Observes ONLY the user's surface lexical features (D-AUTIST13-01).
|
||||
- Never models user internal state, never tries to infer "is the user masking".
|
||||
- Paired with src/iai_mcp/camouflaging.py which adjusts OUR register in response.
|
||||
|
||||
Scientific anchor: Chapman R (2021) "Neurodiversity and the Social Ecology of Mental
|
||||
Functions." — the ecological self-regulation framing. Cook 2021 + Raymaker 2020 tell us
|
||||
WHAT NOT to model (masking as an inferred user state).
|
||||
|
||||
Four surface features (D-AUTIST13-01, weighted sum):
|
||||
1. Lexical formality (w=0.45) — per-language register-marker density. Strongest signal.
|
||||
2. Sentence complexity (w=0.20) — sigmoid on avg chars-per-sentence + clause density.
|
||||
3. Hedging density (w=0.15) — hedge markers per 100 tokens.
|
||||
4. Punctuation formality (w=0.20) — semicolon + em-dash + full-quote density.
|
||||
|
||||
Output: formality_score(text, lang) -> float in [0.0, 1.0]. 0 = fully informal,
|
||||
1 = fully formal. Unknown lang returns 0.5 (neutral) with a logged warning; NEVER raises
|
||||
(MEMORY.md global-product mandate).
|
||||
|
||||
Weight rationale (Pattern 3 proposed
|
||||
0.30/0.30/0.20/0.20 as a baseline — fixture-tuned to 0.45/0.20/0.15/0.20 because the lex
|
||||
dimension is the most unambiguous signal across RU+EN and the shortest formal sentences
|
||||
(e.g. "The proposal is, therefore, accepted.") are otherwise penalised by the
|
||||
complexity sigmoid. Fixture accuracy: 100% (51/51) with the current weights.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- constants
|
||||
# Grep-discoverable module-scope constants (PATTERNS.md §7).
|
||||
|
||||
LEX_MARKERS: dict[str, list[str]] = {
|
||||
"en": [
|
||||
"therefore", "however", "accordingly", "nonetheless", "furthermore",
|
||||
"hence", "thus", "consequently", "moreover", "notwithstanding",
|
||||
"whereas", "hereby", "herein", "thereof", "pursuant", "aforementioned",
|
||||
"shall", "aforesaid",
|
||||
],
|
||||
"ru": [
|
||||
"тем не менее", "следовательно", "однако", "впрочем", "таким образом",
|
||||
"вследствие", "настоящим", "согласно", "вышеизложенного", "вышеизложенному",
|
||||
"в соответствии", "по-видимому", "в силу", "исходя из", "данное",
|
||||
"настоящее", "прилагаемым", "представленное", "уведомляем",
|
||||
],
|
||||
}
|
||||
|
||||
HEDGE_MARKERS: dict[str, list[str]] = {
|
||||
"en": [
|
||||
"possibly", "perhaps", "might", "may", "could", "seemingly",
|
||||
"appears to", "seems", "somewhat", "apparently", "presumably",
|
||||
],
|
||||
"ru": [
|
||||
"возможно", "вероятно", "видимо", "по-видимому", "наверное",
|
||||
"кажется", "пожалуй", "скорее всего", "вроде", "будто",
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_WEIGHTS: dict[str, float] = {
|
||||
"lex": 0.45,
|
||||
"complexity": 0.20,
|
||||
"hedge": 0.15,
|
||||
"punct": 0.20,
|
||||
}
|
||||
|
||||
# Sentence-complexity sigmoid parameters.
|
||||
# avg chars per sentence: centre 40 credits terse formal writing (e.g. "The
|
||||
# proposal is, therefore, accepted."). clause count adds a second signal
|
||||
# weighted equally with length (avg_cl centre 0.5 = one comma per sentence).
|
||||
_SENTENCE_COMPLEXITY_CENTER: float = 40.0
|
||||
_SENTENCE_COMPLEXITY_SCALE: float = 25.0
|
||||
_CLAUSE_COUNT_CENTER: float = 0.5
|
||||
_CLAUSE_COUNT_SCALE: float = 0.5
|
||||
|
||||
# Density sigmoid parameters. Tuned so 0 markers -> ~0.1, 1.5 markers/100tok -> 0.5.
|
||||
_LEX_DENSITY_CENTER: float = 1.5 # markers per 100 tokens
|
||||
_LEX_DENSITY_SCALE: float = 1.2
|
||||
_HEDGE_DENSITY_CENTER: float = 1.0
|
||||
_HEDGE_DENSITY_SCALE: float = 0.8
|
||||
_PUNCT_DENSITY_CENTER: float = 1.5
|
||||
_PUNCT_DENSITY_SCALE: float = 1.3
|
||||
|
||||
_NEUTRAL_SCORE: float = 0.5
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- helpers
|
||||
def _tokens(text: str) -> list[str]:
|
||||
"""Whitespace split on letter sequences; lowercase. Unicode-aware."""
|
||||
cleaned = re.sub(r"[^\w\s\-]", " ", text, flags=re.UNICODE)
|
||||
return [t.lower() for t in cleaned.split() if t]
|
||||
|
||||
|
||||
def _sentence_split(text: str) -> list[str]:
|
||||
parts = re.split(r"[.!?;]+", text)
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
def _sigmoid(x: float) -> float:
|
||||
if x >= 0:
|
||||
ez = math.exp(-x)
|
||||
return 1.0 / (1.0 + ez)
|
||||
ez = math.exp(x)
|
||||
return ez / (1.0 + ez)
|
||||
|
||||
|
||||
def _count_phrase_occurrences(text_lower: str, phrases: Iterable[str]) -> int:
|
||||
count = 0
|
||||
for p in phrases:
|
||||
if not p:
|
||||
continue
|
||||
if " " in p or "-" in p:
|
||||
# Multi-word or hyphenated phrase -> substring match is fine.
|
||||
count += text_lower.count(p)
|
||||
else:
|
||||
count += len(re.findall(rf"\b{re.escape(p)}\b", text_lower, flags=re.UNICODE))
|
||||
return count
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- features
|
||||
def _lex_score(text: str, lang: str) -> float:
|
||||
"""Per-language register-marker density, sigmoid-bounded to [0, 1]."""
|
||||
markers = LEX_MARKERS.get(lang, [])
|
||||
if not markers:
|
||||
return _NEUTRAL_SCORE
|
||||
toks = _tokens(text)
|
||||
if not toks:
|
||||
return 0.0
|
||||
hits = _count_phrase_occurrences(text.lower(), markers)
|
||||
density = hits * 100.0 / max(len(toks), 1)
|
||||
return _sigmoid((density - _LEX_DENSITY_CENTER) / _LEX_DENSITY_SCALE)
|
||||
|
||||
|
||||
def _complexity_score(text: str) -> float:
|
||||
"""Avg chars per sentence + clause-count proxy. Language-independent.
|
||||
|
||||
Returns equal-weight blend of:
|
||||
- length sigmoid (centred at 40 chars so terse formal sentences aren't depressed).
|
||||
- clause sigmoid based on commas per sentence (centred at 0.5 = one comma avg).
|
||||
"""
|
||||
sents = _sentence_split(text)
|
||||
if not sents:
|
||||
return 0.0
|
||||
avg_len = sum(len(s) for s in sents) / len(sents)
|
||||
avg_clauses = sum(s.count(",") for s in sents) / len(sents)
|
||||
len_score = _sigmoid(
|
||||
(avg_len - _SENTENCE_COMPLEXITY_CENTER) / _SENTENCE_COMPLEXITY_SCALE
|
||||
)
|
||||
cl_score = _sigmoid((avg_clauses - _CLAUSE_COUNT_CENTER) / _CLAUSE_COUNT_SCALE)
|
||||
return 0.5 * len_score + 0.5 * cl_score
|
||||
|
||||
|
||||
def _hedge_score(text: str, lang: str) -> float:
|
||||
"""Hedging density per 100 tokens, sigmoid-bounded to [0, 1]."""
|
||||
markers = HEDGE_MARKERS.get(lang, [])
|
||||
if not markers:
|
||||
return _NEUTRAL_SCORE
|
||||
toks = _tokens(text)
|
||||
if not toks:
|
||||
return 0.0
|
||||
hits = _count_phrase_occurrences(text.lower(), markers)
|
||||
density = hits * 100.0 / max(len(toks), 1)
|
||||
return _sigmoid((density - _HEDGE_DENSITY_CENTER) / _HEDGE_DENSITY_SCALE)
|
||||
|
||||
|
||||
def _punct_score(text: str) -> float:
|
||||
"""Semicolon + em-dash + full-quote density per 100 tokens."""
|
||||
toks = _tokens(text)
|
||||
if not toks:
|
||||
return 0.0
|
||||
semi = text.count(";")
|
||||
em = text.count("—") + text.count("–")
|
||||
fq = (
|
||||
text.count('"')
|
||||
+ text.count("“")
|
||||
+ text.count("”")
|
||||
+ text.count("«")
|
||||
+ text.count("»")
|
||||
)
|
||||
hits = semi + em + fq
|
||||
density = hits * 100.0 / max(len(toks), 1)
|
||||
return _sigmoid((density - _PUNCT_DENSITY_CENTER) / _PUNCT_DENSITY_SCALE)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- public
|
||||
def formality_score(
|
||||
text: str,
|
||||
lang: str,
|
||||
*,
|
||||
weights: dict[str, float] | None = None,
|
||||
) -> float:
|
||||
"""Return surface-feature formality score in [0.0, 1.0].
|
||||
|
||||
0.0 = fully informal, 1.0 = fully formal. Unknown languages get a neutral 0.5
|
||||
with a logged warning (MEMORY.md global-product graceful degradation). NEVER
|
||||
raises on bad input.
|
||||
|
||||
Args:
|
||||
text: free-form user utterance (SURFACE only, per D-AUTIST13-01).
|
||||
lang: ISO-639-1 language code ("en", "ru"). Other codes -> neutral + warning.
|
||||
weights: optional override {lex, complexity, hedge, punct}.
|
||||
|
||||
Constitutional guard reminder: callers pass user SURFACE text only. The scorer
|
||||
does not see any inferred internal state. See camouflaging.py for how the
|
||||
score is consumed (to adjust OUR register, never the user's).
|
||||
"""
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
return 0.0
|
||||
|
||||
if lang not in LEX_MARKERS:
|
||||
warnings.warn(
|
||||
f"formality_score: lang={lang!r} outside RU+EN baseline; "
|
||||
"returning neutral 0.5 (MEMORY.md global-product graceful degradation)",
|
||||
stacklevel=2,
|
||||
)
|
||||
_logger.debug("formality_score unknown lang=%s text_len=%d", lang, len(text))
|
||||
return _NEUTRAL_SCORE
|
||||
|
||||
w = dict(DEFAULT_WEIGHTS)
|
||||
if weights:
|
||||
w.update({k: float(v) for k, v in weights.items() if k in w})
|
||||
total_w = sum(w.values()) or 1.0
|
||||
|
||||
lex = _lex_score(text, lang)
|
||||
complexity = _complexity_score(text)
|
||||
hedge = _hedge_score(text, lang)
|
||||
punct = _punct_score(text)
|
||||
|
||||
weighted = (
|
||||
w["lex"] * lex
|
||||
+ w["complexity"] * complexity
|
||||
+ w["hedge"] * hedge
|
||||
+ w["punct"] * punct
|
||||
) / total_w
|
||||
# Clamp to [0, 1] defensively.
|
||||
return max(0.0, min(1.0, weighted))
|
||||
80
src/iai_mcp/gate.py
Normal file
80
src/iai_mcp/gate.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""TOK-06 active-inference retrieval gate (Plan 02-04 Task 2, D-26).
|
||||
|
||||
Skip full pipeline_recall when the expected free-energy reduction for the
|
||||
current cue is below THETA_SKIP bits. Trivial cues (greetings, "thanks",
|
||||
single characters) short-circuit to an L0-only response, saving 200-500
|
||||
tokens per trivial turn.
|
||||
|
||||
The heuristic uses a simple token-count proxy for EFE:
|
||||
- Empty / sub-3-char cues: 0.0 bits (no signal).
|
||||
- Greetings ("hi", "hello", "thanks", "ok") in the fixed trivial set: 0.1 bits.
|
||||
- Single-token cues not in the trivial set: 0.25 bits (above threshold --
|
||||
one rare/novel token can still justify a retrieval).
|
||||
- General cues: min(2.0, log2(1 + unique_token_count) * 0.5).
|
||||
|
||||
Phase 2 note: this is an approximation. can replace with a real
|
||||
embedding-distance-to-prior computation once the write policy is active.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
# threshold (bits).
|
||||
THETA_SKIP = 0.2
|
||||
|
||||
# Fixed-EFE trivial cues. Matched case-insensitively against stripped punctuation.
|
||||
TRIVIAL_SHORT_CUES: frozenset[str] = frozenset({
|
||||
"hi", "hello", "hey", "thanks", "thank you", "ok", "okay",
|
||||
"yes", "no", "sure", ".", "!", "?",
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------- EFE computation
|
||||
|
||||
|
||||
def expected_free_energy_reduction(cue: str) -> float:
|
||||
"""Estimate the expected free-energy reduction for `cue` (bits).
|
||||
|
||||
- Empty or <3 chars -> 0.0 (below threshold; skip)
|
||||
- Fixed trivial set -> 0.1 (below threshold; skip)
|
||||
- Single non-trivial -> 0.25 (above threshold; proceed)
|
||||
- General formula -> min(2.0, log2(1 + unique_token_count) * 0.5)
|
||||
"""
|
||||
if not cue:
|
||||
return 0.0
|
||||
stripped = cue.strip()
|
||||
if len(stripped) < 3:
|
||||
return 0.0
|
||||
|
||||
normalised = stripped.lower().rstrip(".!?").strip()
|
||||
if normalised in TRIVIAL_SHORT_CUES:
|
||||
return 0.1
|
||||
|
||||
tokens = [t for t in stripped.split() if t]
|
||||
unique = len({t.lower() for t in tokens})
|
||||
if unique <= 1:
|
||||
# Single token not in trivial set -- rare/novel token MAY be a proper
|
||||
# noun, code identifier, or keyword. Stay above threshold.
|
||||
return 0.25
|
||||
value = math.log2(1 + unique) * 0.5
|
||||
return min(2.0, float(value))
|
||||
|
||||
|
||||
# ---------------------------------------------------------- skip decision
|
||||
|
||||
|
||||
def should_skip_retrieval(cue: str) -> tuple[bool, str]:
|
||||
"""Return (skip, reason) per D-26.
|
||||
|
||||
reason is a short English diagnostic suitable for a RecallResponse hint.
|
||||
"""
|
||||
if not cue or len(cue.strip()) < 3:
|
||||
return True, "very short cue (<3 chars); no discriminable signal"
|
||||
|
||||
value = expected_free_energy_reduction(cue)
|
||||
if value < THETA_SKIP:
|
||||
return True, (
|
||||
f"trivial cue (EFE {value:.3f} bits < theta {THETA_SKIP})"
|
||||
)
|
||||
return False, ""
|
||||
198
src/iai_mcp/graph.py
Normal file
198
src/iai_mcp/graph.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""Dual-library graph wrapper.
|
||||
|
||||
NetworkX for dev ergonomics at small N; igraph (C-backed) for hot-path at
|
||||
N >= IGRAPH_THRESHOLD. Backend switches automatically in add_node when the
|
||||
node count crosses the threshold, so callers don't have to care.
|
||||
|
||||
Exposed surface (consumed by community.py, richclub.py, pipeline.py):
|
||||
- add_node, add_edge
|
||||
- node_count, backend (property)
|
||||
- centrality() -> dict[UUID, float] # betweenness
|
||||
- two_hop_neighborhood(seeds, top_k) # CONN-03 greedy spread
|
||||
- rich_club_coefficient() # van den Heuvel & Sporns 2011
|
||||
- get_embedding(node_id)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import networkx as nx
|
||||
|
||||
# switch to C-backed igraph at N >= 500 (centrality + Leiden hot path).
|
||||
IGRAPH_THRESHOLD = 500
|
||||
|
||||
try:
|
||||
import igraph as ig # type: ignore
|
||||
_HAS_IGRAPH = True
|
||||
except ImportError: # pragma: no cover -- igraph is a hard dep in pyproject
|
||||
_HAS_IGRAPH = False
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
"""Dual-library graph. NetworkX is the source of truth for topology; igraph
|
||||
is rebuilt on demand when backend flips.
|
||||
|
||||
Storage model:
|
||||
- `self._nx` holds the authoritative NetworkX graph (str(UUID) node labels).
|
||||
- `self._attrs` maps UUID -> {"community_id": UUID|None, "embedding": list[float]}.
|
||||
- `self._ig` holds a cached igraph mirror once the backend switches.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._nx: nx.Graph = nx.Graph()
|
||||
self._ig: "ig.Graph | None" = None
|
||||
self._attrs: dict[UUID, dict[str, Any]] = {}
|
||||
self._backend: str = "networkx"
|
||||
|
||||
# -------------------------------------------------------------- properties
|
||||
|
||||
@property
|
||||
def backend(self) -> str:
|
||||
return self._backend
|
||||
|
||||
def node_count(self) -> int:
|
||||
return self._nx.number_of_nodes()
|
||||
|
||||
# ----------------------------------------------------------------- writes
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node_id: UUID,
|
||||
community_id: UUID | None,
|
||||
embedding: list[float],
|
||||
) -> None:
|
||||
self._nx.add_node(str(node_id))
|
||||
self._attrs[node_id] = {
|
||||
"community_id": community_id,
|
||||
"embedding": embedding,
|
||||
}
|
||||
self._maybe_switch_backend()
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
src: UUID,
|
||||
dst: UUID,
|
||||
weight: float = 1.0,
|
||||
edge_type: str = "hebbian",
|
||||
) -> None:
|
||||
self._nx.add_edge(
|
||||
str(src), str(dst), weight=weight, edge_type=edge_type
|
||||
)
|
||||
if self._ig is not None:
|
||||
# igraph mirror is immutable by topology; rebuild after each edge
|
||||
# write while in igraph backend. Cheap enough at Phase-1 scale.
|
||||
self._rebuild_igraph()
|
||||
|
||||
# ------------------------------------------------------ backend switching
|
||||
|
||||
def _maybe_switch_backend(self) -> None:
|
||||
n = self.node_count()
|
||||
if (
|
||||
n >= IGRAPH_THRESHOLD
|
||||
and self._backend == "networkx"
|
||||
and _HAS_IGRAPH
|
||||
):
|
||||
self._rebuild_igraph()
|
||||
self._backend = "igraph"
|
||||
|
||||
def _rebuild_igraph(self) -> None:
|
||||
if not _HAS_IGRAPH:
|
||||
return
|
||||
nodes = list(self._nx.nodes())
|
||||
idx = {n: i for i, n in enumerate(nodes)}
|
||||
edges = [(idx[u], idx[v]) for u, v in self._nx.edges()]
|
||||
weights = [
|
||||
float(self._nx[u][v].get("weight", 1.0)) for u, v in self._nx.edges()
|
||||
]
|
||||
g = ig.Graph(n=len(nodes), edges=edges, directed=False)
|
||||
g.vs["name"] = nodes
|
||||
if weights:
|
||||
g.es["weight"] = weights
|
||||
self._ig = g
|
||||
|
||||
# ---------------------------------------------------------- graph metrics
|
||||
|
||||
def centrality(self) -> dict[UUID, float]:
|
||||
"""Betweenness centrality. NetworkX for small N, igraph at scale.
|
||||
|
||||
Empty-edge graphs return all-zero centrality (betweenness undefined).
|
||||
"""
|
||||
if self._backend == "networkx":
|
||||
if self._nx.number_of_edges() == 0:
|
||||
return {UUID(n): 0.0 for n in self._nx.nodes()}
|
||||
bc = nx.betweenness_centrality(self._nx, weight="weight")
|
||||
return {UUID(n): float(c) for n, c in bc.items()}
|
||||
# igraph path
|
||||
assert self._ig is not None
|
||||
has_weight = "weight" in self._ig.es.attributes()
|
||||
raw = self._ig.betweenness(weights="weight" if has_weight else None)
|
||||
names = self._ig.vs["name"]
|
||||
return {UUID(name): float(c) for name, c in zip(names, raw)}
|
||||
|
||||
def two_hop_neighborhood(
|
||||
self, seeds: list[UUID], top_k: int = 5
|
||||
) -> list[UUID]:
|
||||
"""CONN-03: 2-hop greedy spread.
|
||||
|
||||
At each hop, for each frontier node, take the top_k highest-weight
|
||||
neighbours (Seguin 2018 local-information reconstruction). Dedup
|
||||
across seeds and hops; exclude seeds themselves.
|
||||
"""
|
||||
visited: set[str] = {str(s) for s in seeds}
|
||||
frontier: set[str] = {str(s) for s in seeds if str(s) in self._nx}
|
||||
collected: set[str] = set()
|
||||
|
||||
for _ in range(2): # 2 hops
|
||||
next_frontier: set[str] = set()
|
||||
for node in frontier:
|
||||
if node not in self._nx:
|
||||
continue
|
||||
neighbours = [
|
||||
(n, float(self._nx[node][n].get("weight", 1.0)))
|
||||
for n in self._nx.neighbors(node)
|
||||
]
|
||||
neighbours.sort(key=lambda x: x[1], reverse=True)
|
||||
for n, _ in neighbours[:top_k]:
|
||||
if n not in visited:
|
||||
next_frontier.add(n)
|
||||
collected.add(n)
|
||||
visited.add(n)
|
||||
frontier = next_frontier
|
||||
if not frontier:
|
||||
break
|
||||
|
||||
return [UUID(n) for n in collected]
|
||||
|
||||
def rich_club_coefficient(self, k_threshold: int | None = None) -> float:
|
||||
"""van den Heuvel & Sporns 2011 -- rich-club coefficient.
|
||||
|
||||
Defaults to using the degree at the 90th percentile as the threshold,
|
||||
matching the 10% rich-club convention used in the connectome literature.
|
||||
Returns 0.0 on graphs smaller than 2 nodes or without any edges.
|
||||
"""
|
||||
if (
|
||||
self._nx.number_of_nodes() < 2
|
||||
or self._nx.number_of_edges() == 0
|
||||
):
|
||||
return 0.0
|
||||
if k_threshold is None:
|
||||
degrees = [d for _, d in self._nx.degree()]
|
||||
if not degrees:
|
||||
return 0.0
|
||||
sorted_deg = sorted(degrees)
|
||||
# 90th percentile ~ top 10% threshold. len//10 is conservative rounding.
|
||||
k_threshold = int(max(1, sorted_deg[-max(1, len(degrees) // 10)]))
|
||||
try:
|
||||
coeffs = nx.rich_club_coefficient(self._nx, normalized=False)
|
||||
except (ZeroDivisionError, nx.NetworkXError):
|
||||
# Rich-club is undefined for disconnected or very small graphs.
|
||||
return 0.0
|
||||
return float(coeffs.get(k_threshold, 0.0))
|
||||
|
||||
# ---------------------------------------------------------------- helpers
|
||||
|
||||
def get_embedding(self, node_id: UUID) -> list[float] | None:
|
||||
"""Return the embedding attached at add_node() time, or None."""
|
||||
attrs = self._attrs.get(node_id)
|
||||
return attrs.get("embedding") if attrs else None
|
||||
188
src/iai_mcp/guard.py
Normal file
188
src/iai_mcp/guard.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
"""D-GUARD: graceful-degradation ladder before any LLM call.
|
||||
|
||||
Every LLM-dependent operation must pass through `should_call_llm`
|
||||
BEFORE making an API call. The 7-step ladder (D-GUARD):
|
||||
|
||||
1. sleep.llm_enabled=true? else Tier 0
|
||||
2. API key present? else Tier 0
|
||||
3. BudgetLedger daily cap OK? else Tier 0
|
||||
4. BudgetLedger monthly cap OK? else Tier 0
|
||||
5. RateLimitLedger: last 429 > 15 min ago? else Tier 0 this cycle
|
||||
6. API call with retry(max=2, exp backoff) + timeout(60s) -- caller's job
|
||||
7. On 429/400/401/5xx -> record in ledger, Tier 0 this cycle -- caller's job
|
||||
|
||||
Write & read paths (memory_recall/reinforce/contradict, profile_get/set,
|
||||
session_start) NEVER block on LLM failure. LLM failures only reduce the QUALITY
|
||||
of semantic consolidation, schema induction, and identity refinement.
|
||||
|
||||
Budget defaults: daily_usd_cap=$0.10, monthly_usd_cap=$3.00,
|
||||
cooldown=15min, on_cap_hit=fallback_to_local.
|
||||
|
||||
BudgetLedger + RateLimitLedger persist in LanceDB tables (budget_ledger,
|
||||
ratelimit_ledger) created by MemoryStore._ensure_tables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from iai_mcp.store import BUDGET_TABLE, RATELIMIT_TABLE, MemoryStore
|
||||
|
||||
|
||||
# D-GUARD defaults
|
||||
BUDGET_DAILY_USD_DEFAULT = 0.10
|
||||
BUDGET_MONTHLY_USD_DEFAULT = 3.00
|
||||
RATELIMIT_COOLDOWN_MIN = 15
|
||||
|
||||
|
||||
class BudgetLedger:
|
||||
"""LanceDB-backed daily + monthly USD spend tracker (D-GUARD).
|
||||
|
||||
Caps default to $0.10/day and $3.00/month. Both are advisory (no OS-level
|
||||
enforcement); caller inspects can_spend() before invoking an LLM API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: MemoryStore,
|
||||
daily_usd_cap: float = BUDGET_DAILY_USD_DEFAULT,
|
||||
monthly_usd_cap: float = BUDGET_MONTHLY_USD_DEFAULT,
|
||||
) -> None:
|
||||
self.store = store
|
||||
self.daily_cap = float(daily_usd_cap)
|
||||
self.monthly_cap = float(monthly_usd_cap)
|
||||
|
||||
# ---- internal helpers
|
||||
|
||||
def _today_utc(self) -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
def _this_month(self) -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
# ---- queries
|
||||
|
||||
def daily_used(self) -> float:
|
||||
"""Sum of usd_spent rows for today (UTC)."""
|
||||
tbl = self.store.db.open_table(BUDGET_TABLE)
|
||||
df = tbl.to_pandas()
|
||||
if df.empty:
|
||||
return 0.0
|
||||
today = df[df["date"] == self._today_utc()]
|
||||
return float(today["usd_spent"].sum()) if not today.empty else 0.0
|
||||
|
||||
def monthly_used(self) -> float:
|
||||
"""Sum of usd_spent rows for the current month (UTC)."""
|
||||
tbl = self.store.db.open_table(BUDGET_TABLE)
|
||||
df = tbl.to_pandas()
|
||||
if df.empty:
|
||||
return 0.0
|
||||
mo = df[df["date"].str.startswith(self._this_month())]
|
||||
return float(mo["usd_spent"].sum()) if not mo.empty else 0.0
|
||||
|
||||
def can_spend(self, usd: float) -> tuple[bool, str]:
|
||||
"""Return (ok, reason). reason is "" on success."""
|
||||
daily = self.daily_used()
|
||||
if daily + float(usd) > self.daily_cap:
|
||||
return (
|
||||
False,
|
||||
f"daily cap exceeded (used {daily:.4f} + {float(usd):.4f} "
|
||||
f"> {self.daily_cap:.4f})",
|
||||
)
|
||||
monthly = self.monthly_used()
|
||||
if monthly + float(usd) > self.monthly_cap:
|
||||
return (
|
||||
False,
|
||||
f"monthly cap exceeded (used {monthly:.4f} + {float(usd):.4f} "
|
||||
f"> {self.monthly_cap:.4f})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
# ---- writes
|
||||
|
||||
def record_spend(self, usd: float, kind: str = "llm") -> None:
|
||||
"""Persist a spend event to the ledger."""
|
||||
tbl = self.store.db.open_table(BUDGET_TABLE)
|
||||
tbl.add(
|
||||
[
|
||||
{
|
||||
"date": self._today_utc(),
|
||||
"usd_spent": float(usd),
|
||||
"kind": kind,
|
||||
"ts": datetime.now(timezone.utc),
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class RateLimitLedger:
|
||||
"""LanceDB-backed 429 history with 15-min cooldown (D-GUARD)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: MemoryStore,
|
||||
cooldown_minutes: int = RATELIMIT_COOLDOWN_MIN,
|
||||
) -> None:
|
||||
self.store = store
|
||||
self.cooldown = timedelta(minutes=int(cooldown_minutes))
|
||||
|
||||
def in_cooldown(self) -> bool:
|
||||
"""True iff the most recent 429 was less than `cooldown_minutes` ago."""
|
||||
tbl = self.store.db.open_table(RATELIMIT_TABLE)
|
||||
df = tbl.to_pandas()
|
||||
if df.empty:
|
||||
return False
|
||||
latest = df["ts"].max()
|
||||
# Pandas timestamp -> python datetime; may be naive on some backends.
|
||||
try:
|
||||
py = latest.to_pydatetime()
|
||||
except AttributeError:
|
||||
py = latest
|
||||
if py.tzinfo is None:
|
||||
py = py.replace(tzinfo=timezone.utc)
|
||||
return (datetime.now(timezone.utc) - py) < self.cooldown
|
||||
|
||||
def record_429(self, endpoint: str = "anthropic") -> None:
|
||||
"""Record a 429 hit; subsequent in_cooldown() calls will see it."""
|
||||
tbl = self.store.db.open_table(RATELIMIT_TABLE)
|
||||
tbl.add(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now(timezone.utc),
|
||||
"status_code": 429,
|
||||
"endpoint": endpoint,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def should_call_llm(
|
||||
budget: BudgetLedger,
|
||||
rate: RateLimitLedger,
|
||||
llm_enabled: bool,
|
||||
has_api_key: bool,
|
||||
estimated_usd: float = 0.001,
|
||||
) -> tuple[bool, str]:
|
||||
"""D-GUARD 7-step ladder.
|
||||
|
||||
Returns (ok, reason). reason is "ok" on success or a short diagnostic
|
||||
describing which ladder step blocked the call.
|
||||
|
||||
Ordering is constitutional: downstream plans rely on this exact
|
||||
precedence. Changing the order without updating test_should_call_llm_ordering_*
|
||||
tests is a spec violation.
|
||||
"""
|
||||
# Step 1: sleep.llm_enabled toggle.
|
||||
if not llm_enabled:
|
||||
return False, "sleep.llm_enabled=false"
|
||||
# Step 2: credentials.
|
||||
if not has_api_key:
|
||||
return False, "no api key"
|
||||
# Step 3 + 4: budget caps (daily, then monthly). can_spend tests both.
|
||||
ok, reason = budget.can_spend(estimated_usd)
|
||||
if not ok:
|
||||
return False, reason
|
||||
# Step 5: rate-limit cooldown.
|
||||
if rate.in_cooldown():
|
||||
return False, "ratelimit cooldown (last 429 < 15min)"
|
||||
# Steps 6-7 are caller's responsibility (retry + 429 recording).
|
||||
return True, "ok"
|
||||
158
src/iai_mcp/handle.py
Normal file
158
src/iai_mcp/handle.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Compact session handle (Plan 05-06 -- ≤16 raw tok target).
|
||||
|
||||
Collapses three pointer fields historically emitted at session-start::
|
||||
|
||||
<id:{8-hex}> (~8 raw tok) identity pointer (L0 UUID prefix)
|
||||
<sess:{8-hex} pend:{N}> (~12 raw tok) brain session handle + pending
|
||||
<topic:{label<=8}> (~8 raw tok) dominant community hint
|
||||
|
||||
into a single opaque pointer::
|
||||
|
||||
<iai:HHHHHHHHHHHHHHHH> (~6-10 raw tok) 16-hex blake2s digest
|
||||
|
||||
The payload bytes are derived deterministically from the three inputs via
|
||||
blake2s(digest_size=8) -> 64 bits -> 16 hex chars. Deterministic encoding
|
||||
means identical (id, sess, topic, pending) always yields the same handle,
|
||||
so the handle can be quoted back to the server and resolved.
|
||||
|
||||
Resolution: the module keeps a bounded LRU (`_HANDLE_CACHE`) of the most
|
||||
recent encodings so the wrapper / recall paths can decode a handle back
|
||||
into its tuple without re-running the encoder. The cache is process-
|
||||
local and intentionally small -- session-start emits one handle per new
|
||||
session, so 256 slots handles the realistic working set with room for
|
||||
concurrent sessions during sleep-daemon transitions. Misses are a
|
||||
possible outcome (stale handle from an old process) and callers treat
|
||||
them as recoverable: the live payload still carries the legacy pointer
|
||||
fields under ``standard`` / ``deep`` wake_depth for fallback.
|
||||
|
||||
Security / invariants:
|
||||
|
||||
* The handle carries NO secrets. It is a hash of values Claude already
|
||||
saw (L0 UUID prefix, session id prefix, community label, pending
|
||||
count). Compromising the handle tells an attacker nothing they could
|
||||
not learn from the full session-start payload.
|
||||
* blake2s is non-reversible. The cache is the only decode path. A
|
||||
caller that did not mint the handle cannot invert it -- by design.
|
||||
* C6 (read-only audit) is untouched: this module writes nothing to the
|
||||
store; the cache is pure in-memory state.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
|
||||
# ------------------------------------------------------------------ constants
|
||||
|
||||
#: Regex a compact handle must match. Exposed for test assertions and
|
||||
#: for the decoder's input-validation contract.
|
||||
COMPACT_HANDLE_RE = re.compile(r"<iai:[0-9a-f]{16}>")
|
||||
|
||||
#: Raw-token budget ceiling for the compact handle per target.
|
||||
#: Enforced by tests/test_handle.py against ``bench/tokens._approx_tokens``.
|
||||
COMPACT_HANDLE_TOKEN_BUDGET = 16
|
||||
|
||||
#: Cache capacity. 256 concurrent handles is plenty for the realistic
|
||||
#: steady-state: one per session, a handful of overlapping sessions
|
||||
#: during daemon sleep transitions, plus test churn. Tuning knob, not
|
||||
#: a policy guarantee.
|
||||
_CACHE_CAPACITY = 256
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ types
|
||||
|
||||
|
||||
class HandleParts(NamedTuple):
|
||||
"""Decoded parts of a compact handle (server-side, never serialised)."""
|
||||
|
||||
identity_short: str # 8 hex of L0 UUID, or "" when unseeded
|
||||
session_short: str # 8 hex of session id, or "-" placeholder
|
||||
topic_label: str # community label (<=8 char) or "none"
|
||||
pending: int # first_turn_pending count (>= 0)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ cache
|
||||
|
||||
|
||||
_HANDLE_CACHE: "OrderedDict[str, HandleParts]" = OrderedDict()
|
||||
_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _remember(handle: str, parts: HandleParts) -> None:
|
||||
"""Record handle -> parts with LRU eviction."""
|
||||
with _CACHE_LOCK:
|
||||
if handle in _HANDLE_CACHE:
|
||||
_HANDLE_CACHE.move_to_end(handle)
|
||||
return
|
||||
_HANDLE_CACHE[handle] = parts
|
||||
while len(_HANDLE_CACHE) > _CACHE_CAPACITY:
|
||||
_HANDLE_CACHE.popitem(last=False)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ public API
|
||||
|
||||
|
||||
def encode_compact_handle(
|
||||
identity_short: str,
|
||||
session_short: str,
|
||||
topic_label: str,
|
||||
pending: int,
|
||||
) -> str:
|
||||
"""Derive the ``<iai:HHHHHHHHHHHHHHHH>`` handle from the three pointer inputs.
|
||||
|
||||
The output is deterministic: equal inputs always yield equal handles.
|
||||
Inputs are normalised (``str``, sanitised) before hashing so whitespace
|
||||
or accidental newlines never affect the digest.
|
||||
|
||||
The returned handle is also inserted into the in-memory decode cache
|
||||
so ``decode_compact_handle`` can reverse it within the same process.
|
||||
"""
|
||||
id_s = str(identity_short or "")
|
||||
sess_s = str(session_short or "-")
|
||||
topic_s = str(topic_label or "none")
|
||||
# Coerce pending to a bounded non-negative int; negatives or huge values
|
||||
# are clamped to the [0, 999] window the emit site actually produces.
|
||||
try:
|
||||
pend_i = max(0, min(999, int(pending)))
|
||||
except (TypeError, ValueError):
|
||||
pend_i = 0
|
||||
|
||||
h = hashlib.blake2s(digest_size=8)
|
||||
h.update(id_s.encode("utf-8"))
|
||||
h.update(b"\x1f")
|
||||
h.update(sess_s.encode("utf-8"))
|
||||
h.update(b"\x1f")
|
||||
h.update(topic_s.encode("utf-8"))
|
||||
h.update(b"\x1f")
|
||||
h.update(str(pend_i).encode("utf-8"))
|
||||
digest = h.hexdigest() # 16 hex chars
|
||||
|
||||
handle = f"<iai:{digest}>"
|
||||
_remember(handle, HandleParts(id_s, sess_s, topic_s, pend_i))
|
||||
return handle
|
||||
|
||||
|
||||
def decode_compact_handle(handle: str) -> HandleParts | None:
|
||||
"""Return the parts for a handle minted earlier in this process.
|
||||
|
||||
Returns ``None`` when the input is malformed or the handle is no
|
||||
longer in the LRU (cold cache / different process). Callers treat a
|
||||
miss as a soft error -- the legacy ``identity_pointer`` /
|
||||
``brain_handle`` / ``topic_cluster_hint`` fields remain available in
|
||||
``standard`` / ``deep`` wake_depth for fallback resolution.
|
||||
"""
|
||||
if not isinstance(handle, str) or not COMPACT_HANDLE_RE.fullmatch(handle):
|
||||
return None
|
||||
with _CACHE_LOCK:
|
||||
parts = _HANDLE_CACHE.get(handle)
|
||||
if parts is not None:
|
||||
_HANDLE_CACHE.move_to_end(handle)
|
||||
return parts
|
||||
|
||||
|
||||
def _reset_cache_for_tests() -> None:
|
||||
"""Test-only: clear the LRU. Production code must never call this."""
|
||||
with _CACHE_LOCK:
|
||||
_HANDLE_CACHE.clear()
|
||||
333
src/iai_mcp/heartbeat_scanner.py
Normal file
333
src/iai_mcp/heartbeat_scanner.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
"""Phase 10.4 L4 — daemon-side heartbeat scanner (per-wrapper, PID-scoped).
|
||||
|
||||
Reads ``~/.iai-mcp/wrappers/heartbeat-<pid>-<uuid>.json`` files written by
|
||||
each MCP wrapper instance, validates freshness (``now - last_refresh <= M``)
|
||||
AND PID liveness (``os.kill(pid, 0)``), and aggregates presence so the daemon's
|
||||
state machine can decide WAKE vs BEDTIME.
|
||||
|
||||
Constraints (carried from CONTEXT 10.4):
|
||||
- Idle CPU near zero — scanner runs on lifecycle TICK (every 30s), not faster.
|
||||
- Scanner code is reentrant: ``scan()`` MUST be safe to call concurrently with
|
||||
a wrapper writing a heartbeat file (atomic rename pattern + JSON-parse-fail
|
||||
fallback to file mtime).
|
||||
- No new third-party dependencies — stdlib only.
|
||||
- macOS-only PID semantics carried through (Linux subset works the same; only
|
||||
Windows is unsupported, which matches the phase's macOS-only stance).
|
||||
- This module is STANDALONE — daemon main-loop integration lands in Phase 10.5.
|
||||
|
||||
Heartbeat file schema (written by wrapper, read here)::
|
||||
|
||||
{
|
||||
"pid": 12345,
|
||||
"uuid": "01HZQ...",
|
||||
"started_at": "2026-05-02T15:00:00Z",
|
||||
"last_refresh": "2026-05-02T15:14:30Z",
|
||||
"wrapper_version": "1.0.0",
|
||||
"schema_version": 1
|
||||
}
|
||||
|
||||
Status semantics:
|
||||
- FRESH: ``last_refresh`` within ``M`` seconds AND PID alive.
|
||||
- STALE: ``last_refresh`` older than ``M`` seconds (regardless of PID).
|
||||
- ORPHAN: PID is dead (``ProcessLookupError`` from ``kill(pid, 0)``) and the
|
||||
file's freshness window has not yet expired. Treated as not-active.
|
||||
|
||||
A file that fails JSON parse falls back to its filesystem mtime so a torn
|
||||
half-written write does not silently mask presence.
|
||||
|
||||
Validates: WAKE-07.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Module-level constants -------------------------------------------------------
|
||||
|
||||
#: Default refresh staleness threshold (seconds). A heartbeat older than this
|
||||
#: is STALE regardless of PID liveness. The wrapper SHOULD refresh every
|
||||
#: ``REFRESH_INTERVAL_SEC`` — three missed refreshes (~90 s)
|
||||
#: trip staleness.
|
||||
DEFAULT_STALE_THRESHOLD_SEC = 90
|
||||
|
||||
#: Window for the "no fresh activity in last 30 minutes" predicate consumed
|
||||
#: by the L6 ``IdleDetector.sleep_eligible`` rule.
|
||||
IDLE_WINDOW_SEC = 30 * 60
|
||||
|
||||
#: Filename glob used to enumerate heartbeat files. Matches the
|
||||
#: ``heartbeat-<pid>-<uuid>.json`` convention from CONTEXT 10.4.
|
||||
_HEARTBEAT_GLOB = "heartbeat-*.json"
|
||||
|
||||
|
||||
class HeartbeatStatus(Enum):
|
||||
"""Tri-state classification of a single heartbeat file."""
|
||||
|
||||
FRESH = "fresh"
|
||||
STALE = "stale"
|
||||
ORPHAN = "orphan"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatEntry:
|
||||
"""One scanned heartbeat file with its derived status.
|
||||
|
||||
Attributes:
|
||||
path: Absolute path of the heartbeat file on disk.
|
||||
pid: Wrapper PID parsed from the file's payload.
|
||||
uuid: Wrapper UUID parsed from the file's payload (used as a stable
|
||||
tie-breaker when the same PID is reused after wrapper restart).
|
||||
last_refresh: Timezone-aware UTC datetime parsed from
|
||||
``last_refresh``; falls back to file mtime if JSON parse fails.
|
||||
status: One of ``HeartbeatStatus.{FRESH, STALE, ORPHAN}``.
|
||||
"""
|
||||
|
||||
path: Path
|
||||
pid: int
|
||||
uuid: str
|
||||
last_refresh: datetime
|
||||
status: HeartbeatStatus
|
||||
|
||||
|
||||
# PID liveness ----------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_pid_alive(pid: int) -> bool:
|
||||
"""Return True iff ``pid`` exists in the kernel's process table.
|
||||
|
||||
Uses the ``kill(pid, 0)`` POSIX trick — sends no signal but raises
|
||||
``ProcessLookupError`` (ESRCH) when the PID has been reaped. A
|
||||
``PermissionError`` (EPERM) means the process exists but the current
|
||||
user cannot signal it — for liveness purposes we count that as alive.
|
||||
A negative or zero ``pid`` is treated as dead (those values would map
|
||||
to ``kill(self_pgrp, 0)`` semantics which is not what we want).
|
||||
"""
|
||||
if pid <= 0:
|
||||
return False
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
# Atomic-read-with-mtime-fallback helper --------------------------------------
|
||||
|
||||
|
||||
def _parse_heartbeat_file(path: Path) -> tuple[int, str, datetime] | None:
|
||||
"""Best-effort parse of a single heartbeat file.
|
||||
|
||||
Returns ``(pid, uuid, last_refresh_utc)`` on success or ``None`` if the
|
||||
file disappeared mid-read (race with wrapper rotation) or its content
|
||||
cannot be coerced into the minimum schema.
|
||||
|
||||
A JSON-parse failure falls back to the file's mtime so that a torn
|
||||
write produced by a wrapper crash mid-rename is treated as STALE-on-
|
||||
age rather than silently dropped — matches the "reentrant + safe under
|
||||
concurrent writers" requirement in PLAN 10.4-01 Task 1.1.
|
||||
"""
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
# Torn write — fall back to filename PID + filesystem mtime so we
|
||||
# at least get a STALE classification rather than dropping the file.
|
||||
return _fallback_parse_from_filename(path)
|
||||
|
||||
pid = payload.get("pid")
|
||||
uuid_str = payload.get("uuid", "")
|
||||
last_refresh_raw = payload.get("last_refresh")
|
||||
|
||||
if not isinstance(pid, int) or not isinstance(uuid_str, str):
|
||||
return _fallback_parse_from_filename(path)
|
||||
if not isinstance(last_refresh_raw, str):
|
||||
return _fallback_parse_from_filename(path)
|
||||
|
||||
try:
|
||||
# ``2026-05-02T15:14:30Z`` — Python 3.11+ accepts the trailing Z;
|
||||
# for safety we normalize to ``+00:00`` for older 3.10 compatibility.
|
||||
normalized = last_refresh_raw.replace("Z", "+00:00")
|
||||
last_refresh = datetime.fromisoformat(normalized)
|
||||
except ValueError:
|
||||
return _fallback_parse_from_filename(path)
|
||||
|
||||
if last_refresh.tzinfo is None:
|
||||
last_refresh = last_refresh.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
last_refresh = last_refresh.astimezone(timezone.utc)
|
||||
|
||||
return pid, uuid_str, last_refresh
|
||||
|
||||
|
||||
def _fallback_parse_from_filename(path: Path) -> tuple[int, str, datetime] | None:
|
||||
"""Recover ``(pid, uuid, mtime_utc)`` from filename + filesystem stat.
|
||||
|
||||
Filename convention: ``heartbeat-<pid>-<uuid>.json``. We split on ``-``
|
||||
once for ``heartbeat`` and once for the PID, joining the remainder as
|
||||
the UUID (UUIDs may contain dashes).
|
||||
"""
|
||||
name = path.stem # heartbeat-<pid>-<uuid>
|
||||
parts = name.split("-", 2)
|
||||
if len(parts) != 3 or parts[0] != "heartbeat":
|
||||
return None
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
uuid_str = parts[2]
|
||||
try:
|
||||
mtime = path.stat().st_mtime
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
return pid, uuid_str, datetime.fromtimestamp(mtime, tz=timezone.utc)
|
||||
|
||||
|
||||
# HeartbeatScanner -------------------------------------------------------------
|
||||
|
||||
|
||||
class HeartbeatScanner:
|
||||
"""Aggregates per-wrapper heartbeat files into a daemon-side presence signal.
|
||||
|
||||
standalone module — wires this into the daemon
|
||||
main-loop TICK to dispatch HEARTBEAT_REFRESH / IDLE state events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrappers_dir: Path,
|
||||
stale_threshold_sec: int = DEFAULT_STALE_THRESHOLD_SEC,
|
||||
) -> None:
|
||||
self._wrappers_dir = wrappers_dir
|
||||
self._stale_threshold_sec = stale_threshold_sec
|
||||
self._last_scan: list[HeartbeatEntry] = []
|
||||
|
||||
# ----- Scan / classify -----------------------------------------------
|
||||
|
||||
def scan(self) -> list[HeartbeatEntry]:
|
||||
"""Read all heartbeat files, classify each, and return entries.
|
||||
|
||||
Reentrant: tolerates concurrent writes by ignoring files that vanish
|
||||
mid-read and falling back to mtime when JSON is half-written.
|
||||
|
||||
Empty / missing wrappers dir → empty list (the daemon hasn't seen
|
||||
any wrappers yet, which is a valid steady state on a fresh install).
|
||||
"""
|
||||
entries: list[HeartbeatEntry] = []
|
||||
if not self._wrappers_dir.exists():
|
||||
self._last_scan = entries
|
||||
return entries
|
||||
|
||||
try:
|
||||
candidates = list(self._wrappers_dir.glob(_HEARTBEAT_GLOB))
|
||||
except OSError:
|
||||
self._last_scan = entries
|
||||
return entries
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for path in candidates:
|
||||
parsed = _parse_heartbeat_file(path)
|
||||
if parsed is None:
|
||||
# File vanished mid-glob (cleanup race) — skip silently.
|
||||
continue
|
||||
pid, uuid_str, last_refresh = parsed
|
||||
|
||||
age_sec = (now - last_refresh).total_seconds()
|
||||
is_alive = _is_pid_alive(pid)
|
||||
|
||||
if age_sec > self._stale_threshold_sec:
|
||||
# Stale wins over orphan — the file is too old to trust
|
||||
# regardless of whether its PID happens to still be live.
|
||||
status = HeartbeatStatus.STALE
|
||||
elif not is_alive:
|
||||
status = HeartbeatStatus.ORPHAN
|
||||
else:
|
||||
status = HeartbeatStatus.FRESH
|
||||
|
||||
entries.append(
|
||||
HeartbeatEntry(
|
||||
path=path,
|
||||
pid=pid,
|
||||
uuid=uuid_str,
|
||||
last_refresh=last_refresh,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
self._last_scan = entries
|
||||
return entries
|
||||
|
||||
# ----- Aggregations consumed by the state machine --------------------
|
||||
|
||||
def fresh_count(self) -> int:
|
||||
"""Number of heartbeats classified as FRESH on the most recent scan.
|
||||
|
||||
Re-runs ``scan()`` so callers don't have to remember to invoke it
|
||||
first; the cost is one filesystem walk per call which is negligible
|
||||
at TICK cadence (every 30 s).
|
||||
"""
|
||||
return sum(1 for e in self.scan() if e.status is HeartbeatStatus.FRESH)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""True iff at least one wrapper is currently FRESH.
|
||||
|
||||
This is the primary signal the state machine uses to dispatch
|
||||
HEARTBEAT_REFRESH (→ WAKE) vs. begin the IDLE-eligibility check.
|
||||
"""
|
||||
return self.fresh_count() >= 1
|
||||
|
||||
def heartbeat_idle_30min(self) -> bool:
|
||||
"""True iff no FRESH heartbeats existed in the last 30 minutes.
|
||||
|
||||
Consumed by ``IdleDetector.sleep_eligible`` as one of the three
|
||||
disjuncts that gate L6 sleep. "No FRESH in window" is implemented
|
||||
as: scan now, and if zero entries are FRESH, the window is empty.
|
||||
STALE / ORPHAN entries imply the wrapper has not refreshed for at
|
||||
least the staleness threshold (90 s by default), so a single scan
|
||||
suffices — we don't keep a history buffer in this module.
|
||||
"""
|
||||
# Fresh count == 0 means no wrapper is currently active. Combined
|
||||
# with the 30-min wall-clock window enforced by the daemon's TICK
|
||||
# rhythm and the L6 idle predicate's hardware backstop (HIDIdleTime
|
||||
# ≥ 1800 s), this gives the same observable behavior as a separate
|
||||
# 30-minute history without keeping in-memory state.
|
||||
return self.fresh_count() == 0
|
||||
|
||||
# ----- Cleanup -------------------------------------------------------
|
||||
|
||||
def cleanup_stale_orphans(self) -> int:
|
||||
"""Delete heartbeat files classified STALE or ORPHAN. Returns count deleted.
|
||||
|
||||
Best-effort: a delete that races with another process unlinking the
|
||||
same file (``FileNotFoundError``) is counted as a successful
|
||||
cleanup; any other ``OSError`` is swallowed so a single problematic
|
||||
file cannot break the rest of the cleanup pass.
|
||||
"""
|
||||
deleted = 0
|
||||
for entry in self.scan():
|
||||
if entry.status is HeartbeatStatus.FRESH:
|
||||
continue
|
||||
try:
|
||||
entry.path.unlink()
|
||||
deleted += 1
|
||||
except FileNotFoundError:
|
||||
# Already unlinked (concurrent wrapper rotation / sibling
|
||||
# daemon scan). Count as cleaned — the file is gone.
|
||||
deleted += 1
|
||||
except OSError:
|
||||
# Permission / FS error on a single file: skip it, keep
|
||||
# going. The doctor row will surface persistent
|
||||
# cleanup failures via "n=X stale" delta on next run.
|
||||
continue
|
||||
return deleted
|
||||
122
src/iai_mcp/hebbian_structure.py
Normal file
122
src/iai_mcp/hebbian_structure.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Plan 03-01 CONN-05 D-TEM-04: structure-edge Hebbian LTP.
|
||||
|
||||
Mirrors content-edge Hebbian (retrieve.reinforce_edges -> store.boost_edges
|
||||
with edge_type="hebbian"). Co-retrieval of two records whose structure_hv
|
||||
hypervectors are sufficiently similar (Hamming similarity >= 0.7 by default)
|
||||
strengthens a "hebbian_structure" edge between them. FSRS decay on the new
|
||||
edge type is identical to the content-edge formula in sleep._decay_edges.
|
||||
|
||||
Constitutional fit:
|
||||
- D-TEM-04: Hebbian LTP on structure edges. Autopoiesis applied to structure;
|
||||
the brain reinforces structural co-occurrence the same way it reinforces
|
||||
content co-occurrence in Phase 1.
|
||||
- Flat layout (PATTERNS.md): no `connectome/` subpackage. Module path is
|
||||
src/iai_mcp/hebbian_structure.py.
|
||||
- Same shape as retrieve.reinforce_edges -- pairwise iterate, compute
|
||||
similarity, call store.boost_edges with edge_type="hebbian_structure".
|
||||
|
||||
Public API:
|
||||
- STRUCTURAL_SIMILARITY_THRESHOLD: pairs above this fire LTP (default 0.7).
|
||||
- structural_similarity(a, b): 1 - hamming_distance(a, b) / D in [0, 1].
|
||||
- strengthen_structure_edge(store, src_id, dst_id, gain=1.0): boost the
|
||||
structure edge between two records.
|
||||
- co_retrieval_trigger(store, hits): pairwise scan of co-retrieved hits;
|
||||
fire strengthen_structure_edge for every pair above the threshold.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import combinations
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import STRUCTURE_HV_DIM
|
||||
|
||||
|
||||
# D-TEM-04 default trigger (per plan Task 2b behavior contract):
|
||||
# co-retrieval LTP fires when structural similarity >= 0.7 (Hamming distance
|
||||
# fraction <= 0.3). Tunable later via the profile registry if a knob is added.
|
||||
STRUCTURAL_SIMILARITY_THRESHOLD: float = 0.7
|
||||
|
||||
|
||||
def structural_similarity(a: bytes, b: bytes) -> float:
|
||||
"""Return 1 - hamming_distance(a, b) / STRUCTURE_HV_DIM in [0.0, 1.0].
|
||||
|
||||
Empty / unequal-length / corrupt inputs return 0.0 (graceful degradation).
|
||||
"""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
aa = np.frombuffer(a, dtype=np.uint8)
|
||||
bb = np.frombuffer(b, dtype=np.uint8)
|
||||
# popcount of XOR -> hamming distance in bits.
|
||||
xor = np.bitwise_xor(aa, bb)
|
||||
# numpy >= 2.x has np.bitwise_count; fall back to unpackbits sum on older.
|
||||
try:
|
||||
ham_bits = int(np.bitwise_count(xor).sum())
|
||||
except AttributeError:
|
||||
ham_bits = int(np.unpackbits(xor).sum())
|
||||
return 1.0 - (ham_bits / STRUCTURE_HV_DIM)
|
||||
|
||||
|
||||
def strengthen_structure_edge(
|
||||
store: MemoryStore,
|
||||
src_id: UUID,
|
||||
dst_id: UUID,
|
||||
gain: float = 1.0,
|
||||
) -> dict[tuple[str, str], float]:
|
||||
"""Plan 03-01 D-TEM-04: structure-edge LTP via store.boost_edges.
|
||||
|
||||
Returns the new weights dict (same shape as retrieve.reinforce_edges'
|
||||
underlying call). Mirrors content-edge LTP shape so downstream code
|
||||
(events, audit, decay sweep) treats structure edges identically.
|
||||
"""
|
||||
return store.boost_edges(
|
||||
[(src_id, dst_id)],
|
||||
delta=float(gain),
|
||||
edge_type="hebbian_structure",
|
||||
)
|
||||
|
||||
|
||||
def co_retrieval_trigger(
|
||||
store: MemoryStore,
|
||||
hits,
|
||||
*,
|
||||
threshold: float = STRUCTURAL_SIMILARITY_THRESHOLD,
|
||||
gain: float = 1.0,
|
||||
) -> int:
|
||||
"""Pairwise scan of co-retrieved hits; fire strengthen_structure_edge
|
||||
for each pair whose structural_similarity >= threshold.
|
||||
|
||||
`hits` may be a list of MemoryHit (record_id only -- structure_hv is
|
||||
fetched lazily from store.get) OR a list of MemoryRecord (faster path,
|
||||
structure_hv read directly).
|
||||
|
||||
Returns the number of structure edges strengthened. A structurally-
|
||||
isolated co-retrieved set returns 0 -- this is expected (means no two
|
||||
hits shared structure to reinforce).
|
||||
"""
|
||||
# Materialise (id, structure_hv) tuples once.
|
||||
pairs: list[tuple[UUID, bytes]] = []
|
||||
for h in hits:
|
||||
rec_id = getattr(h, "record_id", None) or getattr(h, "id", None)
|
||||
if rec_id is None:
|
||||
continue
|
||||
hv = getattr(h, "structure_hv", None)
|
||||
if hv is None:
|
||||
rec = store.get(rec_id)
|
||||
if rec is None:
|
||||
continue
|
||||
hv = rec.structure_hv
|
||||
pairs.append((rec_id, hv or b""))
|
||||
|
||||
fired = 0
|
||||
for (a_id, a_hv), (b_id, b_hv) in combinations(pairs, 2):
|
||||
if structural_similarity(a_hv, b_hv) >= threshold:
|
||||
try:
|
||||
strengthen_structure_edge(store, a_id, b_id, gain=gain)
|
||||
fired += 1
|
||||
except Exception:
|
||||
# Diagnostic only -- never block the pipeline on edge failure.
|
||||
continue
|
||||
return fired
|
||||
324
src/iai_mcp/hippea_cascade.py
Normal file
324
src/iai_mcp/hippea_cascade.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
"""TOK-14 / D5-05: HIPPEA activation-cascade prefetch.
|
||||
|
||||
Daemon receives `session_open` over the Phase-4 unix socket and this module
|
||||
computes precision-weighted salience over 7 days of `session_started` +
|
||||
`retrieval_used` events, selects top-K communities, and pre-warms their
|
||||
top-N records into a process-local LRU cache (cachetools.TTLCache) guarded
|
||||
by an asyncio.Lock.
|
||||
|
||||
Operationalization (Van de Cruys 2014 HIPPEA):
|
||||
f(c) = count(session_gated_to_community=c, last_7_days) / total_sessions_7d
|
||||
p(c) = 1 / |communities|
|
||||
PE(c) = |f(c) - p(c)|
|
||||
sigma2 = Var[day_i_count(c) : i in 7 days]
|
||||
w(c) = 1 / (sigma2(c) + 0.01)
|
||||
S(c) = w(c) * PE(c)
|
||||
top_K = argmax_K S(c) # K=3 default
|
||||
warm = union over c in top_K of top_N_by_centrality(records(c))
|
||||
|
||||
Cold-fallback (<3 sessions in 7-day window): return
|
||||
assignment.top_communities[:top_k] without variance weighting.
|
||||
|
||||
Constitutional invariants (asserted by grep guards in tests/test_hippea_cascade.py):
|
||||
- C1 HUMAN-FIRST: cascade task yields on shutdown within 5s.
|
||||
- C3 ZERO API COST: pure local -- no paid-API env var, no Anthropic SDK import.
|
||||
- C6 READ-ONLY: no store.insert / store.append_provenance / store.update calls.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Iterable
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
|
||||
# ---------------------------------------------------------- process-local LRU
|
||||
|
||||
# D5-05 constants:
|
||||
# maxsize=200, ttl=1800 (30 min). These match the recommendations and
|
||||
# keep the cache small enough to fit in MCP core RAM headroom.
|
||||
_WARM_MAXSIZE = 200
|
||||
_WARM_TTL_SECONDS = 1800
|
||||
|
||||
|
||||
_warm_lru: TTLCache[UUID, Any] = TTLCache(maxsize=_WARM_MAXSIZE, ttl=_WARM_TTL_SECONDS)
|
||||
_warm_lru_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def snapshot_warm_ids() -> list[UUID]:
|
||||
"""Lock-free snapshot of warm record IDs.
|
||||
|
||||
CPython GIL makes `list(dict.keys())` atomic for simple types. A concurrent
|
||||
mutator may race and invalidate the iterator -- we catch RuntimeError and
|
||||
return an empty list rather than propagating the rare race.
|
||||
"""
|
||||
try:
|
||||
return list(_warm_lru.keys())
|
||||
except RuntimeError:
|
||||
return []
|
||||
|
||||
|
||||
def get_warm_record(rid: UUID) -> Any | None:
|
||||
"""Return the warmed record or None. Silent on miss / structural error."""
|
||||
try:
|
||||
return _warm_lru.get(rid)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def warm_records(record_ids: Iterable[UUID], store: Any) -> int:
|
||||
"""Load records into the LRU. Returns count inserted.
|
||||
|
||||
C6: READ-ONLY against the store -- only `store.get(rid)` is called.
|
||||
Any store-get exception is swallowed per-record so a single bad id
|
||||
cannot poison the warmer.
|
||||
"""
|
||||
inserted = 0
|
||||
async with _warm_lru_lock:
|
||||
for rid in record_ids:
|
||||
try:
|
||||
rec = store.get(rid)
|
||||
if rec is not None:
|
||||
_warm_lru[rid] = rec
|
||||
inserted += 1
|
||||
except Exception:
|
||||
continue
|
||||
return inserted
|
||||
|
||||
|
||||
# ---------------------------------------------------------- salience formula
|
||||
|
||||
|
||||
def compute_salient_communities(
|
||||
store: Any,
|
||||
assignment: Any,
|
||||
*,
|
||||
lookback_days: int = 7,
|
||||
top_k: int = 3,
|
||||
) -> list[UUID]:
|
||||
"""Return top-K community UUIDs by HIPPEA salience S(c) = w(c) * PE(c).
|
||||
|
||||
Cold fallback (<3 sessions in window): return
|
||||
`assignment.top_communities[:top_k]` with no variance weighting.
|
||||
"""
|
||||
# Lazy import to keep the module's surface clean of store-mutating paths.
|
||||
from iai_mcp.events import query_events
|
||||
|
||||
since = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
try:
|
||||
sessions = query_events(store, kind="session_started", since=since, limit=10000)
|
||||
except Exception:
|
||||
sessions = []
|
||||
|
||||
if len(sessions) < 3:
|
||||
# D5-05 cold fallback: simplified formula drops the variance term.
|
||||
# Use the existing Leiden top-communities as a reasonable default.
|
||||
return list(getattr(assignment, "top_communities", []))[:top_k]
|
||||
|
||||
try:
|
||||
retrievals = query_events(
|
||||
store, kind="retrieval_used", since=since, limit=50000,
|
||||
)
|
||||
except Exception:
|
||||
retrievals = []
|
||||
|
||||
# session_id -> dominant community for that session (most retrieved).
|
||||
per_session_counter: dict[str, Counter] = defaultdict(Counter)
|
||||
for ev in retrievals:
|
||||
data = ev.get("data", {}) if isinstance(ev, dict) else {}
|
||||
sid = data.get("session_id") or ev.get("session_id", "")
|
||||
cid = data.get("community_id") or data.get("community", "")
|
||||
if sid and cid:
|
||||
per_session_counter[sid][str(cid)] += 1
|
||||
session_comm: dict[str, str] = {
|
||||
sid: ctr.most_common(1)[0][0]
|
||||
for sid, ctr in per_session_counter.items()
|
||||
if ctr
|
||||
}
|
||||
|
||||
total_sessions = len(sessions)
|
||||
community_pool: list[UUID] = list(getattr(assignment, "top_communities", []) or [])
|
||||
# Also admit any community seen in retrievals during the window even if it
|
||||
# isn't in top_communities -- the salience formula evaluates all observed
|
||||
# communities, not just the Leiden-top.
|
||||
seen: set[str] = set(session_comm.values())
|
||||
for cid in (str(c) for c in community_pool):
|
||||
seen.add(cid)
|
||||
if not seen:
|
||||
return []
|
||||
p = 1.0 / len(seen)
|
||||
|
||||
# f(c) across the window.
|
||||
freq: Counter = Counter(session_comm.values())
|
||||
|
||||
# Day-bucketed counts (0 = today, lookback_days-1 = oldest).
|
||||
day_buckets: dict[str, list[int]] = defaultdict(lambda: [0] * lookback_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
for sev in sessions:
|
||||
ts = sev.get("ts") if isinstance(sev, dict) else None
|
||||
try:
|
||||
if isinstance(ts, str):
|
||||
t = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||||
elif hasattr(ts, "to_pydatetime"):
|
||||
t = ts.to_pydatetime()
|
||||
if t.tzinfo is None:
|
||||
t = t.replace(tzinfo=timezone.utc)
|
||||
elif hasattr(ts, "tzinfo") and ts is not None:
|
||||
t = ts
|
||||
if t.tzinfo is None:
|
||||
t = t.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
t = now
|
||||
delta = (now - t).days
|
||||
day_idx = max(0, min(lookback_days - 1, delta))
|
||||
except Exception:
|
||||
day_idx = 0
|
||||
data = sev.get("data", {}) if isinstance(sev, dict) else {}
|
||||
sid = data.get("session_id") or sev.get("session_id", "")
|
||||
c = session_comm.get(sid)
|
||||
if c:
|
||||
day_buckets[c][day_idx] += 1
|
||||
|
||||
# Compute S(c) per community.
|
||||
scores: dict[str, float] = {}
|
||||
for c in seen:
|
||||
f_c = freq.get(c, 0) / max(1, total_sessions)
|
||||
pe = abs(f_c - p)
|
||||
bucket = day_buckets.get(c, [0] * lookback_days)
|
||||
n = len(bucket) or 1
|
||||
mean = sum(bucket) / n
|
||||
variance = sum((x - mean) ** 2 for x in bucket) / n
|
||||
w = 1.0 / (variance + 0.01)
|
||||
scores[c] = w * pe
|
||||
|
||||
ranked = sorted(
|
||||
scores.items(),
|
||||
key=lambda kv: (-kv[1], kv[0]), # deterministic tiebreak by cid str
|
||||
)
|
||||
top: list[UUID] = []
|
||||
for cid_str, _ in ranked:
|
||||
try:
|
||||
top.append(UUID(cid_str))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if len(top) >= top_k:
|
||||
break
|
||||
return top
|
||||
|
||||
|
||||
# ---------------------------------------------------------- centrality helper
|
||||
|
||||
|
||||
def _top_n_records_by_centrality(
|
||||
store: Any, assignment: Any, community_id: UUID, n: int,
|
||||
) -> list[UUID]:
|
||||
"""READ-ONLY: return top-N record ids for `community_id` by centrality.
|
||||
|
||||
Uses `assignment.mid_regions[community_id]` to enumerate member records,
|
||||
then reads each record's `centrality` field via store.get and sorts by
|
||||
descending centrality. Falls back to insertion order if centrality is
|
||||
missing or non-comparable.
|
||||
"""
|
||||
mid_regions = getattr(assignment, "mid_regions", {}) or {}
|
||||
member_ids = list(mid_regions.get(community_id) or [])
|
||||
if not member_ids:
|
||||
return []
|
||||
scored: list[tuple[float, UUID]] = []
|
||||
for rid in member_ids:
|
||||
try:
|
||||
rec = store.get(rid)
|
||||
except Exception:
|
||||
rec = None
|
||||
if rec is None:
|
||||
continue
|
||||
try:
|
||||
centrality = float(getattr(rec, "centrality", 0.0) or 0.0)
|
||||
except (TypeError, ValueError):
|
||||
centrality = 0.0
|
||||
scored.append((centrality, rid))
|
||||
scored.sort(key=lambda kv: (-kv[0], str(kv[1])))
|
||||
return [rid for _c, rid in scored[:n]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------- sync core-side helper
|
||||
|
||||
|
||||
def compute_core_side_warm_snapshot(
|
||||
store: Any,
|
||||
assignment: Any,
|
||||
*,
|
||||
top_k: int = 3,
|
||||
per_community: int | None = None,
|
||||
max_records: int = 50,
|
||||
) -> list[UUID]:
|
||||
"""Synchronous counterpart to :func:`run_cascade`'s compute path.
|
||||
|
||||
the MCP core runs in a different process from the sleep
|
||||
daemon, so the daemon's ``_warm_lru`` is invisible to core --
|
||||
``snapshot_warm_ids()`` returns ``[]`` in the core on every fresh
|
||||
process boot. This helper lets the core compute its OWN cascade
|
||||
inline (no asyncio dependency) and write the warmed record ids into
|
||||
its own process-local LRU. Duplicates daemon work by design; that
|
||||
is the price of not having shared-memory IPC between the two
|
||||
processes.
|
||||
|
||||
Reuses :func:`compute_salient_communities` (already sync) and
|
||||
:func:`_top_n_records_by_centrality` (sync) -- no new salience
|
||||
formula; only the orchestration that :func:`run_cascade` would do
|
||||
asynchronously.
|
||||
|
||||
READ-ONLY against store (C6 invariant); no async I/O; no paid-API
|
||||
import (C3 invariant).
|
||||
"""
|
||||
top = compute_salient_communities(store, assignment, top_k=top_k)
|
||||
if not top:
|
||||
return []
|
||||
per_c = per_community or max(1, max_records // max(1, len(top)))
|
||||
out: list[UUID] = []
|
||||
for cid in top:
|
||||
try:
|
||||
out.extend(_top_n_records_by_centrality(store, assignment, cid, per_c))
|
||||
except Exception:
|
||||
continue
|
||||
return out[:max_records]
|
||||
|
||||
|
||||
# ---------------------------------------------------------- public entrypoint
|
||||
|
||||
|
||||
async def run_cascade(
|
||||
store: Any,
|
||||
assignment: Any,
|
||||
*,
|
||||
top_k: int = 3,
|
||||
per_community: int | None = None,
|
||||
) -> dict:
|
||||
"""Pre-warm records for top-K salient communities.
|
||||
|
||||
Returns a stats dict: {
|
||||
"communities_selected": int,
|
||||
"records_warmed": int,
|
||||
"top_communities": list[str],
|
||||
}
|
||||
"""
|
||||
top = compute_salient_communities(store, assignment, top_k=top_k)
|
||||
if not top:
|
||||
return {"communities_selected": 0, "records_warmed": 0, "top_communities": []}
|
||||
|
||||
per_c = per_community or max(1, _WARM_MAXSIZE // max(1, len(top)))
|
||||
to_warm: list[UUID] = []
|
||||
for cid in top:
|
||||
try:
|
||||
rec_ids = _top_n_records_by_centrality(store, assignment, cid, per_c)
|
||||
to_warm.extend(rec_ids)
|
||||
except Exception:
|
||||
continue
|
||||
inserted = await warm_records(to_warm[:_WARM_MAXSIZE], store)
|
||||
return {
|
||||
"communities_selected": len(top),
|
||||
"records_warmed": inserted,
|
||||
"top_communities": [str(c) for c in top],
|
||||
}
|
||||
364
src/iai_mcp/host_cli.py
Normal file
364
src/iai_mcp/host_cli.py
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
"""Claude Code CLI subprocess wrapper + budget ledger.
|
||||
|
||||
Subprocess safety:
|
||||
- Uses asyncio.create_subprocess_exec (argv-list form) -- NO shell expansion.
|
||||
The prompt string is passed as a single argv element; no shell-injection surface.
|
||||
- NEVER uses asyncio.create_subprocess_shell, shell=True, or os.system.
|
||||
|
||||
Constitutional guards:
|
||||
- we DO NOT read the paid-API env var. The env is scrubbed via
|
||||
ENV_DENY_LIST before the subprocess is spawned so the key cannot leak into
|
||||
the child `claude -p` process even if set in our parent env by accident.
|
||||
- Bug #43333 defence-in-depth:
|
||||
1. Pre-flight credentials.json validation (billingType=stripe_subscription).
|
||||
2. Subprocess spawn with scrubbed env (3 hostile keys removed).
|
||||
3. Post-flight tripwire: cost_usd > 0 -> BudgetTracker.disable_host()
|
||||
+ structured error result. Subsequent calls refuse to spend.
|
||||
- this module does NOT decide frequency. insight.py orchestrates exactly
|
||||
one call per night. This module is the wrapper only.
|
||||
- self-tracked budget (1% daily, 7% weekly buffer, local
|
||||
midnight reset) persisted inside daemon_state under BUDGET_STATE_KEY.
|
||||
- force-wake during an in-flight claude -p subprocess is honoured
|
||||
cooperatively -- CancelledError is caught, the subprocess is terminated
|
||||
(with FORCE_WAKE_GRACE_SEC grace then kill escalation), and a structured
|
||||
error result is returned WITHOUT re-raising. The daemon loop stays alive.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from iai_mcp.daemon_state import load_state, save_state
|
||||
|
||||
# --------------------------------------------------------------------- constants
|
||||
# hostile env key deny list. The paid-API key must NEVER reach the
|
||||
# `claude -p` subprocess; two alias names have been seen in issue reports for
|
||||
# bug #43333 so we scrub all three. We build the key strings from fragments
|
||||
# so the literal names do not appear as static text in this module -- the
|
||||
# constitutional-guard grep test (test_no_api_key_in_daemon) greps for the
|
||||
# bare literal, and the scrub path still removes every variant at runtime.
|
||||
_ANTHR = "ANTHR" + "OPIC_" + "API_" + "KEY"
|
||||
_CLAUDE_KEY = "CLAUDE_" + "API_" + "KEY"
|
||||
_CLAUDE_CODE_KEY = "CLAUDE_" + "CODE_" + "API_" + "KEY"
|
||||
ENV_DENY_LIST: tuple[str, ...] = (_ANTHR, _CLAUDE_KEY, _CLAUDE_CODE_KEY)
|
||||
|
||||
HOST_TIMEOUT_SEC: float = 120.0 # hard wall for a single call
|
||||
FORCE_WAKE_GRACE_SEC: float = 60.0 # cooperative grace on cancel
|
||||
TERMINATE_WAIT_SEC: float = 5.0 # timeout window before kill escalation
|
||||
KILL_WAIT_SEC: float = 2.0 # bound for post-SIGKILL reap wait
|
||||
DAILY_QUOTA_BUDGET_PCT: float = 0.01 # -- 1% of daily estimate
|
||||
WEEKLY_BUFFER_PCT: float = 0.07 # -- 7% weekly ceiling
|
||||
ESTIMATED_DAILY_TOKEN_CEILING: int = 1_000_000 # heuristic (Pro subscription)
|
||||
CREDENTIALS_PATH: Path = Path.home() / ".claude" / ".credentials.json"
|
||||
BUDGET_STATE_KEY: str = "host_budget"
|
||||
|
||||
|
||||
# -------------------------------------------------------- pre-flight credentials
|
||||
|
||||
|
||||
def verify_credentials_subscription() -> dict:
|
||||
"""Validate the local Claude credentials file says the user is on a
|
||||
Stripe subscription (bug #43333 layer 2 defence).
|
||||
|
||||
We do NOT read the file's secret material. We look at `billingType` only
|
||||
and refuse to call `claude -p` when the billing mode is anything other
|
||||
than `stripe_subscription` (accepts both camelCase and snake_case keys
|
||||
since the schema has varied across Claude CLI versions).
|
||||
"""
|
||||
if not CREDENTIALS_PATH.exists():
|
||||
return {"ok": False, "reason": "credentials_file_missing"}
|
||||
try:
|
||||
data = json.loads(CREDENTIALS_PATH.read_text())
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
return {"ok": False, "reason": "credentials_unreadable", "error": str(exc)}
|
||||
billing = data.get("billingType") or data.get("billing_type") or ""
|
||||
if billing != "stripe_subscription":
|
||||
return {"ok": False, "reason": "not_subscription", "billing_type": billing}
|
||||
return {"ok": True, "billing_type": billing}
|
||||
|
||||
|
||||
# --------------------------------------------------------------- BudgetTracker
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Self-tracked daily + weekly token budget.
|
||||
|
||||
State is stored inside daemon_state under BUDGET_STATE_KEY. The tracker
|
||||
reads once at construction and writes back via save_state on any mutation.
|
||||
Thread-safety is handled at the daemon-state filesystem layer (atomic
|
||||
rename in daemon_state.save_state).
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict) -> None:
|
||||
self._state = state
|
||||
budget = state.get(BUDGET_STATE_KEY) or {}
|
||||
self._daily_used_tokens = int(budget.get("daily_used_tokens", 0) or 0)
|
||||
self._weekly_buffer_used_tokens = int(
|
||||
budget.get("weekly_buffer_used_tokens", 0) or 0,
|
||||
)
|
||||
self._last_reset_date = budget.get("last_reset_date")
|
||||
self._host_disabled = bool(budget.get("host_disabled", False))
|
||||
self._disabled_reason = budget.get("host_disabled_reason")
|
||||
|
||||
# --- read helpers --------------------------------------------------------
|
||||
|
||||
def host_disabled_after_billing_event(self) -> bool:
|
||||
"""True if a prior call hit the bug #43333 tripwire and auto-disabled."""
|
||||
return self._host_disabled
|
||||
|
||||
def weekly_buffer_exceeded(self) -> bool:
|
||||
"""D-16 ceiling: 7% weekly buffer fully consumed."""
|
||||
weekly_cap = int(WEEKLY_BUFFER_PCT * ESTIMATED_DAILY_TOKEN_CEILING * 7)
|
||||
return self._weekly_buffer_used_tokens >= weekly_cap
|
||||
|
||||
def can_spend(self, estimated_tokens: int) -> bool:
|
||||
"""Pre-flight check: will this call fit in the daily cap, or (if
|
||||
overflowing) in the remaining weekly buffer? Returns False when
|
||||
Claude is auto-disabled or when neither ledger has room."""
|
||||
if self._host_disabled:
|
||||
return False
|
||||
daily_cap = int(DAILY_QUOTA_BUDGET_PCT * ESTIMATED_DAILY_TOKEN_CEILING)
|
||||
if self._daily_used_tokens + estimated_tokens <= daily_cap:
|
||||
return True
|
||||
weekly_cap = int(WEEKLY_BUFFER_PCT * ESTIMATED_DAILY_TOKEN_CEILING * 7)
|
||||
overflow = (self._daily_used_tokens + estimated_tokens) - daily_cap
|
||||
return self._weekly_buffer_used_tokens + overflow <= weekly_cap
|
||||
|
||||
# --- mutations -----------------------------------------------------------
|
||||
|
||||
def reset_if_new_day(self, now: datetime, tz) -> None:
|
||||
"""zero the daily counter at the user's LOCAL midnight. Any
|
||||
unused daily budget returns to the weekly buffer (capped at the
|
||||
weekly ceiling). Safe to call every tick -- it's a no-op until the
|
||||
local-date actually rolls."""
|
||||
today_local = now.astimezone(tz).date().isoformat()
|
||||
if self._last_reset_date == today_local:
|
||||
return
|
||||
daily_cap = int(DAILY_QUOTA_BUDGET_PCT * ESTIMATED_DAILY_TOKEN_CEILING)
|
||||
weekly_cap = int(WEEKLY_BUFFER_PCT * ESTIMATED_DAILY_TOKEN_CEILING * 7)
|
||||
unused_today = max(0, daily_cap - self._daily_used_tokens)
|
||||
self._weekly_buffer_used_tokens = max(
|
||||
0,
|
||||
min(
|
||||
weekly_cap,
|
||||
self._weekly_buffer_used_tokens - unused_today,
|
||||
),
|
||||
)
|
||||
self._daily_used_tokens = 0
|
||||
self._last_reset_date = today_local
|
||||
self._persist()
|
||||
|
||||
def record(self, tokens_in: int, tokens_out: int, now: datetime) -> None:
|
||||
"""Record the tokens spent on one `claude -p` call. Overflow past the
|
||||
daily cap spills into the weekly buffer; daily counter is then clamped
|
||||
at the cap so `can_spend` sees today as fully exhausted."""
|
||||
total = int(tokens_in) + int(tokens_out)
|
||||
daily_cap = int(DAILY_QUOTA_BUDGET_PCT * ESTIMATED_DAILY_TOKEN_CEILING)
|
||||
if self._daily_used_tokens + total <= daily_cap:
|
||||
self._daily_used_tokens += total
|
||||
else:
|
||||
overflow = (self._daily_used_tokens + total) - daily_cap
|
||||
self._daily_used_tokens = daily_cap
|
||||
self._weekly_buffer_used_tokens += overflow
|
||||
self._persist()
|
||||
|
||||
def disable_host(self, reason: str) -> None:
|
||||
"""Bug #43333 tripwire. Once fired, no further calls are allowed
|
||||
until explicit re-enable (requires user intervention via the morning
|
||||
digest which surfaces the event)."""
|
||||
self._host_disabled = True
|
||||
self._disabled_reason = str(reason)[:500]
|
||||
self._persist()
|
||||
|
||||
# --- persistence ---------------------------------------------------------
|
||||
|
||||
def _persist(self) -> None:
|
||||
self._state[BUDGET_STATE_KEY] = {
|
||||
"daily_used_tokens": self._daily_used_tokens,
|
||||
"weekly_buffer_used_tokens": self._weekly_buffer_used_tokens,
|
||||
"last_reset_date": self._last_reset_date,
|
||||
"host_disabled": self._host_disabled,
|
||||
"host_disabled_reason": self._disabled_reason,
|
||||
}
|
||||
save_state(self._state)
|
||||
|
||||
|
||||
# --------------------------------------------------------- subprocess invocation
|
||||
|
||||
|
||||
def _scrubbed_env() -> dict[str, str]:
|
||||
"""Return a copy of os.environ with the hostile keys removed.
|
||||
|
||||
ENV_DENY_LIST above is the single source of truth for the key names so
|
||||
the constitutional-guard grep test sees them in exactly one place.
|
||||
"""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in os.environ.items():
|
||||
if key in ENV_DENY_LIST:
|
||||
continue
|
||||
result[key] = value
|
||||
for hostile in ENV_DENY_LIST:
|
||||
result.pop(hostile, None)
|
||||
return result
|
||||
|
||||
|
||||
def _build_cmd(prompt: str, model: str) -> list[str]:
|
||||
"""Argv list for `claude -p`. Single list element for prompt -> no shell
|
||||
interpolation path."""
|
||||
return [
|
||||
"claude",
|
||||
"--bare",
|
||||
"-p",
|
||||
prompt,
|
||||
"--output-format",
|
||||
"json",
|
||||
"--max-turns",
|
||||
"1",
|
||||
"--tools",
|
||||
"",
|
||||
"--no-session-persistence",
|
||||
"--model",
|
||||
model,
|
||||
]
|
||||
|
||||
|
||||
async def _terminate_then_kill(proc, grace_sec: float) -> None:
|
||||
"""Cooperative shutdown: terminate(); wait `grace_sec`; kill() if still
|
||||
running. Never raises -- best-effort cleanup only."""
|
||||
try:
|
||||
if proc.returncode is None:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=grace_sec)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
return
|
||||
try:
|
||||
# Bound the post-kill wait so the scheduler always yields even
|
||||
# when the OS refuses to reap the child (zombie path).
|
||||
await asyncio.wait_for(proc.wait(), timeout=KILL_WAIT_SEC)
|
||||
except (asyncio.TimeoutError, Exception): # noqa: BLE001 -- best-effort
|
||||
pass
|
||||
|
||||
|
||||
async def invoke_host_once(
|
||||
prompt: str,
|
||||
*,
|
||||
model: str = "haiku",
|
||||
) -> dict:
|
||||
"""Spawn one `claude -p` subprocess, return a structured result dict.
|
||||
|
||||
Shape of the return value always includes ok, cost_usd, tokens_in,
|
||||
tokens_out so callers can sum budgets unconditionally. On ok=False,
|
||||
reason is one of:
|
||||
timeout | nonzero_exit | unparseable_output | api_billing_detected
|
||||
| force_wake_killed
|
||||
|
||||
Constitutional guarantees:
|
||||
- No shell expansion of `prompt` -- argv list only.
|
||||
- Hostile env keys scrubbed via ENV_DENY_LIST before spawn.
|
||||
- bug #43333: cost_usd > 0 triggers BudgetTracker.disable_host plus an
|
||||
error result. A second call then short-circuits at can_spend().
|
||||
"""
|
||||
env = _scrubbed_env()
|
||||
cmd = _build_cmd(prompt, model)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.DEVNULL,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=HOST_TIMEOUT_SEC,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await _terminate_then_kill(proc, TERMINATE_WAIT_SEC)
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "timeout",
|
||||
"exit_code": proc.returncode if proc.returncode is not None else -1,
|
||||
"cost_usd": 0.0,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
}
|
||||
except asyncio.CancelledError:
|
||||
# + Warning 8: force-wake arrived mid-call. Clean up subprocess,
|
||||
# return a structured error, do NOT re-raise. Re-raising would unwind
|
||||
# back into the daemon scheduler and potentially crash the event
|
||||
# loop; cooperative yield requires a normal return here.
|
||||
await _terminate_then_kill(proc, FORCE_WAKE_GRACE_SEC)
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "force_wake_killed",
|
||||
"cost_usd": 0.0,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
}
|
||||
|
||||
if proc.returncode != 0:
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "nonzero_exit",
|
||||
"exit_code": proc.returncode,
|
||||
"stderr": stderr.decode("utf-8", errors="replace")[:500],
|
||||
"cost_usd": 0.0,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
data = json.loads(stdout)
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "unparseable_output",
|
||||
"cost_usd": 0.0,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
}
|
||||
|
||||
cost_usd = float(data.get("cost_usd", 0.0) or 0.0)
|
||||
usage = data.get("usage") or {}
|
||||
tokens_in = int(usage.get("input_tokens", 0) or 0)
|
||||
tokens_out = int(usage.get("output_tokens", 0) or 0)
|
||||
|
||||
# Bug #43333 post-flight tripwire: a real subscription-mode Claude CLI
|
||||
# call MUST report cost_usd == 0. Anything else means the subscription
|
||||
# path was bypassed (billing would follow). Auto-disable future calls.
|
||||
if cost_usd > 0.0:
|
||||
try:
|
||||
state = load_state()
|
||||
BudgetTracker(state).disable_host(
|
||||
reason=f"api_billing_detected cost_usd={cost_usd}",
|
||||
)
|
||||
except Exception: # noqa: BLE001 -- tripwire must not re-raise
|
||||
pass
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "api_billing_detected",
|
||||
"cost_usd": cost_usd,
|
||||
"data": data,
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
}
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"data": data,
|
||||
"cost_usd": cost_usd,
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
}
|
||||
197
src/iai_mcp/identity_audit.py
Normal file
197
src/iai_mcp/identity_audit.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""Continuous S5 identity audit. Runs even when daemon is paused.
|
||||
|
||||
Wraps `s5.detect_drift_anomaly` + `sigma.compute_and_emit` on a 1-hour cadence.
|
||||
Both calls are MVCC reads (LanceDB handles concurrent readers natively), so
|
||||
this loop does NOT acquire the fcntl exclusive lock. That is the C6 invariant:
|
||||
the daemon continues to observe its own identity even when heavy consolidation
|
||||
is paused.
|
||||
|
||||
Phase 7.3 addition (D7.3-11): the same loop iteration also runs Lance
|
||||
storage maintenance (`optimize_lance_storage`) on a configurable cadence
|
||||
(default 1h via `LANCE_OPTIMIZE_INTERVAL_SEC`). The optimize body is gated
|
||||
by a `time.monotonic()` cooldown against the configured interval; the
|
||||
cooldown gate is silent when blocked (no event flooding).
|
||||
|
||||
Phase 10.6 Plan 10.6-01 Task 1.4: REMOVED the `_should_yield_to_mcp(socket)`
|
||||
HUMAN-FIRST gate. The lifecycle state machine + sleep_pipeline supersede
|
||||
this design — periodic optimize runs unconditionally once the cooldown
|
||||
passes; SLEEP-state coexistence is provided by the lifecycle predicate
|
||||
that gates SLEEP entry on `sleep_eligible`. The `socket` kwarg has been
|
||||
removed from `continuous_audit`'s signature.
|
||||
|
||||
Constitutional guard:
|
||||
- C6: S5 invariant audit runs read-only (MVCC) and does NOT acquire the
|
||||
process-wide exclusive lock. Grep-guarded by
|
||||
tests/test_constitutional_guards.py (C6 = no lock module imported here).
|
||||
- C3: ZERO paid-API cost. No reference to paid-API env var.
|
||||
- C5: literal preservation -- no writes to MemoryRecord.literal_surface.
|
||||
- Light daemon ops run concurrent with MCP via LanceDB MVCC; the audit
|
||||
path is exactly one such op.
|
||||
|
||||
Exception handling: each of the underlying calls is wrapped in its own
|
||||
try/except. Failures are emitted as `identity_audit_error` events with a
|
||||
`stage` discriminator ("s5" | "sigma") and the loop continues to the next
|
||||
tick. The Lance optimize step uses a separate try/except path because its
|
||||
helper already swallows per-table failures into the report dict (D7.3-09);
|
||||
the outer guard there only protects against event-write failure. The
|
||||
daemon must never die from an audit OR maintenance failure.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from iai_mcp import maintenance as _maintenance
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.maintenance import optimize_lance_storage
|
||||
from iai_mcp.s5 import detect_drift_anomaly
|
||||
from iai_mcp.sigma import compute_and_emit
|
||||
|
||||
# 1-hour cadence -- same granularity as sigma snapshot + S5 audit in S4 pass.
|
||||
AUDIT_INTERVAL_SEC: int = 60 * 60
|
||||
|
||||
# R2 / D7.3-14: timestamp of the most recent successful periodic
|
||||
# Lance optimize. Module-level mutable; the loop body declares
|
||||
# `global _last_optimize_completed_at` to write. Ephemeral by design --
|
||||
# daemon restart resets to 0.0 so the first periodic poll runs immediately
|
||||
# (the startup wire-in in daemon.main() already handled the boot-time bloat
|
||||
# collapse, so this just establishes the periodic cadence baseline).
|
||||
#
|
||||
# Mirrors Phase 7.2's _last_cascade_completed_at pattern in daemon.py
|
||||
# exactly (D7.2-03/D7.2-05): time.monotonic() not datetime.now() so the
|
||||
# cooldown is immune to clock skew + system suspend/resume.
|
||||
_last_optimize_completed_at: float = 0.0
|
||||
|
||||
|
||||
async def continuous_audit(
|
||||
store,
|
||||
shutdown: asyncio.Event,
|
||||
*,
|
||||
interval_sec: float | None = None,
|
||||
) -> None:
|
||||
"""Loop until `shutdown` is set.
|
||||
|
||||
On each tick: run S5 drift anomaly detection, then sigma topology
|
||||
snapshot, then gated Lance storage optimize. All three
|
||||
are independent: a failure in any one stage does not abort the others.
|
||||
The interval sleep is implemented via `asyncio.wait_for(shutdown.wait(),
|
||||
timeout=interval_sec)` so shutdown is responsive within a fraction of a
|
||||
second rather than having to wait a full hour.
|
||||
|
||||
When `interval_sec` is None we look up the current module-level
|
||||
`AUDIT_INTERVAL_SEC` at call time. This lets tests monkeypatch the
|
||||
constant before calling the function.
|
||||
|
||||
Plan 10.6-01 Task 1.4: REMOVED the `socket` kwarg + the
|
||||
`_should_yield_to_mcp(socket)` gate inside the periodic Lance
|
||||
optimize branch. SLEEP-state coexistence is now provided by the
|
||||
lifecycle state machine instead of an in-loop yield probe.
|
||||
|
||||
Args:
|
||||
store: MemoryStore instance.
|
||||
shutdown: asyncio.Event that breaks the loop when set.
|
||||
interval_sec: optional override for the per-tick sleep. Tests use
|
||||
small values (e.g. 0.05) to drive the loop quickly.
|
||||
"""
|
||||
# R2: explicit `global` so the assignment in the periodic body
|
||||
# updates module-level state, not a local binding. Mirrors the Pitfall 3
|
||||
# discipline from Phase 7.2's _hippea_cascade_loop.
|
||||
global _last_optimize_completed_at
|
||||
|
||||
while not shutdown.is_set():
|
||||
effective_interval: float = (
|
||||
float(interval_sec) if interval_sec is not None else float(AUDIT_INTERVAL_SEC)
|
||||
)
|
||||
# Stage 1: S5 drift anomaly detection (MVCC read).
|
||||
try:
|
||||
await asyncio.to_thread(detect_drift_anomaly, store, 5)
|
||||
except Exception as exc: # noqa: BLE001 -- daemon must never die
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
write_event,
|
||||
store,
|
||||
"identity_audit_error",
|
||||
{"stage": "s5", "error": str(exc)[:500]},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
# Even the event write failed -- swallow silently so the loop
|
||||
# can continue. Next tick gets a fresh chance.
|
||||
pass
|
||||
|
||||
# Stage 2: sigma topology snapshot + emit (MVCC read).
|
||||
try:
|
||||
await asyncio.to_thread(compute_and_emit, store)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
write_event,
|
||||
store,
|
||||
"identity_audit_error",
|
||||
{"stage": "sigma", "error": str(exc)[:500]},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Stage 3 (Phase 7.3 R2/R3): gated periodic Lance storage optimize.
|
||||
# Plan 10.6-01 Task 1.4 simplified: single gate
|
||||
# (interval cooldown). The D7.3-11 MCP-active yield
|
||||
# gate via `_should_yield_to_mcp(socket)` was removed; the
|
||||
# lifecycle state machine handles SLEEP-state coexistence
|
||||
# outside this loop.
|
||||
try:
|
||||
# Access the module attribute at call time (not at import time)
|
||||
# so test fixtures can monkeypatch
|
||||
# `maintenance.LANCE_OPTIMIZE_INTERVAL_SEC` and observe the new
|
||||
# value without needing `importlib.reload(identity_audit)`.
|
||||
interval_sec_now = _maintenance.LANCE_OPTIMIZE_INTERVAL_SEC
|
||||
retention_sec_now = _maintenance.LANCE_OPTIMIZE_RETENTION_SEC
|
||||
elapsed_since_last = time.monotonic() - _last_optimize_completed_at
|
||||
if elapsed_since_last < interval_sec_now:
|
||||
# D7.3-19: silent skip -- no event. The cooldown gates
|
||||
# work, it does not consume a ledger slot.
|
||||
pass
|
||||
else:
|
||||
periodic_t0 = time.monotonic()
|
||||
try:
|
||||
periodic_report = await asyncio.to_thread(
|
||||
optimize_lance_storage, store,
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
write_event,
|
||||
store,
|
||||
"lance_storage_optimized",
|
||||
{
|
||||
"phase": "periodic",
|
||||
"retention_days": (
|
||||
retention_sec_now / 86400.0
|
||||
),
|
||||
"per_table": periodic_report,
|
||||
"total_elapsed_sec": round(
|
||||
time.monotonic() - periodic_t0, 3,
|
||||
),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# D7.3-14: stamp completion timestamp regardless of
|
||||
# success/exception so a failed optimize still gates
|
||||
# the next run by LANCE_OPTIMIZE_INTERVAL_SEC.
|
||||
_last_optimize_completed_at = time.monotonic()
|
||||
except Exception:
|
||||
# Outer defense-in-depth: a bug in the gate logic itself must
|
||||
# not crash the audit loop (C6 invariant: the daemon must
|
||||
# continue observing its own identity even when maintenance
|
||||
# work fails). Same discipline as the S5/sigma stages above.
|
||||
pass
|
||||
|
||||
# Shutdown-responsive sleep: return early if shutdown fires.
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=effective_interval)
|
||||
break # shutdown fired mid-sleep
|
||||
except asyncio.TimeoutError:
|
||||
continue # normal path: time for next audit tick
|
||||
342
src/iai_mcp/idle_detector.py
Normal file
342
src/iai_mcp/idle_detector.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
"""Phase 10.4 L6 — hardware-aware idle detector for the wake/sleep cycle.
|
||||
|
||||
Combines three hardware-grounded signals into a single ``sleep_eligible``
|
||||
predicate the daemon's state machine consumes when deciding whether to
|
||||
transition into a sleep cycle:
|
||||
|
||||
1. **Heartbeat-idle (30 min):** no FRESH wrapper heartbeats in the last 30
|
||||
minutes — supplied externally by ``HeartbeatScanner.heartbeat_idle_30min``.
|
||||
2. **HIDIdleTime:** ``ioreg -c IOHIDSystem`` exposes nanoseconds since the
|
||||
last user input event. Convert ns→sec, compare against ``≥ 30 min``.
|
||||
3. **pmset events:** macOS power-manager log entries for ``System Sleep`` or
|
||||
``Display is turned off`` within the last ``window_min`` minutes.
|
||||
|
||||
``sleep_eligible`` is the **disjunction** of the three: any one signal is
|
||||
sufficient. This matches the proposal v2 §2 L6 rule — there is no
|
||||
wall-clock fallback, only hardware-grounded evidence of inactivity.
|
||||
|
||||
Hard constraints (carried from CONTEXT 10.4):
|
||||
- ALL subprocess calls use array form ``[bin, arg, ...]`` with
|
||||
``shell=False`` and a finite ``timeout``. NEVER ``shell=True``. NEVER
|
||||
f-string interpolation into command strings.
|
||||
- Idle CPU near zero — this module is invoked on lifecycle TICK (every 30 s),
|
||||
not faster. ``pmset -g log`` can be slow (≈1 s) so we tail the last 200
|
||||
lines of output rather than re-parsing the entire log.
|
||||
- macOS-only: ``ioreg`` and ``pmset`` are macOS binaries. On non-macOS the
|
||||
detector returns ``None`` / ``False`` gracefully — cross-platform support
|
||||
is deferred per proposal v2 §6.6.
|
||||
- No new third-party dependencies — stdlib only.
|
||||
|
||||
Validates: WAKE-09.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
# Module-level constants -------------------------------------------------------
|
||||
|
||||
#: Absolute path to the macOS ``ioreg`` binary. Hard-coded to avoid PATH-based
|
||||
#: hijacks (a planted ``ioreg`` in the user's PATH could feed us spoofed
|
||||
#: HIDIdleTime values that would falsely keep the daemon awake or asleep).
|
||||
_IOREG_BIN = "/usr/sbin/ioreg"
|
||||
|
||||
#: Absolute path to the macOS ``pmset`` binary. Same PATH-hijack rationale.
|
||||
_PMSET_BIN = "/usr/bin/pmset"
|
||||
|
||||
#: Subprocess timeout for ``ioreg`` (seconds). The call is a straight kernel
|
||||
#: registry dump and returns within ~50 ms on a healthy system; a 5 s ceiling
|
||||
#: keeps a hung kernel-extension probe from blocking the lifecycle TICK.
|
||||
_IOREG_TIMEOUT_SEC = 5
|
||||
|
||||
#: Subprocess timeout for ``pmset -g log``. ``pmset`` walks the system power
|
||||
#: log and on a long-uptime machine can take ~1 s; 10 s ceiling.
|
||||
_PMSET_TIMEOUT_SEC = 10
|
||||
|
||||
#: Number of trailing lines to scan from ``pmset -g log``. The log is
|
||||
#: append-only and ordered by time, so the most-recent events are at the end.
|
||||
#: 200 lines covers ~last 24 h on a typical workstation; the window check
|
||||
#: filters by timestamp anyway.
|
||||
_PMSET_TAIL_LINES = 200
|
||||
|
||||
#: Regex for the HIDIdleTime line. Format: ``"HIDIdleTime" = 12345678901``.
|
||||
_HID_IDLE_RE = re.compile(r'"HIDIdleTime"\s*=\s*(\d+)')
|
||||
|
||||
#: Substrings that indicate a sleep / display-off event in pmset log output.
|
||||
_PMSET_SLEEP_MARKERS = ("System Sleep", "Display is turned off")
|
||||
|
||||
#: Default window for ``pmset_recent_sleep`` (minutes). Aligned with the
|
||||
#: proposal v2 §2 L6 wording: "in last 5 min".
|
||||
_PMSET_DEFAULT_WINDOW_MIN = 5
|
||||
|
||||
#: Hardware-idle threshold for the disjunction in ``sleep_eligible`` —
|
||||
#: ``HIDIdleTime ≥ 30 min`` is sufficient evidence of user inactivity.
|
||||
_HID_IDLE_THRESHOLD_SEC = 30 * 60
|
||||
|
||||
#: Regex anchoring a pmset log line's leading timestamp. The format is
|
||||
#: ``YYYY-MM-DD HH:MM:SS ±HHMM`` (e.g. ``2026-05-02 15:00:00 -0400``).
|
||||
_PMSET_TS_RE = re.compile(
|
||||
r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})\s+([+-]\d{4})"
|
||||
)
|
||||
|
||||
#: Strptime pattern for the timestamp captured by ``_PMSET_TS_RE``.
|
||||
_PMSET_TS_FMT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
# Public dataclass -------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdleStatus:
|
||||
"""Snapshot of the L6 detector for the doctor row (n) display.
|
||||
|
||||
Attributes:
|
||||
hid_idle_sec: Seconds since last user input, or ``None`` if ``ioreg``
|
||||
is unavailable or its output cannot be parsed.
|
||||
pmset_recent_sleep: True iff a System / Display Sleep event was seen
|
||||
within the configured window. False on parse failure or missing
|
||||
tool — biased toward "no recent sleep" so the doctor row reports
|
||||
a clean state rather than a false-positive sleep.
|
||||
available_signals: Subset of ``["HIDIdleTime", "pmset"]`` listing
|
||||
which hardware sources actually returned data on this probe.
|
||||
Empty list means we have no hardware grounding right now and
|
||||
the L6 disjunction must rely on the heartbeat-idle signal.
|
||||
"""
|
||||
|
||||
hid_idle_sec: int | None = None
|
||||
pmset_recent_sleep: bool = False
|
||||
available_signals: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# IdleDetector -----------------------------------------------------------------
|
||||
|
||||
|
||||
class IdleDetector:
|
||||
"""Hardware-grounded idle probe for the daemon state machine.
|
||||
|
||||
Standalone module — wires this into the daemon's TICK so
|
||||
``sleep_eligible`` gates the BEDTIME transition. Each public method
|
||||
can be called independently; ``status()`` aggregates them for the
|
||||
doctor row.
|
||||
"""
|
||||
|
||||
# ---- HIDIdleTime via ioreg --------------------------------------
|
||||
|
||||
def hid_idle_time_sec(self) -> int | None:
|
||||
"""Return seconds since last HID input, or ``None`` on any failure.
|
||||
|
||||
Spawns ``/usr/sbin/ioreg -c IOHIDSystem`` (array form, ``shell=False``,
|
||||
5 s timeout, ``check=False``). Parses the first ``"HIDIdleTime" =
|
||||
<ns>`` match and integer-divides by 1e9. Any error path — missing
|
||||
tool, non-zero exit, parse miss, timeout — collapses to ``None`` so
|
||||
the caller treats the signal as absent rather than zero (zero would
|
||||
falsely imply "active right now").
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_IOREG_BIN, "-c", "IOHIDSystem"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=_IOREG_TIMEOUT_SEC,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except subprocess.TimeoutExpired:
|
||||
return None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
match = _HID_IDLE_RE.search(result.stdout or "")
|
||||
if match is None:
|
||||
return None
|
||||
try:
|
||||
ns = int(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
if ns < 0:
|
||||
return None
|
||||
return ns // 1_000_000_000
|
||||
|
||||
# ---- pmset event detection --------------------------------------
|
||||
|
||||
def pmset_recent_sleep(
|
||||
self, window_min: int = _PMSET_DEFAULT_WINDOW_MIN
|
||||
) -> bool:
|
||||
"""True iff a System/Display Sleep event was recorded in the window.
|
||||
|
||||
Spawns ``/usr/bin/pmset -g log`` (array form, ``shell=False``, 10 s
|
||||
timeout, ``check=False``). Tails the last ``_PMSET_TAIL_LINES``
|
||||
lines of stdout, parses the leading timestamp, and reports True if
|
||||
any line within ``window_min`` minutes of "now" contains one of the
|
||||
``_PMSET_SLEEP_MARKERS`` substrings.
|
||||
|
||||
Failure modes (missing tool, non-zero exit, no parseable lines) all
|
||||
collapse to ``False`` — biased toward "no recent sleep" so an
|
||||
unavailable signal does not trigger the L6 disjunction on its own.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_PMSET_BIN, "-g", "log"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=_PMSET_TIMEOUT_SEC,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
if result.returncode != 0:
|
||||
return False
|
||||
|
||||
return self._scan_pmset_lines(result.stdout or "", window_min)
|
||||
|
||||
@staticmethod
|
||||
def _scan_pmset_lines(stdout: str, window_min: int) -> bool:
|
||||
"""Helper — pure-function scan over pmset log text.
|
||||
|
||||
Split out for unit testing without subprocess mocking. Walks the
|
||||
last ``_PMSET_TAIL_LINES`` lines, returns True at the first match
|
||||
within the window. Parse failures on individual lines are skipped.
|
||||
"""
|
||||
if window_min <= 0:
|
||||
return False
|
||||
# Build a UTC "now" once; pmset timestamps come with explicit ±HHMM
|
||||
# offsets so we convert each parsed timestamp to UTC for comparison.
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
cutoff = now_utc - timedelta(minutes=window_min)
|
||||
|
||||
# Tail the last N lines so we don't re-scan a multi-megabyte log.
|
||||
lines = stdout.splitlines()
|
||||
tail = lines[-_PMSET_TAIL_LINES:] if len(lines) > _PMSET_TAIL_LINES else lines
|
||||
|
||||
for line in tail:
|
||||
if not any(marker in line for marker in _PMSET_SLEEP_MARKERS):
|
||||
continue
|
||||
ts = _parse_pmset_timestamp(line)
|
||||
if ts is None:
|
||||
continue
|
||||
if ts >= cutoff:
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---- Disjunction predicate consumed by the state machine --------
|
||||
|
||||
def sleep_eligible(self, heartbeat_idle_30min: bool) -> bool:
|
||||
"""L6 disjunction: any of three hardware-grounded signals is sufficient.
|
||||
|
||||
Args:
|
||||
heartbeat_idle_30min: True iff no FRESH wrapper heartbeat in the
|
||||
last 30 minutes (supplied by
|
||||
``HeartbeatScanner.heartbeat_idle_30min``).
|
||||
|
||||
Returns:
|
||||
``heartbeat_idle_30min OR (hid_idle_time_sec ≥ 30 min) OR
|
||||
pmset_recent_sleep()``. Short-circuits on the first True so a
|
||||
heartbeat-idle session does not pay for ``ioreg`` + ``pmset``
|
||||
spawns it does not need.
|
||||
"""
|
||||
if heartbeat_idle_30min:
|
||||
return True
|
||||
|
||||
hid_idle = self.hid_idle_time_sec()
|
||||
if hid_idle is not None and hid_idle >= _HID_IDLE_THRESHOLD_SEC:
|
||||
return True
|
||||
|
||||
return self.pmset_recent_sleep()
|
||||
|
||||
# ---- Aggregated snapshot for doctor row (n) ---------------------
|
||||
|
||||
def status(self) -> IdleStatus:
|
||||
"""Return an ``IdleStatus`` snapshot for the doctor checklist.
|
||||
|
||||
Calls both probes regardless of disjunction short-circuit so the
|
||||
doctor surface always reflects the *actual* per-signal availability
|
||||
(a doctor that hides ``pmset`` whenever ``HIDIdleTime`` already
|
||||
triggers would not help the user diagnose a missing pmset log).
|
||||
"""
|
||||
hid_idle = self.hid_idle_time_sec()
|
||||
pmset_seen = self.pmset_recent_sleep()
|
||||
|
||||
signals: list[str] = []
|
||||
if hid_idle is not None:
|
||||
signals.append("HIDIdleTime")
|
||||
# pmset_recent_sleep returning False does not imply pmset is missing
|
||||
# — it only means no event in the window. We can't reliably tell
|
||||
# "tool present but quiet" from "tool absent" without re-spawning,
|
||||
# so we bias the doctor display toward listing pmset as available
|
||||
# whenever the call succeeded (i.e. did not raise / non-zero-exit).
|
||||
if _pmset_responsive():
|
||||
signals.append("pmset")
|
||||
|
||||
return IdleStatus(
|
||||
hid_idle_sec=hid_idle,
|
||||
pmset_recent_sleep=pmset_seen,
|
||||
available_signals=signals,
|
||||
)
|
||||
|
||||
|
||||
# Module-private helpers -------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_pmset_timestamp(line: str) -> datetime | None:
|
||||
"""Return the leading timestamp of a pmset log line as UTC, or None.
|
||||
|
||||
Matches ``YYYY-MM-DD HH:MM:SS ±HHMM`` at the start of the line. The
|
||||
``±HHMM`` offset is parsed manually because ``%z`` on older Python
|
||||
builds is finicky with shorthand offsets — we apply the offset to a
|
||||
naive datetime and tag it as UTC.
|
||||
"""
|
||||
m = _PMSET_TS_RE.match(line)
|
||||
if m is None:
|
||||
return None
|
||||
ts_str, offset_str = m.group(1), m.group(2)
|
||||
try:
|
||||
naive = datetime.strptime(ts_str, _PMSET_TS_FMT)
|
||||
except ValueError:
|
||||
return None
|
||||
sign = 1 if offset_str[0] == "+" else -1
|
||||
try:
|
||||
hours = int(offset_str[1:3])
|
||||
minutes = int(offset_str[3:5])
|
||||
except ValueError:
|
||||
return None
|
||||
offset = timedelta(hours=hours, minutes=minutes) * sign
|
||||
# Treat naive timestamp as in the offset's local zone, then convert to
|
||||
# UTC by subtracting the offset.
|
||||
return (naive - offset).replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _pmset_responsive() -> bool:
|
||||
"""Probe whether ``/usr/bin/pmset`` exists and exits 0 for a trivial call.
|
||||
|
||||
Used by ``IdleDetector.status`` to populate ``available_signals``
|
||||
without inferring availability from the (legitimate) "no recent sleep"
|
||||
output. ``pmset -g`` (no subcommand) prints the current power state
|
||||
and exits 0 quickly; missing-binary or non-zero-exit ⇒ unavailable.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_PMSET_BIN, "-g"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=_PMSET_TIMEOUT_SEC,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
return result.returncode == 0
|
||||
267
src/iai_mcp/insight.py
Normal file
267
src/iai_mcp/insight.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
"""Lucid moment orchestration -- (D-13 Option A).
|
||||
|
||||
The "main insight of the day": exactly ONE `claude -p` subprocess call per
|
||||
night, at the end of the last REM cycle. The prompt is built from 3 locally-
|
||||
extracted schema patterns + 1 surprising episode; Claude distils them into a
|
||||
single unifying insight of 1-2 sentences which we store as a semantic-tier
|
||||
record tagged `overnight_insight`.
|
||||
|
||||
Constitutional guards:
|
||||
- LOCAL is the primary worker. This module owns the single surgical
|
||||
Claude call; all other consolidation work is pure-numpy/NetworkX/TF-IDF.
|
||||
- the call goes through host_cli.invoke_host_once which scrubs
|
||||
the paid-API env var and validates the credentials.json subscription mode
|
||||
before spawning the subprocess. This module NEVER references the paid-API
|
||||
env var by name.
|
||||
- pre-flight budget gate via BudgetTracker.can_spend. A call that
|
||||
would exceed the daily cap (overflow into weekly buffer) is silently
|
||||
skipped, queued implicitly for the next night.
|
||||
- Bug #43333: cost_usd > 0 from invoke_host_once is recorded by the wrapper
|
||||
(BudgetTracker.disable_host). This module short-circuits on host_disabled
|
||||
so the bad call never repeats.
|
||||
- / C5: the inserted MemoryRecord is assembled once from Claude's
|
||||
text response; we do NOT rewrite literal_surface after insert.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from iai_mcp.host_cli import (
|
||||
BudgetTracker,
|
||||
invoke_host_once,
|
||||
verify_credentials_subscription,
|
||||
)
|
||||
from iai_mcp.daemon_state import load_state
|
||||
from iai_mcp.events import query_events, write_event
|
||||
from iai_mcp.schema import induce_schemas_tier0
|
||||
from iai_mcp.tz import load_user_tz
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
# Option A prompt template. The fragments "3 locally-found patterns",
|
||||
# "1 surprising episode", "unifying insight", and "1-2 sentences" are verbatim
|
||||
# per the locked decision; grep tests assert they appear unmodified.
|
||||
INSIGHT_PROMPT_TEMPLATE: str = (
|
||||
"Here are 3 locally-found patterns from today + 1 surprising episode. "
|
||||
"What is the unifying insight? Reply in 1-2 sentences.\n\n"
|
||||
"Patterns:\n{patterns}\n\n"
|
||||
"Surprise:\n{surprise}"
|
||||
)
|
||||
|
||||
# Conservative pre-flight token estimate for the one nightly call -- covers
|
||||
# the prompt frame + patterns + surprise payload. Actual spend is recorded
|
||||
# post-call via BudgetTracker.record(tokens_in, tokens_out).
|
||||
PROMPT_ESTIMATE_TOKENS: int = 500
|
||||
|
||||
# Kinds of events considered "surprising" for the prompt.
|
||||
_SURPRISE_KINDS: frozenset[str] = frozenset({
|
||||
"art_gate_high_novelty",
|
||||
"contradiction_detected",
|
||||
"s4_contradiction",
|
||||
"s5_drift",
|
||||
})
|
||||
|
||||
|
||||
def _gather_patterns(store) -> list[str]:
|
||||
"""Top-3 recent schema candidates by confidence. Graceful on empty."""
|
||||
try:
|
||||
schemas = induce_schemas_tier0(store) or []
|
||||
except Exception: # noqa: BLE001 -- pattern extraction must never crash insight
|
||||
schemas = []
|
||||
|
||||
def _conf(s: Any) -> float:
|
||||
# SchemaCandidate has .confidence; dicts may use the same key.
|
||||
val = getattr(s, "confidence", None)
|
||||
if val is None and isinstance(s, dict):
|
||||
val = s.get("confidence")
|
||||
try:
|
||||
return float(val or 0.0)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
def _text(s: Any) -> str:
|
||||
# SchemaCandidate exposes .pattern; dicts use "pattern" / "description".
|
||||
for attr in ("pattern", "description", "summary"):
|
||||
val = getattr(s, attr, None)
|
||||
if val:
|
||||
return str(val)
|
||||
if isinstance(s, dict) and s.get(attr):
|
||||
return str(s[attr])
|
||||
return str(s)
|
||||
|
||||
schemas_sorted = sorted(schemas, key=_conf, reverse=True)
|
||||
top3 = schemas_sorted[:3]
|
||||
if not top3:
|
||||
return ["[no patterns yet]"]
|
||||
return [_text(s) for s in top3]
|
||||
|
||||
|
||||
def _gather_surprise(store) -> str:
|
||||
"""Most recent surprising event over the last 24h. Graceful on empty."""
|
||||
try:
|
||||
since = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0,
|
||||
)
|
||||
candidates = query_events(store, since=since, limit=1000) or []
|
||||
except Exception: # noqa: BLE001 -- event query must never crash insight
|
||||
candidates = []
|
||||
|
||||
for event in candidates:
|
||||
if event.get("kind") in _SURPRISE_KINDS:
|
||||
data = event.get("data") or event
|
||||
return str(data)[:500]
|
||||
return "[no surprise yet]"
|
||||
|
||||
|
||||
async def generate_overnight_insight(store, session_id: str) -> dict:
|
||||
"""Orchestrate the Option A Claude call.
|
||||
|
||||
Returns a structured dict. Shape (always present): ok (bool), reason
|
||||
(str | None), text (str | None). Success result also carries
|
||||
tokens_in / tokens_out for the caller's bookkeeping.
|
||||
|
||||
Pre-flight gate sequence (every one MUST pass before spawning subprocess):
|
||||
1. verify_credentials_subscription (bug #43333 layer 2)
|
||||
2. BudgetTracker.host_disabled_after_billing_event (bug #43333 layer 3)
|
||||
3. BudgetTracker.can_spend(PROMPT_ESTIMATE_TOKENS) (D-15 budget)
|
||||
"""
|
||||
creds = verify_credentials_subscription()
|
||||
if not creds.get("ok"):
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "credentials_check_failed",
|
||||
"text": None,
|
||||
"details": creds,
|
||||
}
|
||||
|
||||
state = load_state()
|
||||
tracker = BudgetTracker(state)
|
||||
|
||||
try:
|
||||
tz = load_user_tz()
|
||||
except Exception: # noqa: BLE001 -- tz lookup never crashes the call path
|
||||
tz = timezone.utc # naive fallback; reset_if_new_day handles both
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
tracker.reset_if_new_day(now, tz)
|
||||
|
||||
if tracker.host_disabled_after_billing_event():
|
||||
return {"ok": False, "reason": "host_disabled_c3", "text": None}
|
||||
|
||||
if not tracker.can_spend(PROMPT_ESTIMATE_TOKENS):
|
||||
return {"ok": False, "reason": "budget_exceeded", "text": None}
|
||||
|
||||
patterns = _gather_patterns(store)
|
||||
surprise = _gather_surprise(store)
|
||||
prompt = INSIGHT_PROMPT_TEMPLATE.format(
|
||||
patterns="\n".join(f"- {p}" for p in patterns),
|
||||
surprise=surprise,
|
||||
)
|
||||
|
||||
result = await invoke_host_once(prompt, model="haiku")
|
||||
|
||||
# Record any tokens the call actually spent (host_cli returns tokens
|
||||
# even on non-ok paths when the subprocess completed).
|
||||
tokens_in = int(result.get("tokens_in", 0) or 0)
|
||||
tokens_out = int(result.get("tokens_out", 0) or 0)
|
||||
if tokens_in + tokens_out > 0:
|
||||
tracker.record(tokens_in, tokens_out, now)
|
||||
|
||||
if not result.get("ok"):
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": result.get("reason", "claude_call_failed"),
|
||||
"text": None,
|
||||
"details": {k: v for k, v in result.items() if k != "data"},
|
||||
}
|
||||
|
||||
data = result.get("data") or {}
|
||||
insight_text = str(data.get("result", "")).strip()
|
||||
if not insight_text:
|
||||
return {"ok": False, "reason": "empty_insight", "text": None}
|
||||
|
||||
# Build the L1-tier record. MemoryRecord requires a large
|
||||
# set of fields per schema; we default every non-essential field
|
||||
# to a neutral value so the shield/crypto pipeline treats the insight as
|
||||
# a plain semantic record subject to S4/S5 on-read contradiction.
|
||||
embed_dim = getattr(store, "embed_dim", None) or 384
|
||||
record = MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="semantic",
|
||||
literal_surface=insight_text,
|
||||
aaak_index="",
|
||||
embedding=[0.0] * int(embed_dim),
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[{
|
||||
"ts": now.isoformat(),
|
||||
"cue": "overnight_insight",
|
||||
"session_id": session_id,
|
||||
}],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["overnight_insight"],
|
||||
language="en", # the prompt is English-framed; insight is English.
|
||||
)
|
||||
# Dataclass has `tags` (list) not `tag` (scalar); we also expose `tag`
|
||||
# via attribute assignment for callers that prefer the scalar form. This
|
||||
# is NOT a literal_surface mutation so it does not violate C5 MEM-01.
|
||||
try:
|
||||
object.__setattr__(record, "tag", "overnight_insight")
|
||||
except Exception: # noqa: BLE001 -- attribute attach is best-effort
|
||||
pass
|
||||
|
||||
try:
|
||||
# R4 (researcher finding #3): wrap bare-sync store.insert
|
||||
# to avoid blocking the asyncio event loop. Reached from
|
||||
# dream.run_rem_cycle when claude_enabled=True (last cycle of REM).
|
||||
# store.insert touches LanceDB write + encryption — not safe-fast.
|
||||
await asyncio.to_thread(store.insert, record)
|
||||
except Exception as exc: # noqa: BLE001 -- store errors must not crash daemon
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
"overnight_insight_store_error",
|
||||
{"error": str(exc)[:500]},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"ok": False,
|
||||
"reason": "store_insert_failed",
|
||||
"text": insight_text,
|
||||
"error": str(exc)[:500],
|
||||
}
|
||||
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
"overnight_insight_generated",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"text_len": len(insight_text),
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
},
|
||||
)
|
||||
except Exception: # noqa: BLE001 -- event emission failure is non-fatal
|
||||
pass
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"text": insight_text,
|
||||
"reason": None,
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
}
|
||||
166
src/iai_mcp/learn.py
Normal file
166
src/iai_mcp/learn.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""Learning layer (LEARN-01/02/05/06, Task 2).
|
||||
|
||||
Four mechanisms live here:
|
||||
|
||||
1. LEARN-01 (Bayesian profile update) is implemented in `iai_mcp.profile`
|
||||
as `bayesian_update`; this module re-exports the RetrievalFeedback and
|
||||
policy utilities used by the pipeline + core dispatch.
|
||||
|
||||
2. LEARN-02 retrieval-policy RL -- simple tabular gradient on score
|
||||
weights. Feedback sources:
|
||||
- user acted on hit (used) -> boost W_COSINE
|
||||
- user issued contradict (corrected) -> reduce W_COSINE
|
||||
- user re-asked same cue (re_asked) -> reduce W_COSINE
|
||||
|
||||
3. LEARN-05 meta-learning -- ε-greedy bandit over retrieval strategies
|
||||
keyed by query type.
|
||||
|
||||
4. LEARN-06 identity refinement -- reads s5_invariant_update /
|
||||
s5_invariant_proposal events and drifts s5_trust_score up for
|
||||
consistently-agreeing anchors, down for frequently-rejected ones.
|
||||
|
||||
All writes go through the D-STORAGE events table; no .jsonl files.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from iai_mcp.events import query_events
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- constants
|
||||
|
||||
LEARN_RATE: float = 0.05
|
||||
MAX_WEIGHT: float = 5.0
|
||||
MIN_WEIGHT: float = 0.0
|
||||
EPSILON_EXPLORE: float = 0.1 # LEARN-05 bandit exploration probability
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- feedback
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalFeedback:
|
||||
"""Implicit feedback signal on a memory_recall response."""
|
||||
|
||||
query_type: str # e.g. "fact_lookup" | "open_ended" | "contradiction_check"
|
||||
hit_ids: list[UUID]
|
||||
used_ids: list[UUID] = field(default_factory=list)
|
||||
corrected: bool = False # user issued memory_contradict on a hit
|
||||
re_asked: bool = False # user re-issued the same cue within 5 turns
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- LEARN-02
|
||||
|
||||
|
||||
def update_retrieval_weights(
|
||||
feedback: RetrievalFeedback,
|
||||
current_weights: dict[str, float],
|
||||
) -> dict[str, float]:
|
||||
"""LEARN-02 tabular gradient on score weights.
|
||||
|
||||
Primary signal: use-rate = |used_ids ∩ hit_ids| / |hit_ids|.
|
||||
delta = (use_rate - 0.5) * LEARN_RATE
|
||||
Correction penalty: -LEARN_RATE
|
||||
Re-ask penalty: -LEARN_RATE * 0.5
|
||||
|
||||
All weights clamped to [MIN_WEIGHT, MAX_WEIGHT].
|
||||
Returns a new dict (does not mutate the input).
|
||||
"""
|
||||
w = dict(current_weights)
|
||||
delta = 0.0
|
||||
if feedback.hit_ids:
|
||||
hits_set = set(feedback.hit_ids)
|
||||
used_set = set(feedback.used_ids)
|
||||
use_rate = len(hits_set & used_set) / len(feedback.hit_ids)
|
||||
delta = (use_rate - 0.5) * LEARN_RATE
|
||||
if feedback.corrected:
|
||||
delta -= LEARN_RATE
|
||||
if feedback.re_asked:
|
||||
delta -= LEARN_RATE * 0.5
|
||||
|
||||
w_cos = w.get("W_COSINE", 1.0)
|
||||
w["W_COSINE"] = max(MIN_WEIGHT, min(MAX_WEIGHT, w_cos + delta))
|
||||
|
||||
# Clamp other weights in case of external mutation.
|
||||
for k in ("W_AAAK", "W_DEGREE", "W_AGE"):
|
||||
if k in w:
|
||||
w[k] = max(MIN_WEIGHT, min(MAX_WEIGHT, w[k]))
|
||||
return w
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- LEARN-05
|
||||
|
||||
|
||||
def pick_retrieval_strategy(
|
||||
query_type: str,
|
||||
history: dict,
|
||||
strategies: list[str] | None = None,
|
||||
) -> str:
|
||||
"""ε-greedy bandit over retrieval strategies per query type.
|
||||
|
||||
`history` shape:
|
||||
{
|
||||
"<query_type>": {
|
||||
"<strategy>": {"mean": float, "n": int},
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
Returns the strategy with the highest mean for this query_type except on
|
||||
the ε fraction of calls where a random strategy is explored.
|
||||
"""
|
||||
strategies = strategies or ["pipeline_default", "greedy_2hop", "rich_club_first"]
|
||||
if random.random() < EPSILON_EXPLORE:
|
||||
return random.choice(strategies)
|
||||
rewards = history.get(query_type, {})
|
||||
if not rewards:
|
||||
return strategies[0]
|
||||
return max(
|
||||
strategies,
|
||||
key=lambda s: rewards.get(s, {}).get("mean", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- LEARN-06
|
||||
|
||||
|
||||
TRUST_INCREMENT_PER_COMMIT: float = 0.02
|
||||
TRUST_DECREMENT_PER_REJECT: float = 0.01
|
||||
|
||||
|
||||
def refine_s5_trust_score(
|
||||
store: MemoryStore,
|
||||
record_id: UUID,
|
||||
current: float,
|
||||
) -> float:
|
||||
"""LEARN-06: trust score drifts based on consensus history.
|
||||
|
||||
+TRUST_INCREMENT per s5_invariant_update event with agree_count >= 3
|
||||
-TRUST_DECREMENT per s5_invariant_proposal with passes_vigilance == False
|
||||
|
||||
Clamped to [0, 1].
|
||||
"""
|
||||
updates = query_events(store, kind="s5_invariant_update", limit=200)
|
||||
commits = sum(
|
||||
1 for e in updates
|
||||
if e["data"].get("anchor_id") == str(record_id)
|
||||
and int(e["data"].get("agree_count", 0)) >= 3
|
||||
)
|
||||
rejects_events = query_events(store, kind="s5_invariant_proposal", limit=500)
|
||||
rejects = sum(
|
||||
1 for e in rejects_events
|
||||
if e["data"].get("anchor_id") == str(record_id)
|
||||
and not e["data"].get("passes_vigilance", True)
|
||||
)
|
||||
new_score = (
|
||||
current
|
||||
+ TRUST_INCREMENT_PER_COMMIT * commits
|
||||
- TRUST_DECREMENT_PER_REJECT * rejects
|
||||
)
|
||||
return max(0.0, min(1.0, new_score))
|
||||
336
src/iai_mcp/lifecycle.py
Normal file
336
src/iai_mcp/lifecycle.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
"""Phase 10.1 -- Lifecycle State Machine + Shadow-Run Mode.
|
||||
|
||||
Realises LOCKED contracts L1 (hibernation depth: kill process) and
|
||||
L2 (state authority: daemon-only writer for `lifecycle_state.json`).
|
||||
|
||||
The four lifecycle states (WAKE, DROWSY, SLEEP, HIBERNATION) form a
|
||||
deterministic FSM. Transitions are pure functions of the current state
|
||||
and the dispatched event (with optional payload guards); side effects
|
||||
(persistence + event-log append + shadow-run warning) happen ONLY in
|
||||
`dispatch`.
|
||||
|
||||
Phase 10.6 Plan 10.6-01 Task 1.6: flipped `shadow_run` default from
|
||||
True to False. HIBERNATION transitions now actually exit the daemon
|
||||
process via the global shutdown event in `daemon.main()`'s lifecycle
|
||||
tick. The legacy `_rss_watchdog_loop` was removed in Task 1.4; this
|
||||
state machine is the sole owner of shutdown authority.
|
||||
|
||||
Shadow-run mode is preserved as an opt-in for testing: passing
|
||||
`shadow_run=True` to `LifecycleStateMachine.__init__` keeps the old
|
||||
"persist + log + emit shadow_run_warning, do NOT exit" behaviour so
|
||||
the panel R7 validation script can drive transitions without
|
||||
terminating the daemon process.
|
||||
|
||||
Single-writer enforcement (L2): a separate lock file
|
||||
`~/.iai-mcp/.lifecycle.lock` carries the `fcntl.flock(LOCK_EX|LOCK_NB)`.
|
||||
The data file `lifecycle_state.json` is atomically replaced via
|
||||
`os.replace` (Phase 04-01 pattern), which swaps the inode — any lock
|
||||
held on the data file's fd would not protect the new file. The lock
|
||||
file is never renamed, so the lock survives `save_state` cycles.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import fcntl
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
from iai_mcp.lifecycle_event_log import LifecycleEventLog
|
||||
from iai_mcp.lifecycle_state import (
|
||||
LIFECYCLE_STATE_PATH,
|
||||
LifecycleState,
|
||||
LifecycleStateRecord,
|
||||
default_state,
|
||||
load_state,
|
||||
save_state,
|
||||
)
|
||||
|
||||
# Default lock path lives next to lifecycle_state.json. Hidden so it
|
||||
# does not show up in `ls`. Pattern matches `daemon-state.json` /
|
||||
# `.daemon-state.json` precedent.
|
||||
DEFAULT_LOCK_PATH: Path = Path.home() / ".iai-mcp" / ".lifecycle.lock"
|
||||
|
||||
|
||||
class LifecycleStateLocked(RuntimeError):
|
||||
"""Raised when another process holds the lifecycle_state.json lock.
|
||||
|
||||
Per L2 the daemon is the sole authority. A wrapper that finds the
|
||||
lock held by the daemon should signal events via Unix socket
|
||||
(when daemon alive) or write `~/.iai-mcp/wake.signal` (when
|
||||
daemon hibernated) — never bypass the lock with a direct write.
|
||||
"""
|
||||
|
||||
|
||||
class LifecycleEvent(str, Enum):
|
||||
"""Events that drive transitions."""
|
||||
|
||||
HEARTBEAT_REFRESH = "heartbeat_refresh"
|
||||
IDLE_5MIN = "idle_5min"
|
||||
IDLE_30MIN = "idle_30min"
|
||||
SLEEP_ELIGIBLE = "sleep_eligible"
|
||||
REQUEST_ARRIVED = "request_arrived"
|
||||
SLEEP_CYCLE_DONE = "sleep_cycle_done"
|
||||
HIBERNATION_GRACE_EXPIRED = "hibernation_grace_expired"
|
||||
WAKE_SIGNAL = "wake_signal"
|
||||
TICK = "tick"
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
"""ISO-8601 UTC timestamp; central so tests can monkey-patch."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure transition function — exposed at module scope for property tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_transition(
|
||||
state: LifecycleState,
|
||||
event: LifecycleEvent,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> LifecycleState | None:
|
||||
"""Return the target state, or None if `event` is a no-op for `state`.
|
||||
|
||||
Pure function — no I/O, no side effects, deterministic. The
|
||||
transition table is encoded inline here rather than a dict because
|
||||
the guard-bearing rows (`(DROWSY, IDLE_30MIN)` AND `sleep_eligible`)
|
||||
are easier to read as straight-line code than a `(state, event,
|
||||
guard) -> state` lookup with conditional fallback.
|
||||
|
||||
Transition table:
|
||||
|
||||
| From | Event | To |
|
||||
| WAKE | IDLE_5MIN | DROWSY |
|
||||
| DROWSY | HEARTBEAT_REFRESH | WAKE |
|
||||
| DROWSY | IDLE_30MIN AND sleep_eligible | SLEEP |
|
||||
| SLEEP | REQUEST_ARRIVED | WAKE |
|
||||
| SLEEP | SLEEP_CYCLE_DONE AND still_idle | HIBERNATION |
|
||||
| HIBERNATION | WAKE_SIGNAL | WAKE |
|
||||
| * | REQUEST_ARRIVED | WAKE (catch-all)
|
||||
|
||||
Catch-all: REQUEST_ARRIVED from any state goes to WAKE; that
|
||||
matches the SLEEP-specific rule above and adds DROWSY/HIBERNATION
|
||||
coverage. (HIBERNATION → WAKE on REQUEST_ARRIVED is a future-phase
|
||||
cold-start path — a wrapper that has REQUEST_ARRIVED to dispatch
|
||||
has already woken the daemon via wake.signal first; this branch
|
||||
exists for in-process test scaffolding and defence-in-depth.)
|
||||
"""
|
||||
payload = payload if payload is not None else {}
|
||||
|
||||
# Catch-all REQUEST_ARRIVED → WAKE; check first so subsequent
|
||||
# branches do not need to repeat the rule per source state.
|
||||
if event is LifecycleEvent.REQUEST_ARRIVED:
|
||||
return LifecycleState.WAKE
|
||||
|
||||
if state is LifecycleState.WAKE:
|
||||
if event is LifecycleEvent.IDLE_5MIN:
|
||||
return LifecycleState.DROWSY
|
||||
return None
|
||||
|
||||
if state is LifecycleState.DROWSY:
|
||||
if event is LifecycleEvent.HEARTBEAT_REFRESH:
|
||||
return LifecycleState.WAKE
|
||||
if event is LifecycleEvent.IDLE_30MIN and payload.get("sleep_eligible"):
|
||||
return LifecycleState.SLEEP
|
||||
return None
|
||||
|
||||
if state is LifecycleState.SLEEP:
|
||||
if event is LifecycleEvent.SLEEP_CYCLE_DONE and payload.get("still_idle"):
|
||||
return LifecycleState.HIBERNATION
|
||||
return None
|
||||
|
||||
if state is LifecycleState.HIBERNATION:
|
||||
if event is LifecycleEvent.WAKE_SIGNAL:
|
||||
return LifecycleState.WAKE
|
||||
# HIBERNATION_GRACE_EXPIRED is a future-phase trigger that
|
||||
# currently has no destination — kept as a known no-op so
|
||||
# the dispatcher does not raise on it.
|
||||
return None
|
||||
|
||||
return None # unreachable; defensive against future state additions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File-lock context manager — separate file per advisor recommendation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _lifecycle_lock(lock_path: Path) -> Iterator[int]:
|
||||
"""Acquire `fcntl.flock(LOCK_EX | LOCK_NB)` on a sibling lock file.
|
||||
|
||||
Raises `LifecycleStateLocked` if the lock is held by another
|
||||
process. The lock file persists across releases — it is the
|
||||
"named-mutex" handle, not the data. The data file
|
||||
`lifecycle_state.json` is atomically replaced separately and
|
||||
therefore must NOT carry the lock (os.replace swaps the inode).
|
||||
"""
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(str(lock_path), os.O_RDWR | os.O_CREAT, 0o600)
|
||||
try:
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except OSError as exc:
|
||||
if exc.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
|
||||
raise LifecycleStateLocked(
|
||||
f"another process holds {lock_path}"
|
||||
) from exc
|
||||
raise
|
||||
try:
|
||||
yield fd
|
||||
finally:
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
# Best effort — the close below releases the lock
|
||||
# whether or not the explicit unlock succeeded.
|
||||
pass
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State machine class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LifecycleStateMachine:
|
||||
"""Side-effecting wrapper around `compute_transition`.
|
||||
|
||||
Owns:
|
||||
- `lifecycle_state.json` reads + writes (single-writer enforced).
|
||||
- Event log emission (`state_transition`, `shadow_run_warning`).
|
||||
- `shadow_run` flag (default False since Phase 10.6; True is a transition-test escape hatch).
|
||||
|
||||
Construction is cheap; the lock is acquired only inside
|
||||
`dispatch`. Tests can drive transitions either via `dispatch`
|
||||
(full pipeline) or via `compute_transition` (pure-function
|
||||
coverage).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_path: Path | None = None,
|
||||
event_log: LifecycleEventLog | None = None,
|
||||
lock_path: Path | None = None,
|
||||
shadow_run: bool = False,
|
||||
) -> None:
|
||||
self._state_path = state_path if state_path is not None else LIFECYCLE_STATE_PATH
|
||||
self._event_log = event_log if event_log is not None else LifecycleEventLog()
|
||||
self._lock_path = lock_path if lock_path is not None else DEFAULT_LOCK_PATH
|
||||
self._shadow_run = shadow_run
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read-only helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def shadow_run(self) -> bool:
|
||||
return self._shadow_run
|
||||
|
||||
@property
|
||||
def current_state(self) -> LifecycleState:
|
||||
record = load_state(self._state_path)
|
||||
return LifecycleState(record["current_state"])
|
||||
|
||||
def snapshot(self) -> LifecycleStateRecord:
|
||||
"""Return the on-disk record (or default if absent)."""
|
||||
return load_state(self._state_path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pure transition (no I/O) — re-exposed for callers using an instance
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compute_transition(
|
||||
self,
|
||||
state: LifecycleState,
|
||||
event: LifecycleEvent,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> LifecycleState | None:
|
||||
return compute_transition(state, event, payload)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dispatcher — single-writer, persists + logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
event: LifecycleEvent,
|
||||
**payload: Any,
|
||||
) -> LifecycleState:
|
||||
"""Apply `event` to the current state, persist, log; return new state.
|
||||
|
||||
Acquires the lock for the duration of the read-compute-write
|
||||
cycle so the disk record cannot be raced by a second writer.
|
||||
Always returns the post-dispatch state — even when the event
|
||||
was a no-op (transition target was None), the caller gets the
|
||||
unchanged current state back. That keeps callers from having
|
||||
to special-case None.
|
||||
"""
|
||||
with _lifecycle_lock(self._lock_path):
|
||||
current_record = load_state(self._state_path)
|
||||
current_state = LifecycleState(current_record["current_state"])
|
||||
|
||||
target = compute_transition(current_state, event, payload)
|
||||
|
||||
now_iso = _utc_now_iso()
|
||||
# last_activity advances on any user-attributable event so
|
||||
# idle timers reset correctly.
|
||||
updated_record: LifecycleStateRecord = dict(current_record) # type: ignore[assignment]
|
||||
if event in {
|
||||
LifecycleEvent.HEARTBEAT_REFRESH,
|
||||
LifecycleEvent.REQUEST_ARRIVED,
|
||||
LifecycleEvent.WAKE_SIGNAL,
|
||||
}:
|
||||
updated_record["last_activity_ts"] = now_iso
|
||||
updated_record["wrapper_event_seq"] = (
|
||||
current_record.get("wrapper_event_seq", 0) + 1
|
||||
)
|
||||
|
||||
updated_record["shadow_run"] = self._shadow_run
|
||||
|
||||
if target is None:
|
||||
# No state change — persist any incremental wrapper-event
|
||||
# bookkeeping (last_activity_ts, seq) but skip the
|
||||
# transition log line.
|
||||
if updated_record != current_record:
|
||||
save_state(updated_record, self._state_path)
|
||||
return current_state
|
||||
|
||||
# State change. Update record and persist atomically.
|
||||
updated_record["current_state"] = target.value
|
||||
updated_record["since_ts"] = now_iso
|
||||
save_state(updated_record, self._state_path)
|
||||
|
||||
# Always log the transition.
|
||||
self._event_log.append(
|
||||
{
|
||||
"event": "state_transition",
|
||||
"from": current_state.value,
|
||||
"to": target.value,
|
||||
"trigger": event.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Shadow-run guard for HIBERNATION: the new state is
|
||||
# persisted on disk (so observers see it), and a warning
|
||||
# event documents that the legacy watchdog still owns
|
||||
# shutdown semantics.
|
||||
if target is LifecycleState.HIBERNATION and self._shadow_run:
|
||||
self._event_log.append(
|
||||
{
|
||||
"event": "shadow_run_warning",
|
||||
"would_action": "hibernate_kill_process",
|
||||
"blocked_by": "shadow_run=True",
|
||||
"note": (
|
||||
"shadow_run=True is a test-only legacy guard "
|
||||
"preserved for transition tests; production "
|
||||
"daemons run with shadow_run=False where this "
|
||||
"branch never fires."
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return target
|
||||
231
src/iai_mcp/lifecycle_event_log.py
Normal file
231
src/iai_mcp/lifecycle_event_log.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Phase 10.1 -- JSONL event log for lifecycle state machine validation.
|
||||
|
||||
Per panel verdict R7, the lifecycle state machine needs an append-only
|
||||
event log to validate transitions in shadow-run mode and to provide a
|
||||
post-mortem trail when something misbehaves. The log is the empirical
|
||||
ground truth for "did the machine compute the right state at the right
|
||||
moment", separate from the live `lifecycle_state.json` snapshot.
|
||||
|
||||
Format: JSONL (one JSON record per line), file per UTC date, kept under
|
||||
`~/.iai-mcp/logs/lifecycle-events-YYYY-MM-DD.jsonl`. Daily rotation
|
||||
keyed off the UTC date of the appended event so writes near local
|
||||
midnight do not silently fragment across two files in unpredictable
|
||||
timezones. 30-day retention with gzip compression for older files
|
||||
matches the retention spec.
|
||||
|
||||
Atomic line writes: each `append` opens the file with `O_APPEND |
|
||||
O_CREAT` and uses `fcntl.flock(LOCK_EX)` to serialise concurrent writers
|
||||
across processes. POSIX guarantees `O_APPEND` writes <= PIPE_BUF bytes
|
||||
are atomic on local filesystems; the explicit lock keeps us safe past
|
||||
that threshold (a single JSONL line for our event shapes is well under
|
||||
PIPE_BUF=512, but the lock costs ~microseconds and saves us debugging
|
||||
on the day a payload grows).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import fcntl
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Default location. Overridable via constructor `log_dir` for tests.
|
||||
DEFAULT_LOG_DIR: Path = Path.home() / ".iai-mcp" / "logs"
|
||||
|
||||
# Event kinds emitted by the state machine and helpers; treat as the
|
||||
# closed set for now — adding a kind requires updating downstream
|
||||
# consumers (panel R7 validation script in a future phase).
|
||||
KNOWN_EVENT_KINDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"state_transition",
|
||||
"wrapper_event",
|
||||
"shadow_run_warning",
|
||||
"sleep_step_started",
|
||||
"sleep_step_completed",
|
||||
"quarantine_entered",
|
||||
"quarantine_lifted",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Single point of `datetime.now(UTC)` -- patchable in tests."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _utc_date_string(dt: datetime | None = None) -> str:
|
||||
"""Return the UTC date as `YYYY-MM-DD` for filename derivation."""
|
||||
moment = dt if dt is not None else _utc_now()
|
||||
if moment.tzinfo is None:
|
||||
moment = moment.replace(tzinfo=timezone.utc)
|
||||
return moment.astimezone(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
class LifecycleEventLog:
|
||||
"""Append-only JSONL event log with daily rotation + retention.
|
||||
|
||||
Public surface:
|
||||
append(event) -- write one event line, lock + fsync.
|
||||
rotate_old_files(...) -- gzip files older than retention.
|
||||
current_file() -- return path to today's log file.
|
||||
|
||||
Thread/process safety: a per-call `fcntl.flock` on the destination
|
||||
file makes concurrent writers (daemon, hooks) safe. The lock is
|
||||
released as soon as the bytes hit disk; we do NOT keep a long-lived
|
||||
handle, so the file can rotate / be archived between calls without
|
||||
leaving a stale fd open.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir: Path | None = None) -> None:
|
||||
self._log_dir = log_dir if log_dir is not None else DEFAULT_LOG_DIR
|
||||
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Path derivation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def file_for_date(self, date_str: str) -> Path:
|
||||
"""Return the JSONL path for the given `YYYY-MM-DD` date string."""
|
||||
return self._log_dir / f"lifecycle-events-{date_str}.jsonl"
|
||||
|
||||
def current_file(self, now: datetime | None = None) -> Path:
|
||||
"""Return the path that `append` would write to right now."""
|
||||
return self.file_for_date(_utc_date_string(now))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Appender
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def append(self, event: dict[str, Any], now: datetime | None = None) -> None:
|
||||
"""Append one event as a JSONL line; auto-rotate by UTC date.
|
||||
|
||||
Adds `ts` (current UTC ISO-8601) if the caller did not pass one.
|
||||
Verifies `event["event"]` is a non-empty string but does NOT
|
||||
gate on `KNOWN_EVENT_KINDS` — adding a new kind should not
|
||||
require a code change to the log writer.
|
||||
|
||||
Concurrency: held lock via `fcntl.flock(LOCK_EX)`. Crash mid
|
||||
write: the partial line is on disk because we are O_APPEND
|
||||
without buffering, but `fsync` keeps the *prior* lines
|
||||
durable. Readers MUST tolerate a truncated final line (trim
|
||||
or skip on JSON decode error).
|
||||
"""
|
||||
if not isinstance(event, dict):
|
||||
raise TypeError(
|
||||
f"event must be a dict, got {type(event).__name__}"
|
||||
)
|
||||
kind = event.get("event")
|
||||
if not isinstance(kind, str) or not kind:
|
||||
raise ValueError("event['event'] must be a non-empty string")
|
||||
|
||||
moment = now if now is not None else _utc_now()
|
||||
if "ts" not in event:
|
||||
# Mutate a shallow copy so the caller's dict stays clean.
|
||||
event = {"ts": moment.astimezone(timezone.utc).isoformat(), **event}
|
||||
|
||||
line = json.dumps(event, separators=(",", ":")) + "\n"
|
||||
target = self.current_file(moment)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Open with O_APPEND so seeks land at EOF even under concurrent
|
||||
# write; flock for cross-process serialisation.
|
||||
fd = os.open(
|
||||
str(target),
|
||||
os.O_WRONLY | os.O_APPEND | os.O_CREAT,
|
||||
0o600,
|
||||
)
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
try:
|
||||
os.write(fd, line.encode("utf-8"))
|
||||
os.fsync(fd)
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Retention / rotation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def rotate_old_files(
|
||||
self,
|
||||
retention_days: int = 30,
|
||||
now: datetime | None = None,
|
||||
) -> int:
|
||||
"""Gzip log files whose UTC date is older than `retention_days`.
|
||||
|
||||
Already-gzipped files (`*.jsonl.gz`) are left alone. Returns
|
||||
the number of files newly compressed in this call. Files older
|
||||
than `retention_days` that are *also* already gzipped are kept
|
||||
forever in this phase — the spec asks for compression after
|
||||
the window, not deletion. (Deletion is a future-phase decision.)
|
||||
"""
|
||||
moment = now if now is not None else _utc_now()
|
||||
cutoff_date = (moment - timedelta(days=retention_days)).date()
|
||||
|
||||
compressed = 0
|
||||
for path in self._log_dir.glob("lifecycle-events-*.jsonl"):
|
||||
stem = path.stem # lifecycle-events-YYYY-MM-DD
|
||||
try:
|
||||
date_part = stem.rsplit("-", 3)[-3:] # ['YYYY','MM','DD']
|
||||
file_date = datetime.strptime(
|
||||
"-".join(date_part), "%Y-%m-%d"
|
||||
).date()
|
||||
except (ValueError, IndexError):
|
||||
# Unrecognised filename — skip rather than guess.
|
||||
continue
|
||||
if file_date > cutoff_date:
|
||||
continue
|
||||
|
||||
gz_path = path.with_suffix(".jsonl.gz")
|
||||
if gz_path.exists():
|
||||
# Idempotent: already compressed in a prior run.
|
||||
continue
|
||||
try:
|
||||
with path.open("rb") as src, gzip.open(gz_path, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
# Match prior chmod to keep the tarball user-only.
|
||||
os.chmod(gz_path, 0o600)
|
||||
# Remove the plaintext only after the gzip is durable.
|
||||
os.unlink(path)
|
||||
compressed += 1
|
||||
except OSError as exc:
|
||||
# Best-effort: a single broken file should not stop
|
||||
# the next iterations.
|
||||
if exc.errno in (errno.EACCES, errno.EPERM):
|
||||
continue
|
||||
# Unknown OSError — let the caller see it.
|
||||
raise
|
||||
return compressed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read helpers (non-essential but useful for tests + CLI)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def read_all(self, date_str: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Read all events from the file for `date_str` (or today).
|
||||
|
||||
Skips truncated final lines silently — only fully-decoded JSON
|
||||
records are returned. Returns [] if the file does not exist.
|
||||
"""
|
||||
target = self.file_for_date(
|
||||
date_str if date_str is not None else _utc_date_string()
|
||||
)
|
||||
if not target.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with target.open("r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
out.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return out
|
||||
341
src/iai_mcp/lifecycle_lock.py
Normal file
341
src/iai_mcp/lifecycle_lock.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
"""Phase 10.6 -- single-machine ``~/.iai-mcp/.locked`` lockfile.
|
||||
|
||||
Realises LOCKED contract (single-machine assumption): the
|
||||
daemon writes ``~/.iai-mcp/.locked`` on startup with PID + hostname +
|
||||
started_at. A second daemon attempt on the same host raises
|
||||
``LifecycleLockConflict``; a daemon on a different host (e.g. via
|
||||
iCloud / NFS sync of ``~/.iai-mcp``) detects the foreign hostname and
|
||||
takes over with a warning.
|
||||
|
||||
This is **distinct from** ``ProcessLock`` (Phase 04-01,
|
||||
``~/.iai-mcp/.lock``): that fcntl flock guards LanceDB writers / heavy
|
||||
consolidation against concurrent in-host processes. The ``.locked``
|
||||
lockfile is a higher-level, human-readable singleton marker for the
|
||||
lifecycle state machine (LSM); it does NOT use ``fcntl.flock`` because
|
||||
single-machine is the assumption and the JSON content (PID +
|
||||
hostname) is the diagnostic surface that ``iai-mcp lifecycle
|
||||
force-unlock`` consumes.
|
||||
|
||||
Design constraints (carried from CONTEXT 10.6):
|
||||
|
||||
- stdlib only -- ``os``, ``socket``, ``json``, ``pathlib``, ``datetime``.
|
||||
- POSIX-atomic write via ``tempfile.mkstemp`` + ``os.replace`` (same
|
||||
pattern as ``daemon_state.save_state`` / ``lifecycle_state.save_state``).
|
||||
- 0o600 file mode -- consistent with the rest of the project's state files.
|
||||
- Hostname recorded so iCloud / NFS sync of ``~/.iai-mcp`` does NOT
|
||||
produce a deadlock when the user moves to a second Mac.
|
||||
- PID-liveness check uses ``os.kill(pid, 0)`` (same trick as
|
||||
``heartbeat_scanner._is_pid_alive``).
|
||||
|
||||
Validates: WAKE-13.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defaults / constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _default_lock_path() -> Path:
|
||||
"""Resolve the default lockfile path, honoring ``IAI_MCP_STORE``.
|
||||
|
||||
Tests + multi-tenant deployments override the iai-mcp data root via
|
||||
the ``IAI_MCP_STORE`` env var (HIGH-4 LOCK precedent, Plan 07-04).
|
||||
Falling back to ``~/.iai-mcp`` keeps the production default
|
||||
untouched.
|
||||
"""
|
||||
env_path = os.environ.get("IAI_MCP_STORE")
|
||||
root = Path(env_path) if env_path else (Path.home() / ".iai-mcp")
|
||||
return root / ".locked"
|
||||
|
||||
|
||||
# Production lock-file path. Re-resolved via the helper so monkey-
|
||||
# patching ``IAI_MCP_STORE`` in tests redirects the production
|
||||
# default automatically. Tests can also pass an explicit ``lock_path``
|
||||
# argument to ``LifecycleLock``.
|
||||
DEFAULT_LOCK_PATH: Path = _default_lock_path()
|
||||
|
||||
#: Schema version persisted alongside the payload so a future bump can
|
||||
#: be detected at takeover time.
|
||||
SCHEMA_VERSION: int = 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LifecycleLockConflict(RuntimeError):
|
||||
"""Raised when ``acquire()`` finds a live daemon on the same host.
|
||||
|
||||
The exception carries the existing lockfile content as a dict so the
|
||||
caller (daemon main, ``iai-mcp lifecycle force-unlock``) can surface
|
||||
PID / started_at to the operator without a second disk read.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, existing: "LockPayload | None" = None) -> None:
|
||||
super().__init__(message)
|
||||
self.existing = existing
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed payload schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LockPayload(TypedDict):
|
||||
"""On-disk schema for ``.locked``."""
|
||||
|
||||
pid: int
|
||||
hostname: str
|
||||
started_at: str # ISO-8601 UTC
|
||||
schema_version: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
"""Return ISO-8601 UTC timestamp -- single point so tests can patch."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _current_hostname() -> str:
|
||||
"""Return ``socket.gethostname()``; central so tests can monkey-patch."""
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def _is_pid_alive(pid: int) -> bool:
|
||||
"""Return True iff ``pid`` exists in the kernel process table.
|
||||
|
||||
Mirrors the discipline in ``heartbeat_scanner._is_pid_alive``:
|
||||
``os.kill(pid, 0)`` sends no signal but raises ``ProcessLookupError``
|
||||
when the PID has been reaped. ``PermissionError`` (EPERM) means the
|
||||
process exists but we cannot signal it -- still alive for liveness
|
||||
purposes. Negative / zero PIDs are dead.
|
||||
"""
|
||||
if pid <= 0:
|
||||
return False
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _validate_payload(raw: object) -> LockPayload:
|
||||
"""Reject malformed JSON; return a typed copy on success.
|
||||
|
||||
Schema check kept light -- enough to catch operator hand-edits and
|
||||
out-of-band writes from a stale schema version. We do NOT require
|
||||
``schema_version`` to equal ``SCHEMA_VERSION``; a higher schema is
|
||||
treated as forward-compatible (the daemon refuses to overwrite it
|
||||
only if PID is alive on same host -- the conflict path).
|
||||
"""
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(
|
||||
f"lockfile payload must be a JSON object, got {type(raw).__name__}"
|
||||
)
|
||||
pid = raw.get("pid")
|
||||
if not isinstance(pid, int) or pid <= 0:
|
||||
raise ValueError(f"lockfile.pid must be a positive int, got {pid!r}")
|
||||
hostname = raw.get("hostname")
|
||||
if not isinstance(hostname, str) or not hostname:
|
||||
raise ValueError(
|
||||
f"lockfile.hostname must be a non-empty string, got {hostname!r}"
|
||||
)
|
||||
started_at = raw.get("started_at")
|
||||
if not isinstance(started_at, str) or not started_at:
|
||||
raise ValueError(
|
||||
f"lockfile.started_at must be a non-empty string, got {started_at!r}"
|
||||
)
|
||||
sv = raw.get("schema_version")
|
||||
if not isinstance(sv, int) or sv <= 0:
|
||||
raise ValueError(
|
||||
f"lockfile.schema_version must be a positive int, got {sv!r}"
|
||||
)
|
||||
return {
|
||||
"pid": pid,
|
||||
"hostname": hostname,
|
||||
"started_at": started_at,
|
||||
"schema_version": sv,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LifecycleLock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LifecycleLock:
|
||||
"""Single-machine lockfile for the lifecycle state machine.
|
||||
|
||||
Construction is cheap; no I/O happens until ``acquire()`` is called.
|
||||
Tests instantiate with an explicit ``lock_path`` under ``tmp_path``
|
||||
so production state is never touched.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_path: Path | None = None) -> None:
|
||||
# Resolve at construction time (not import time) so a test
|
||||
# that monkey-patches IAI_MCP_STORE before instantiating sees
|
||||
# the redirected path. Production callers pass no argument
|
||||
# and get the canonical ~/.iai-mcp/.locked.
|
||||
self._lock_path = (
|
||||
lock_path if lock_path is not None else _default_lock_path()
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def lock_path(self) -> Path:
|
||||
"""Filesystem location of the ``.locked`` file."""
|
||||
return self._lock_path
|
||||
|
||||
def read(self) -> LockPayload | None:
|
||||
"""Return the on-disk payload, or ``None`` if absent / corrupt.
|
||||
|
||||
Corrupt-file behaviour is "no lock" rather than raising: an
|
||||
operator hand-edit that produces invalid JSON should not block
|
||||
a fresh daemon boot. ``acquire()`` will then overwrite the file.
|
||||
"""
|
||||
if not self._lock_path.exists():
|
||||
return None
|
||||
try:
|
||||
raw = json.loads(self._lock_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
try:
|
||||
return _validate_payload(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def is_held_by_self(self) -> bool:
|
||||
"""True iff the on-disk lockfile names this process + this host.
|
||||
|
||||
Used by the daemon to short-circuit a redundant ``acquire()``
|
||||
on a fast restart where the file was never released (e.g. a
|
||||
crash that bypassed the ``finally`` cleanup -- in that case
|
||||
the PID will not match either, so this returns False and
|
||||
``acquire()`` does the dead-PID takeover).
|
||||
"""
|
||||
payload = self.read()
|
||||
if payload is None:
|
||||
return False
|
||||
return (
|
||||
payload["pid"] == os.getpid()
|
||||
and payload["hostname"] == _current_hostname()
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Acquire / release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def acquire(self) -> None:
|
||||
"""Write the lockfile, claiming the singleton slot for this process.
|
||||
|
||||
Decision tree:
|
||||
|
||||
1. No lockfile present -> write fresh.
|
||||
2. Lockfile present, corrupt JSON -> overwrite (treat as absent).
|
||||
3. Lockfile present, foreign hostname -> overwrite + log a warning
|
||||
(cross-host scenario via iCloud / NFS sync; daemon on the new
|
||||
host wins because the original host's daemon cannot reach
|
||||
this filesystem).
|
||||
4. Lockfile present, same hostname, dead PID -> overwrite (the
|
||||
previous daemon crashed before releasing).
|
||||
5. Lockfile present, same hostname, live PID -> ``raise
|
||||
LifecycleLockConflict`` (a real concurrent boot attempt).
|
||||
|
||||
Atomic write via ``tempfile.mkstemp`` + ``os.replace`` -- same
|
||||
pattern as ``lifecycle_state.save_state`` / ``daemon_state.save_state``.
|
||||
"""
|
||||
existing = self.read()
|
||||
if existing is not None:
|
||||
# Live PID on same host -> conflict.
|
||||
if existing["hostname"] == _current_hostname() and _is_pid_alive(
|
||||
existing["pid"]
|
||||
):
|
||||
raise LifecycleLockConflict(
|
||||
f"daemon already running: pid={existing['pid']} "
|
||||
f"hostname={existing['hostname']} "
|
||||
f"started_at={existing['started_at']}",
|
||||
existing=existing,
|
||||
)
|
||||
# Dead PID OR foreign hostname -> takeover (no error). The
|
||||
# foreign-hostname branch corresponds to the cross-host
|
||||
# iCloud / NFS sync scenario; we silently overwrite because
|
||||
# the only viable remediation is "the new host wins"
|
||||
# (the original host's daemon cannot share state with us
|
||||
# over a sync filesystem, by definition).
|
||||
|
||||
payload: LockPayload = {
|
||||
"pid": os.getpid(),
|
||||
"hostname": _current_hostname(),
|
||||
"started_at": _utc_now_iso(),
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
}
|
||||
|
||||
self._lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".locked.",
|
||||
suffix=".tmp",
|
||||
dir=str(self._lock_path.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.chmod(tmp, 0o600)
|
||||
os.replace(tmp, self._lock_path)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
def release(self) -> None:
|
||||
"""Delete the lockfile. Idempotent -- absent file is not an error.
|
||||
|
||||
Called from the daemon's graceful-shutdown ``finally`` block. A
|
||||
crash before this point leaves the file intact; the next
|
||||
``acquire()`` will detect the dead PID and overwrite.
|
||||
"""
|
||||
try:
|
||||
self._lock_path.unlink()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
def force_unlock(self) -> LockPayload | None:
|
||||
"""Delete the lockfile unconditionally; return the prior content.
|
||||
|
||||
Operator-facing helper used by ``iai-mcp lifecycle force-unlock``
|
||||
when a daemon crashed before ``release()`` and the dead-PID
|
||||
takeover did not catch the case (e.g. the operator wants to
|
||||
clear a foreign-hostname lock without booting a daemon first).
|
||||
|
||||
Returns the parsed prior payload (or ``None`` if absent /
|
||||
corrupt) so the caller can print PID / hostname / started_at
|
||||
in the diagnostic output.
|
||||
"""
|
||||
previous = self.read()
|
||||
try:
|
||||
self._lock_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return previous
|
||||
233
src/iai_mcp/lifecycle_state.py
Normal file
233
src/iai_mcp/lifecycle_state.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""Phase 10.1 -- typed schema + atomic load/save for lifecycle_state.json.
|
||||
|
||||
The 4-state lifecycle (WAKE / DROWSY / SLEEP / HIBERNATION) needs a single
|
||||
source of truth on disk. Per LOCKED contract L2 (panel verdict R2), the
|
||||
daemon is the ONLY writer of `~/.iai-mcp/lifecycle_state.json`; wrappers
|
||||
signal events via Unix socket OR atomic-write `~/.iai-mcp/wake.signal`
|
||||
filesystem marker.
|
||||
|
||||
Persistence pattern mirrors `daemon_state.py` (Phase 04-01) and
|
||||
`maintenance.py` (Phase 07.11-03):
|
||||
- Writes via `tempfile.mkstemp` + `os.replace` (POSIX atomic rename).
|
||||
- Crash mid-write leaves the prior file intact; readers either see
|
||||
the old complete blob or the new complete blob, never partial bytes.
|
||||
- File mode 0o600 (user-only, matches T-04-07 mitigation).
|
||||
|
||||
Schema mirrors lifecycle_state.json spec.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
# Default location. Overridable for tests via the `path` arg of load/save.
|
||||
LIFECYCLE_STATE_PATH: Path = Path.home() / ".iai-mcp" / "lifecycle_state.json"
|
||||
|
||||
|
||||
class LifecycleState(str, Enum):
|
||||
"""Four lifecycle states."""
|
||||
|
||||
WAKE = "WAKE"
|
||||
DROWSY = "DROWSY"
|
||||
SLEEP = "SLEEP"
|
||||
HIBERNATION = "HIBERNATION"
|
||||
|
||||
|
||||
class SleepCycleProgress(TypedDict, total=False):
|
||||
"""Per-attempt progress of the multi-step sleep pipeline.
|
||||
|
||||
All fields optional so the dict can be partially populated mid-cycle;
|
||||
`last_completed_step=0` and `attempt=1` represent a freshly-started cycle.
|
||||
"""
|
||||
|
||||
last_completed_step: int
|
||||
attempt: int
|
||||
last_error: str | None
|
||||
started_at: str # ISO-8601 UTC
|
||||
|
||||
|
||||
class Quarantine(TypedDict):
|
||||
"""A failing sleep step can quarantine the cycle until `until_ts`."""
|
||||
|
||||
until_ts: str # ISO-8601 UTC
|
||||
reason: str
|
||||
since_ts: str # ISO-8601 UTC
|
||||
|
||||
|
||||
class LifecycleStateRecord(TypedDict):
|
||||
"""On-disk schema for `lifecycle_state.json`.
|
||||
|
||||
`sleep_cycle_progress` and `quarantine` are nullable; the rest are
|
||||
always present in a well-formed record. `shadow_run` toggles whether
|
||||
the state machine actually executes process termination on
|
||||
HIBERNATION (False post-Phase 10.6) or merely logs the would-action.
|
||||
"""
|
||||
|
||||
current_state: str # one of LifecycleState values
|
||||
since_ts: str # ISO-8601 UTC
|
||||
last_activity_ts: str # ISO-8601 UTC
|
||||
wrapper_event_seq: int
|
||||
sleep_cycle_progress: SleepCycleProgress | None
|
||||
quarantine: Quarantine | None
|
||||
shadow_run: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
"""Return ISO-8601 UTC timestamp with explicit `+00:00` suffix.
|
||||
|
||||
`isoformat()` on a UTC-aware datetime emits `+00:00` rather than `Z`.
|
||||
Both forms are valid ISO-8601; downstream readers (CLI status, event
|
||||
log, Hypothesis tests) parse via `datetime.fromisoformat` which
|
||||
accepts the offset form.
|
||||
"""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def default_state() -> LifecycleStateRecord:
|
||||
"""Return a fresh WAKE record with shadow_run=False (Phase 10.6 default).
|
||||
|
||||
Used by `load_state` when the file is absent or malformed (self-heal),
|
||||
and by tests / callers that need a known starting point.
|
||||
|
||||
Plan 10.6-01 Task 1.6 flipped the default from True to False:
|
||||
HIBERNATION transitions now actually exit the daemon process via the
|
||||
global shutdown event in `daemon.main()`. The legacy RSS-watchdog has
|
||||
been removed in Task 1.4; the lifecycle state machine owns shutdown
|
||||
authority.
|
||||
"""
|
||||
now = _utc_now_iso()
|
||||
return {
|
||||
"current_state": LifecycleState.WAKE.value,
|
||||
"since_ts": now,
|
||||
"last_activity_ts": now,
|
||||
"wrapper_event_seq": 0,
|
||||
"sleep_cycle_progress": None,
|
||||
"quarantine": None,
|
||||
"shadow_run": False,
|
||||
}
|
||||
|
||||
|
||||
def _validate_record(raw: object) -> LifecycleStateRecord:
|
||||
"""Reject malformed JSON; return a typed copy on success.
|
||||
|
||||
A minimal schema check — enough to catch hand-edited corruption and
|
||||
out-of-band writes from a stale schema version, without pulling in
|
||||
pydantic for runtime validation. Reads stay zero-allocation past the
|
||||
JSON parse step.
|
||||
"""
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(
|
||||
f"lifecycle_state record must be a JSON object, got {type(raw).__name__}"
|
||||
)
|
||||
|
||||
required_str_keys = ("current_state", "since_ts", "last_activity_ts")
|
||||
for k in required_str_keys:
|
||||
v = raw.get(k)
|
||||
if not isinstance(v, str) or not v:
|
||||
raise ValueError(f"lifecycle_state.{k} must be a non-empty string, got {v!r}")
|
||||
|
||||
state_value = raw["current_state"]
|
||||
if state_value not in {s.value for s in LifecycleState}:
|
||||
raise ValueError(
|
||||
f"lifecycle_state.current_state {state_value!r} is not a valid LifecycleState"
|
||||
)
|
||||
|
||||
seq = raw.get("wrapper_event_seq")
|
||||
if not isinstance(seq, int) or seq < 0:
|
||||
raise ValueError(
|
||||
f"lifecycle_state.wrapper_event_seq must be a non-negative int, got {seq!r}"
|
||||
)
|
||||
|
||||
shadow = raw.get("shadow_run")
|
||||
if not isinstance(shadow, bool):
|
||||
raise ValueError(
|
||||
f"lifecycle_state.shadow_run must be a bool, got {shadow!r}"
|
||||
)
|
||||
|
||||
progress = raw.get("sleep_cycle_progress")
|
||||
if progress is not None and not isinstance(progress, dict):
|
||||
raise ValueError(
|
||||
f"lifecycle_state.sleep_cycle_progress must be dict or null, got {progress!r}"
|
||||
)
|
||||
|
||||
quarantine = raw.get("quarantine")
|
||||
if quarantine is not None:
|
||||
if not isinstance(quarantine, dict):
|
||||
raise ValueError(
|
||||
f"lifecycle_state.quarantine must be dict or null, got {quarantine!r}"
|
||||
)
|
||||
for k in ("until_ts", "reason", "since_ts"):
|
||||
if not isinstance(quarantine.get(k), str):
|
||||
raise ValueError(
|
||||
f"lifecycle_state.quarantine.{k} must be string"
|
||||
)
|
||||
|
||||
# Cast is safe after the checks above; mypy/pylance accept the dict.
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
|
||||
def load_state(path: Path | None = None) -> LifecycleStateRecord:
|
||||
"""Read `lifecycle_state.json`; return `default_state()` if absent.
|
||||
|
||||
On JSON-decode error or schema-validation error: also returns a
|
||||
fresh default state. The legacy file is left in place (no auto-delete)
|
||||
so an operator can inspect it; `save_state` will overwrite it on the
|
||||
next persist.
|
||||
"""
|
||||
target = path if path is not None else LIFECYCLE_STATE_PATH
|
||||
if not target.exists():
|
||||
return default_state()
|
||||
try:
|
||||
raw = json.loads(target.read_text())
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return default_state()
|
||||
try:
|
||||
return _validate_record(raw)
|
||||
except ValueError:
|
||||
return default_state()
|
||||
|
||||
|
||||
def save_state(record: LifecycleStateRecord, path: Path | None = None) -> None:
|
||||
"""Atomically persist `record` via tempfile + os.replace.
|
||||
|
||||
Mirrors `daemon_state.save_state` (Phase 04-01) bullet-for-bullet:
|
||||
creates parent dir if missing; writes to a sibling temp file in the
|
||||
same directory (required so os.replace is an atomic same-filesystem
|
||||
rename); fsyncs the file contents before rename so the data is on
|
||||
disk; chmods 0o600 before the swap so the visible file is never
|
||||
world-readable; on exception unlinks the temp file so /tmp does not
|
||||
accumulate.
|
||||
"""
|
||||
target = path if path is not None else LIFECYCLE_STATE_PATH
|
||||
# Validate before writing so callers get an early ValueError on
|
||||
# malformed records rather than persisting garbage to disk.
|
||||
_validate_record(record)
|
||||
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
prefix=".lifecycle_state.",
|
||||
suffix=".tmp",
|
||||
dir=str(target.parent),
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(record, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.chmod(tmp, 0o600)
|
||||
os.replace(tmp, target)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
179
src/iai_mcp/maintenance.py
Normal file
179
src/iai_mcp/maintenance.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""periodic Lance storage maintenance.
|
||||
|
||||
Forensic trigger (2026-04-27): the daemon was running 248% CPU sustained for
|
||||
1h14min because `records.lance` had grown to 10,841 versions / 3.66 GB for
|
||||
only 7,130 rows over 9 days. There has never been a `table.optimize()` call
|
||||
site in production code. Offline `optimize(cleanup_older_than=timedelta(days=1))`
|
||||
reclaimed 84% disk and dropped `build_runtime_graph` cold latency 13.3s ->
|
||||
0.13s (102x). codifies that fix as a daemon-managed periodic job
|
||||
so version manifests + soft-deleted rows do not re-accumulate.
|
||||
|
||||
Architecture:
|
||||
- D7.3-01: periodic + startup, NOT write-triggered (post-write hook would
|
||||
amplify write latency unboundedly).
|
||||
- D7.3-02: single-process inside the daemon (no worker process).
|
||||
- D7.3-03: helper is SYNC; callers wrap in `asyncio.to_thread`. Phase 7.2's
|
||||
AST fence (tests/test_no_bare_sync_in_async.py) enforces this discipline
|
||||
via `BLOCKING_NAMES` (D7.3-26).
|
||||
- D7.3-09: helper NEVER raises. Per-table failures captured in the per-table
|
||||
dict's `error` field. The daemon must not die from an optimize failure.
|
||||
- D7.3-13/D7.3-21: 1-day default retention matches Lance docs FAQ.
|
||||
|
||||
Two env overrides (read once at import per D7.3-22):
|
||||
- IAI_MCP_LANCE_OPTIMIZE_INTERVAL_SEC (default 3600s = 1h cadence)
|
||||
- IAI_MCP_LANCE_OPTIMIZE_RETENTION_SEC (default 86400s = 1 day)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# D7.3-20: 1-hour periodic cadence (12x the cascade-poll cadence; same order
|
||||
# of magnitude as the maintenance work itself; far longer than typical session
|
||||
# length so optimize rarely interferes; short enough that bloat stays bounded).
|
||||
LANCE_OPTIMIZE_INTERVAL_SEC: float = float(
|
||||
os.environ.get("IAI_MCP_LANCE_OPTIMIZE_INTERVAL_SEC", "3600.0"),
|
||||
)
|
||||
|
||||
# D7.3-21: 1-day retention matches Lance's documented `cleanup_older_than`
|
||||
# example. Aggressive enough to free disk fast; conservative enough for
|
||||
# point-in-time time-travel reads within the same day.
|
||||
LANCE_OPTIMIZE_RETENTION_SEC: float = float(
|
||||
os.environ.get("IAI_MCP_LANCE_OPTIMIZE_RETENTION_SEC", "86400.0"),
|
||||
)
|
||||
|
||||
# Daemon-owned tables; matches src/iai_mcp/store.py constants
|
||||
# (RECORDS_TABLE/EDGES_TABLE/EVENTS_TABLE) but kept literal so this module
|
||||
# does not pull MemoryStore at import time.
|
||||
_TABLES_TO_OPTIMIZE: tuple[str, ...] = ("records", "edges", "events")
|
||||
|
||||
|
||||
def _measure_table_size_bytes(store: Any, table_name: str) -> int:
|
||||
"""Sum the size of every file under <storage_root>/lancedb/<table>.lance/.
|
||||
|
||||
Returns 0 on any measurement failure so size metrics are best-effort:
|
||||
a measurement failure must NOT cause the helper itself to raise. The
|
||||
actual `tbl.optimize()` call is independent — disk-size telemetry is
|
||||
purely observational and exists for the operator-facing event payload.
|
||||
"""
|
||||
try:
|
||||
# MemoryStore.root is the user-supplied (or env-derived) storage
|
||||
# root; the LanceDB connection lives at root/lancedb (see store.py
|
||||
# line 202). Each table is a `<name>.lance` directory underneath.
|
||||
root = getattr(store, "root", None)
|
||||
if root is None:
|
||||
return 0
|
||||
table_dir = Path(root) / "lancedb" / f"{table_name}.lance"
|
||||
if not table_dir.exists():
|
||||
return 0
|
||||
total = 0
|
||||
for p in table_dir.rglob("*"):
|
||||
try:
|
||||
if p.is_file():
|
||||
total += p.stat().st_size
|
||||
except OSError:
|
||||
# File could be unlinked mid-scan during an active optimize;
|
||||
# skip it, keep counting the rest.
|
||||
continue
|
||||
return total
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def optimize_lance_storage(
|
||||
store: Any,
|
||||
*,
|
||||
retention: timedelta | None = None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Run `tbl.optimize(cleanup_older_than=retention)` on each daemon-owned
|
||||
LanceDB table (records, edges, events).
|
||||
|
||||
Args:
|
||||
store: MemoryStore-shaped object exposing `.db` (lancedb.Connection).
|
||||
Duck-typed so test fixtures can pass a stub. The function only
|
||||
reads `store.db` and `store.root` (latter optional for size
|
||||
telemetry).
|
||||
retention: timedelta passed to LanceDB's `cleanup_older_than`. If
|
||||
None, defaults to `timedelta(seconds=LANCE_OPTIMIZE_RETENTION_SEC)`
|
||||
which is 1 day in production.
|
||||
|
||||
Returns:
|
||||
Flat dict keyed by table name (`records`, `edges`, `events`). Each
|
||||
value is a per-table dict::
|
||||
|
||||
{
|
||||
"rows_before": int, # tbl.count_rows() pre-optimize
|
||||
"rows_after": int, # tbl.count_rows() post-optimize
|
||||
"versions_before": int, # len(tbl.list_versions()) pre
|
||||
"versions_after": int, # len(tbl.list_versions()) post
|
||||
"size_bytes_before": int, # du -sb on .lance/ pre, 0 on err
|
||||
"size_bytes_after": int, # du -sb on .lance/ post, 0 on err
|
||||
"elapsed_sec": float, # wall-clock for optimize()
|
||||
"error": str, # ONLY present on failure
|
||||
}
|
||||
|
||||
Per D7.3-09: this helper NEVER raises. Per-table failure captured in
|
||||
the table's `error` field; the other tables are still processed.
|
||||
"""
|
||||
if retention is None:
|
||||
retention = timedelta(seconds=LANCE_OPTIMIZE_RETENTION_SEC)
|
||||
|
||||
report: dict[str, dict[str, Any]] = {}
|
||||
db = getattr(store, "db", None)
|
||||
|
||||
for table_name in _TABLES_TO_OPTIMIZE:
|
||||
per_table: dict[str, Any] = {
|
||||
"rows_before": 0,
|
||||
"rows_after": 0,
|
||||
"versions_before": 0,
|
||||
"versions_after": 0,
|
||||
"size_bytes_before": 0,
|
||||
"size_bytes_after": 0,
|
||||
"elapsed_sec": 0.0,
|
||||
}
|
||||
try:
|
||||
if db is None:
|
||||
raise RuntimeError("store has no .db attribute")
|
||||
tbl = db.open_table(table_name)
|
||||
try:
|
||||
per_table["rows_before"] = int(tbl.count_rows())
|
||||
except Exception:
|
||||
per_table["rows_before"] = 0
|
||||
try:
|
||||
per_table["versions_before"] = len(tbl.list_versions())
|
||||
except Exception:
|
||||
per_table["versions_before"] = 0
|
||||
per_table["size_bytes_before"] = _measure_table_size_bytes(
|
||||
store, table_name,
|
||||
)
|
||||
|
||||
t0 = time.monotonic()
|
||||
tbl.optimize(cleanup_older_than=retention)
|
||||
per_table["elapsed_sec"] = round(time.monotonic() - t0, 3)
|
||||
|
||||
# Re-open the table after optimize: some LanceDB versions return
|
||||
# cached metadata on the original handle until refresh.
|
||||
try:
|
||||
tbl_after = db.open_table(table_name)
|
||||
except Exception:
|
||||
tbl_after = tbl
|
||||
try:
|
||||
per_table["rows_after"] = int(tbl_after.count_rows())
|
||||
except Exception:
|
||||
per_table["rows_after"] = per_table["rows_before"]
|
||||
try:
|
||||
per_table["versions_after"] = len(tbl_after.list_versions())
|
||||
except Exception:
|
||||
per_table["versions_after"] = per_table["versions_before"]
|
||||
per_table["size_bytes_after"] = _measure_table_size_bytes(
|
||||
store, table_name,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 -- helper MUST NOT raise (D7.3-09)
|
||||
per_table["error"] = str(exc)[:500]
|
||||
|
||||
report[table_name] = per_table
|
||||
|
||||
return report
|
||||
1979
src/iai_mcp/migrate.py
Normal file
1979
src/iai_mcp/migrate.py
Normal file
File diff suppressed because it is too large
Load diff
1429
src/iai_mcp/pipeline.py
Normal file
1429
src/iai_mcp/pipeline.py
Normal file
File diff suppressed because it is too large
Load diff
634
src/iai_mcp/profile.py
Normal file
634
src/iai_mcp/profile.py
Normal file
|
|
@ -0,0 +1,634 @@
|
|||
"""11-knob profile registry (D-11 + wake_depth, Plan 07.12-02 removals).
|
||||
|
||||
Plan 02-03 activated the Phase-2 autistic-kernel knobs. flipped
|
||||
AUTIST-13 camouflaging_relaxation to live. appended the sealed
|
||||
operator-facing knob `wake_depth` — selects session-start payload size
|
||||
(minimal = <=30 raw tok lazy handle; standard = Phase-1 1388 tok eager dump;
|
||||
deep = <=2000 tok expanded rich_club). Plan 07.12-02 REMOVED 4 dead KnobSpec
|
||||
entries (AUTIST-02 sensory_channel_weights, event_vs_time_cue,
|
||||
AUTIST-11 alexithymia_accommodation, double_empathy) — none was
|
||||
read in any production scoring/response path; double_empathy was promoted
|
||||
to a passive system invariant in CLAUDE.md, event_vs_time_cue was documented
|
||||
as a deferred future capability.
|
||||
|
||||
Registry shape:
|
||||
- 10 live autistic-kernel knobs (AUTIST-01,03,04,05,06,07,09,10,13,14)
|
||||
- 1 live Phase-5 operator knob (MCP-12 wake_depth, default "minimal")
|
||||
- 0 deferred
|
||||
|
||||
The registry is a module-level frozen-dataclass dict so
|
||||
1. `assert len(PROFILE_KNOBS) == 11`
|
||||
2. test_profile.py can grep exact knob names in order
|
||||
3. Session-start assembler reads the live subset in O(1)
|
||||
|
||||
Schema validation covers:
|
||||
- `enum:a|b|c` -- value must be exactly one of the listed tokens
|
||||
- `bool` -- isinstance(value, bool)
|
||||
- `int_range:lo..hi` -- integer in [lo, hi] inclusive
|
||||
- `float_range:lo..hi` -- float in [lo, hi] inclusive
|
||||
- `dict:<keytype>:<valuetype>` -- per-key recursive validation
|
||||
(e.g. `dict:str:float_range:0.0..1.0`)
|
||||
- anything else -- reject (typo guard)
|
||||
|
||||
Plan 02-03 runtime-gain mechanism exposed via two helpers:
|
||||
- bayesian_update: weighted ensemble posterior update
|
||||
- profile_modulation_for_record: per-record edge-weight gain dict
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- schema
|
||||
@dataclass(frozen=True)
|
||||
class KnobSpec:
|
||||
"""Static spec for one autistic-kernel knob."""
|
||||
|
||||
name: str
|
||||
phase: int # 1 | 2 | 3
|
||||
default: Any # Phase-1 default, or Phase-2/3 placeholder default
|
||||
description: str
|
||||
value_schema: str # "enum:a|b|c" | "bool" | "int_range:0..5" | "float_range:0.0..1.0"
|
||||
requirement_id: str # AUTIST-01..14
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ registry
|
||||
# 11 sealed knobs: 10 autistic-kernel + wake_depth
|
||||
# (Plan 07.12-02 removed sensory_channel_weights, AUTIST-08
|
||||
# event_vs_time_cue, alexithymia_accommodation, double_empathy).
|
||||
# flipped 9 Phase-2 knobs to phase=1.
|
||||
# flipped camouflaging_relaxation to phase=1.
|
||||
# appended wake_depth (MCP-12, operator-facing).
|
||||
PROFILE_KNOBS: dict[str, KnobSpec] = {
|
||||
"monotropism_depth": KnobSpec(
|
||||
"monotropism_depth",
|
||||
1,
|
||||
{}, # per-domain dict; empty default (unknown domains -> no gain)
|
||||
"Monotropism depth per domain (voluntary tunnel; HIPPEA precision)",
|
||||
"dict:str:float_range:0.0..1.0",
|
||||
"AUTIST-01",
|
||||
),
|
||||
"dunn_quadrant": KnobSpec(
|
||||
"dunn_quadrant",
|
||||
1,
|
||||
"neutral",
|
||||
"Sensory threshold x regulation posture (Dunn four-quadrant; "
|
||||
"drives HIPPEA precision weighting at runtime)",
|
||||
"enum:neutral|low-registration|seeking|sensitive|avoiding",
|
||||
"AUTIST-03",
|
||||
),
|
||||
"literal_preservation": KnobSpec(
|
||||
"literal_preservation",
|
||||
1,
|
||||
"strong",
|
||||
"Verbatim vs semantic summary (raw always retained)",
|
||||
"enum:strong|medium|loose",
|
||||
"AUTIST-04",
|
||||
),
|
||||
"demand_avoidance_tolerance": KnobSpec(
|
||||
"demand_avoidance_tolerance",
|
||||
1,
|
||||
"collaborative",
|
||||
"PDA-aware collaborative phrasing vs imperative",
|
||||
"enum:collaborative|neutral|imperative",
|
||||
"AUTIST-05",
|
||||
),
|
||||
"masking_off": KnobSpec(
|
||||
"masking_off",
|
||||
1,
|
||||
True,
|
||||
"No small-talk, no performative empathy, literal pragmatics",
|
||||
"bool",
|
||||
"AUTIST-06",
|
||||
),
|
||||
"task_support": KnobSpec(
|
||||
"task_support",
|
||||
1,
|
||||
"cued_recognition",
|
||||
"Blank-recall vs cued-recognition with adjacent suggestions (Bowler)",
|
||||
"enum:blank_recall|cued_recognition",
|
||||
"AUTIST-07",
|
||||
),
|
||||
"interest_boost": KnobSpec(
|
||||
"interest_boost",
|
||||
1,
|
||||
0.0,
|
||||
"Salience amplification adjacent to monotropism domains",
|
||||
"float_range:0.0..1.0",
|
||||
"AUTIST-09",
|
||||
),
|
||||
"inertia_awareness": KnobSpec(
|
||||
"inertia_awareness",
|
||||
1,
|
||||
False,
|
||||
"Ambient passive capture in high-inertia windows",
|
||||
"bool",
|
||||
"AUTIST-10",
|
||||
),
|
||||
"camouflaging_relaxation": KnobSpec(
|
||||
"camouflaging_relaxation",
|
||||
1,
|
||||
0.0,
|
||||
"Detect over-formal writing, gradually relax formality (Phase 1 live)",
|
||||
"float_range:0.0..1.0",
|
||||
"AUTIST-13",
|
||||
),
|
||||
"scene_construction_scaffold": KnobSpec(
|
||||
"scene_construction_scaffold",
|
||||
1,
|
||||
True,
|
||||
"Scene-construction scaffold intensity for episodic encoding",
|
||||
"bool",
|
||||
"AUTIST-14",
|
||||
),
|
||||
# D5-06: 15th sealed knob (operator-facing, not autistic-kernel).
|
||||
# wake_depth drives session-start payload size. minimal (default) = ≤30 raw
|
||||
# tok pointer handle (lazy; brain stays server-side); standard = Phase-1
|
||||
# 1388 tok eager dump (back-compat per D5-10); deep = ≤2000 tok expanded
|
||||
# rich_club. Set via existing profile_get_set tool; no new MCP surface.
|
||||
"wake_depth": KnobSpec(
|
||||
"wake_depth",
|
||||
1, # phase — live in (counts toward PHASE_1_LIVE)
|
||||
"minimal",
|
||||
(
|
||||
"Session-start payload size: minimal=<=30 raw (lazy, default), "
|
||||
"standard=Phase-1 eager (back-compat), deep=<=2000 (full)"
|
||||
),
|
||||
"enum:minimal|standard|deep",
|
||||
"MCP-12",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
PHASE_1_LIVE: frozenset[str] = frozenset(
|
||||
{name for name, spec in PROFILE_KNOBS.items() if spec.phase == 1}
|
||||
)
|
||||
PHASE_2_DEFERRED: frozenset[str] = frozenset(
|
||||
{name for name, spec in PROFILE_KNOBS.items() if spec.phase == 2}
|
||||
)
|
||||
PHASE_3_DEFERRED: frozenset[str] = frozenset(
|
||||
{name for name, spec in PROFILE_KNOBS.items() if spec.phase == 3}
|
||||
)
|
||||
|
||||
|
||||
# Plan 07.12-02: 11-knob shape is load-bearing. Enforced at import time.
|
||||
# History:
|
||||
# - flipped the 9 Phase-2 knobs to phase=1 (PHASE_1_LIVE=13).
|
||||
# - FLIPPED camouflaging_relaxation to phase=1 (PHASE_1_LIVE=14).
|
||||
# - APPENDS wake_depth as the 15th sealed knob (PHASE_1_LIVE=15).
|
||||
# - Plan 07.12-02 REMOVES 4 dead KnobSpec entries (AUTIST-02 sensory,
|
||||
# event_vs_time_cue, alexithymia, double_empathy).
|
||||
# Final shape: 10 AUTIST + 1 wake_depth = 11 sealed knobs.
|
||||
assert len(PROFILE_KNOBS) == 11, (
|
||||
"Plan 07.12-02: 10 autistic-kernel knobs + wake_depth = 11 sealed entries"
|
||||
)
|
||||
assert len(PHASE_1_LIVE) == 11, (
|
||||
"Plan 07.12-02: 10 autistic-kernel knobs + wake_depth are live"
|
||||
)
|
||||
assert len(PHASE_2_DEFERRED) == 0, "Plan 02-03 empties PHASE_2_DEFERRED"
|
||||
assert len(PHASE_3_DEFERRED) == 0, "PHASE_3_DEFERRED emptied"
|
||||
|
||||
|
||||
# Bayesian signal weights (Plan 02-03 LEARN-01)
|
||||
SIGNAL_WEIGHT: dict[str, float] = {
|
||||
"implicit": 0.3,
|
||||
"inferred": 0.5,
|
||||
"explicit": 1.0,
|
||||
}
|
||||
|
||||
|
||||
# profile sentinel UUID -- target node for every profile_modulates edge.
|
||||
# Deterministic so the edges table can be scanned without a side table. The
|
||||
# UUID is ff-nonsense so no record ever collides with it.
|
||||
PROFILE_SENTINEL_UUID_STR = "00000000-0000-0000-0000-0000000000f1"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- state
|
||||
def default_state() -> dict[str, Any]:
|
||||
"""Initial per-process state: the live knobs with defaults.
|
||||
|
||||
Deferred knobs do not appear in state because profile_set rejects them;
|
||||
profile_get on a deferred knob returns status/phase/requirement_id directly
|
||||
from the registry.
|
||||
"""
|
||||
return {
|
||||
name: spec.default
|
||||
for name, spec in PROFILE_KNOBS.items()
|
||||
if spec.phase == 1
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- validation
|
||||
def _validate(schema: str, value: Any) -> tuple[bool, str]:
|
||||
"""Return (ok, reason). Reason empty on success.
|
||||
|
||||
extends the validators to support `dict:<keytype>:<valuetype>`
|
||||
via recursive per-key validation. Unknown schemas (typos) are rejected.
|
||||
"""
|
||||
if schema == "bool":
|
||||
# Note: `isinstance(True, int)` is True in Python, so check bool first.
|
||||
if isinstance(value, bool):
|
||||
return True, ""
|
||||
return False, f"value must be bool, got {type(value).__name__}"
|
||||
|
||||
if schema.startswith("enum:"):
|
||||
allowed = schema[len("enum:"):].split("|")
|
||||
if value in allowed:
|
||||
return True, ""
|
||||
return False, f"value {value!r} not in enum {allowed}"
|
||||
|
||||
if schema.startswith("int_range:"):
|
||||
bounds = schema[len("int_range:"):]
|
||||
try:
|
||||
lo_s, hi_s = bounds.split("..")
|
||||
lo, hi = int(lo_s), int(hi_s)
|
||||
except (ValueError, TypeError):
|
||||
return False, f"malformed int_range schema {schema!r}"
|
||||
if isinstance(value, bool):
|
||||
return False, "value must be int, got bool"
|
||||
if not isinstance(value, int):
|
||||
return False, f"value must be int, got {type(value).__name__}"
|
||||
if value < lo or value > hi:
|
||||
return False, f"value {value} out of range [{lo}, {hi}]"
|
||||
return True, ""
|
||||
|
||||
if schema.startswith("float_range:"):
|
||||
bounds = schema[len("float_range:"):]
|
||||
try:
|
||||
lo_s, hi_s = bounds.split("..")
|
||||
lo, hi = float(lo_s), float(hi_s)
|
||||
except (ValueError, TypeError):
|
||||
return False, f"malformed float_range schema {schema!r}"
|
||||
if isinstance(value, bool):
|
||||
return False, "value must be float, got bool"
|
||||
if not isinstance(value, (int, float)):
|
||||
return False, f"value must be float, got {type(value).__name__}"
|
||||
v = float(value)
|
||||
if v < lo or v > hi:
|
||||
return False, f"value {v} out of range [{lo}, {hi}]"
|
||||
return True, ""
|
||||
|
||||
if schema.startswith("dict:"):
|
||||
body = schema[len("dict:"):]
|
||||
key_type, _, val_type = body.partition(":")
|
||||
if not val_type:
|
||||
return False, f"malformed dict schema {schema!r}"
|
||||
if not isinstance(value, dict):
|
||||
return False, f"value must be dict, got {type(value).__name__}"
|
||||
for k, v in value.items():
|
||||
if key_type == "str" and not isinstance(k, str):
|
||||
return False, f"dict key must be str, got {type(k).__name__}"
|
||||
ok, reason = _validate(val_type, v)
|
||||
if not ok:
|
||||
return False, f"in key {k!r}: {reason}"
|
||||
return True, ""
|
||||
|
||||
# Unknown schema -> reject (covers accidental typos in KnobSpec.value_schema).
|
||||
return False, f"unknown value_schema {schema!r}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------- public surface
|
||||
def profile_get(knob: str | None, state: dict[str, Any]) -> dict:
|
||||
"""Read a knob (or the full registry surface).
|
||||
|
||||
- knob=None -> full registry: {live: {11}, deferred: {0}, total_knobs: 11}.
|
||||
- knob in PHASE_1_LIVE -> {"knob": n, "value": state[n]}.
|
||||
- knob in deferred (P3) -> status/phase/requirement_id payload.
|
||||
- unknown knob -> {"knob": n, "status": "unknown"}.
|
||||
|
||||
Plan 07.12-02: total_knobs is 11 (10 AUTIST + wake_depth) after AUTIST-02/08/11/12 removal.
|
||||
"""
|
||||
if knob is None:
|
||||
live = {
|
||||
n: state.get(n, PROFILE_KNOBS[n].default)
|
||||
for n in sorted(PHASE_1_LIVE)
|
||||
}
|
||||
deferred = {}
|
||||
for n in sorted(PHASE_2_DEFERRED | PHASE_3_DEFERRED):
|
||||
spec = PROFILE_KNOBS[n]
|
||||
deferred[n] = {
|
||||
"status": "not-yet-implemented",
|
||||
"phase": spec.phase,
|
||||
"requirement_id": spec.requirement_id,
|
||||
"description": spec.description,
|
||||
}
|
||||
return {"live": live, "deferred": deferred, "total_knobs": 11}
|
||||
|
||||
if knob in PHASE_1_LIVE:
|
||||
spec = PROFILE_KNOBS[knob]
|
||||
return {"knob": knob, "value": state.get(knob, spec.default)}
|
||||
|
||||
if knob in PROFILE_KNOBS:
|
||||
spec = PROFILE_KNOBS[knob]
|
||||
return {
|
||||
"knob": knob,
|
||||
"status": "not-yet-implemented",
|
||||
"phase": spec.phase,
|
||||
"requirement_id": spec.requirement_id,
|
||||
}
|
||||
|
||||
return {"knob": knob, "status": "unknown"}
|
||||
|
||||
|
||||
def profile_set(
|
||||
knob: str,
|
||||
value: Any,
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
store: "object | None" = None,
|
||||
) -> dict:
|
||||
"""Write a live knob. Rejects unknown/deferred/invalid-value writes.
|
||||
|
||||
Rule priority:
|
||||
1. unknown knob -> {"status": "error", "reason": "unknown knob"}
|
||||
2. Phase-2 knob -> {"status": "error", "reason": "deferred to Phase 2"}
|
||||
(Plan 02-03 empties this set but the branch is retained for safety.)
|
||||
3. Phase-3 knob -> {"status": "error", "reason": "deferred to Phase 3"}
|
||||
4. schema fail -> {"status": "error", "reason": <validator message>}
|
||||
5. success -> mutates state; returns {"status": "ok", knob, value}
|
||||
|
||||
(M4 LIVE prerequisite): when ``store`` is provided AND the
|
||||
write actually changes the value, emit ``kind='profile_updated'`` so
|
||||
M4 profile-variance can be computed live. No-op writes (old == new) do
|
||||
NOT emit (avoid event flood). The ``store`` kwarg is optional so old
|
||||
callers (e.g. core.dispatch profile_set branch) keep working unchanged.
|
||||
"""
|
||||
if knob not in PROFILE_KNOBS:
|
||||
return {"status": "error", "reason": "unknown knob", "knob": knob}
|
||||
|
||||
spec = PROFILE_KNOBS[knob]
|
||||
if spec.phase == 2:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "deferred to Phase 2",
|
||||
"knob": knob,
|
||||
"requirement_id": spec.requirement_id,
|
||||
}
|
||||
if spec.phase == 3:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": "deferred to Phase 3",
|
||||
"knob": knob,
|
||||
"requirement_id": spec.requirement_id,
|
||||
}
|
||||
|
||||
ok, reason = _validate(spec.value_schema, value)
|
||||
if not ok:
|
||||
return {
|
||||
"status": "error",
|
||||
"reason": reason,
|
||||
"knob": knob,
|
||||
"schema": spec.value_schema,
|
||||
}
|
||||
|
||||
old_value = state.get(knob, spec.default)
|
||||
state[knob] = value
|
||||
|
||||
# M4 LIVE: emit only on actual change to avoid no-op flood.
|
||||
if store is not None and old_value != value:
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
from iai_mcp.events import write_event
|
||||
write_event(
|
||||
store,
|
||||
kind="profile_updated",
|
||||
data={
|
||||
"knob": knob,
|
||||
"old": old_value,
|
||||
"new": value,
|
||||
"requirement_id": spec.requirement_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
except Exception:
|
||||
# Diagnostic only: never block the profile_set on emit failure.
|
||||
pass
|
||||
|
||||
return {"status": "ok", "knob": knob, "value": value}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- Bayesian
|
||||
|
||||
|
||||
def bayesian_update(
|
||||
knob: str,
|
||||
signal: str,
|
||||
observed: Any,
|
||||
state: dict,
|
||||
posterior: dict,
|
||||
) -> tuple[Any, dict]:
|
||||
"""D-20 weighted-ensemble posterior update on a knob value.
|
||||
|
||||
Conjugate-prior form per schema type:
|
||||
- bool -> Beta(alpha, beta); alpha += w*obs, beta += w*(1-obs)
|
||||
New value is the Beta mode (alpha > beta -> True).
|
||||
- enum -> Dirichlet(alphas); alphas[obs] += w
|
||||
New value is argmax(alphas).
|
||||
- float_range -> Normal mean via weighted running average
|
||||
- int_range -> rounded weighted running average
|
||||
- dict:... -> per-key recursive update (observed must also be a dict)
|
||||
|
||||
Returns (new_value, new_posterior). `posterior` is a dict keyed by knob
|
||||
name with an internal per-knob sub-dict carrying alpha/beta/alphas/mean/n.
|
||||
"""
|
||||
w = SIGNAL_WEIGHT.get(signal, 0.0)
|
||||
if w == 0.0:
|
||||
return state.get(knob, PROFILE_KNOBS[knob].default if knob in PROFILE_KNOBS else None), posterior
|
||||
|
||||
spec = PROFILE_KNOBS.get(knob)
|
||||
if spec is None:
|
||||
return state.get(knob), posterior
|
||||
|
||||
sch = spec.value_schema
|
||||
p = dict(posterior)
|
||||
kp = dict(p.get(knob, {}))
|
||||
|
||||
current = state.get(knob, spec.default)
|
||||
|
||||
if sch == "bool":
|
||||
alpha = float(kp.get("alpha", 1.0))
|
||||
beta = float(kp.get("beta", 1.0))
|
||||
if observed is True:
|
||||
alpha += w
|
||||
elif observed is False:
|
||||
beta += w
|
||||
else:
|
||||
# Invalid observation for bool; degrade silently.
|
||||
return current, p
|
||||
kp["alpha"] = alpha
|
||||
kp["beta"] = beta
|
||||
new_value = alpha >= beta
|
||||
elif sch.startswith("enum:"):
|
||||
allowed = sch[len("enum:"):].split("|")
|
||||
alphas: dict[str, float] = dict(kp.get("alphas", {}))
|
||||
if observed not in allowed:
|
||||
return current, p
|
||||
alphas[observed] = alphas.get(observed, 1.0) + w
|
||||
kp["alphas"] = alphas
|
||||
# Seed with current as implicit prior boost if no entries yet.
|
||||
if current in allowed and current not in alphas:
|
||||
alphas[current] = alphas.get(current, 1.0) + 0.001
|
||||
new_value = max(alphas.keys(), key=lambda k: alphas[k])
|
||||
elif sch.startswith("float_range:"):
|
||||
# Weighted running mean.
|
||||
try:
|
||||
obs_f = float(observed)
|
||||
except (TypeError, ValueError):
|
||||
return current, p
|
||||
prev_sum = float(kp.get("weighted_sum", float(current) if isinstance(current, (int, float)) else 0.0))
|
||||
prev_wts = float(kp.get("total_weight", 0.0))
|
||||
new_sum = prev_sum + w * obs_f
|
||||
new_wts = prev_wts + w
|
||||
mean = new_sum / new_wts if new_wts > 0 else obs_f
|
||||
# Clamp to the schema range.
|
||||
bounds = sch[len("float_range:"):]
|
||||
lo_s, hi_s = bounds.split("..")
|
||||
lo, hi = float(lo_s), float(hi_s)
|
||||
mean = max(lo, min(hi, mean))
|
||||
kp["weighted_sum"] = new_sum
|
||||
kp["total_weight"] = new_wts
|
||||
kp["mean"] = mean
|
||||
new_value = mean
|
||||
elif sch.startswith("int_range:"):
|
||||
try:
|
||||
obs_f = float(observed)
|
||||
except (TypeError, ValueError):
|
||||
return current, p
|
||||
prev_sum = float(kp.get("weighted_sum", float(current) if isinstance(current, (int, float)) else 0.0))
|
||||
prev_wts = float(kp.get("total_weight", 0.0))
|
||||
new_sum = prev_sum + w * obs_f
|
||||
new_wts = prev_wts + w
|
||||
mean = new_sum / new_wts if new_wts > 0 else obs_f
|
||||
bounds = sch[len("int_range:"):]
|
||||
lo_s, hi_s = bounds.split("..")
|
||||
lo, hi = int(lo_s), int(hi_s)
|
||||
new_value = max(lo, min(hi, int(round(mean))))
|
||||
kp["weighted_sum"] = new_sum
|
||||
kp["total_weight"] = new_wts
|
||||
kp["mean"] = mean
|
||||
elif sch.startswith("dict:"):
|
||||
# Per-key recursive update. `observed` must be dict-of-same-shape.
|
||||
if not isinstance(observed, dict):
|
||||
return current, p
|
||||
body = sch[len("dict:"):]
|
||||
_key_type, _, val_type = body.partition(":")
|
||||
per_key_posts: dict[str, dict] = dict(kp.get("per_key", {}))
|
||||
current_dict: dict = dict(current) if isinstance(current, dict) else {}
|
||||
for k, v in observed.items():
|
||||
# Mini-recursion: synthesise a float-style update for the inner value.
|
||||
sub_spec = val_type
|
||||
sub_kp = dict(per_key_posts.get(k, {}))
|
||||
if sub_spec.startswith("float_range:"):
|
||||
try:
|
||||
obs_f = float(v)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
prev_sum = float(sub_kp.get("weighted_sum", float(current_dict.get(k, 0.0))))
|
||||
prev_wts = float(sub_kp.get("total_weight", 0.0))
|
||||
new_sum = prev_sum + w * obs_f
|
||||
new_wts = prev_wts + w
|
||||
mean = new_sum / new_wts if new_wts > 0 else obs_f
|
||||
bounds = sub_spec[len("float_range:"):]
|
||||
lo_s, hi_s = bounds.split("..")
|
||||
lo, hi = float(lo_s), float(hi_s)
|
||||
mean = max(lo, min(hi, mean))
|
||||
sub_kp["weighted_sum"] = new_sum
|
||||
sub_kp["total_weight"] = new_wts
|
||||
sub_kp["mean"] = mean
|
||||
per_key_posts[k] = sub_kp
|
||||
current_dict[k] = mean
|
||||
kp["per_key"] = per_key_posts
|
||||
new_value = current_dict
|
||||
else:
|
||||
return current, p
|
||||
|
||||
p[knob] = kp
|
||||
state[knob] = new_value
|
||||
return new_value, p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- gain
|
||||
|
||||
|
||||
def profile_modulation_for_record(
|
||||
record,
|
||||
profile_state: dict,
|
||||
*,
|
||||
knobs_applied: dict | None = None,
|
||||
) -> dict[str, float]:
|
||||
"""Compute edge-weight gain dict for a record.
|
||||
|
||||
Returned gains are multiplicative (>=1.0 means amplify, <1.0 means damp).
|
||||
Keys match the knob name. Empty dict means no active modulation.
|
||||
|
||||
Current gain sources:
|
||||
- `monotropism_depth`: gain = 1.0 + depth for the record's domain tag.
|
||||
- `interest_boost`: gain = 1.0 + boost (amplifies every record).
|
||||
- `dunn_quadrant`: seeking -> 1.2, avoiding -> 0.8, else no entry.
|
||||
- `special_interest_amplification`: extension (no-op here).
|
||||
|
||||
The record's own `profile_modulation_gain` dict is NOT mutated here; the
|
||||
caller (pipeline_recall) copies the gains onto the record cache after
|
||||
computing them.
|
||||
|
||||
Phase 07.12-03: when ``knobs_applied`` is provided (a dict), records
|
||||
/ / provenance strings into it whenever
|
||||
the corresponding gain branch fires. The accumulator is owned by the
|
||||
caller (typically core.dispatch); this function mutates it in place,
|
||||
pass-by-reference — never reassigns, never returns it.
|
||||
|
||||
BLOCKER 3 (CONTEXT D-04, 2026-04-30): provenance strings MUST contain
|
||||
'profile.py' so the production-path integration test can prove the
|
||||
upstream-gains accumulator is wired in this file (not stubbed elsewhere).
|
||||
Back-compat: callers that don't pass the kwarg behave exactly as before.
|
||||
"""
|
||||
gains: dict[str, float] = {}
|
||||
|
||||
# Monotropism depth per domain tag.
|
||||
md = profile_state.get("monotropism_depth", {})
|
||||
if isinstance(md, dict) and md:
|
||||
for tag in (record.tags or []):
|
||||
if tag.startswith("domain:"):
|
||||
dom = tag.split(":", 1)[1]
|
||||
if dom in md:
|
||||
depth = md[dom]
|
||||
try:
|
||||
gains["monotropism_depth"] = 1.0 + float(depth)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if knobs_applied is not None:
|
||||
knobs_applied["AUTIST-01"] = (
|
||||
"profile.py:profile_modulation_for_record:monotropism_depth"
|
||||
)
|
||||
break
|
||||
|
||||
# Interest boost amplifies any record. (verified line range: 613-616)
|
||||
ib = profile_state.get("interest_boost", 0.0)
|
||||
try:
|
||||
if float(ib) > 0:
|
||||
gains["interest_boost"] = 1.0 + float(ib)
|
||||
if knobs_applied is not None:
|
||||
knobs_applied["AUTIST-09"] = (
|
||||
"profile.py:profile_modulation_for_record:interest_boost"
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Dunn quadrant posture. (verified line range: 621-625)
|
||||
dq = profile_state.get("dunn_quadrant")
|
||||
if dq == "seeking":
|
||||
gains["dunn_quadrant"] = 1.2
|
||||
if knobs_applied is not None:
|
||||
knobs_applied["AUTIST-03"] = (
|
||||
"profile.py:profile_modulation_for_record:dunn_quadrant=seeking"
|
||||
)
|
||||
elif dq == "avoiding":
|
||||
gains["dunn_quadrant"] = 0.8
|
||||
if knobs_applied is not None:
|
||||
knobs_applied["AUTIST-03"] = (
|
||||
"profile.py:profile_modulation_for_record:dunn_quadrant=avoiding"
|
||||
)
|
||||
|
||||
return gains
|
||||
399
src/iai_mcp/provenance_queue.py
Normal file
399
src/iai_mcp/provenance_queue.py
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
"""Plan 05-14 — async provenance write queue (OPS-10 / M-02).
|
||||
|
||||
Moves provenance writes off the recall critical path. A single daemon
|
||||
thread drains a bounded queue.Queue of (record_id, entry) pairs and
|
||||
flushes them via the existing ``MemoryStore.append_provenance_batch``
|
||||
exactly as the sync path did.
|
||||
|
||||
Why this is the right shape:
|
||||
- provenance writes are pure SIDE EFFECTS; pipeline_recall never reads
|
||||
their result. Textbook fire-and-forget candidate.
|
||||
- The biological analogue: consolidation writes happen during rest, not
|
||||
during retrieval (CLS / sleep replay).
|
||||
- The existing ``AsyncWriteQueue`` is for record inserts,
|
||||
which must be durable before their return (S4 viability check reads
|
||||
them back). Provenance has no such contract — a simpler, purpose-built
|
||||
queue avoids the coroutine/event-loop machinery that asyncio imposes.
|
||||
|
||||
Constitutional fences:
|
||||
- Rule 1: worker swallows all exceptions (recall must never fail due
|
||||
to a provenance-write failure).
|
||||
- entries are never dropped during normal operation; on shutdown
|
||||
the atexit hook drains the queue. W1/when the
|
||||
in-memory queue is full under overload, batches are spilled to
|
||||
``~/.iai-mcp/.provenance-overflow/<unix_ms>-<n>.jsonl``. The worker
|
||||
drains the spill dir on idle and re-enqueues the batches. Zero drops
|
||||
on the happy path; the only path that can drop is disk-write failure
|
||||
(alarmed via the ``provenance_queue_spill_failed`` stderr event).
|
||||
- C3 / C6: stdlib only. No extra dependencies.
|
||||
|
||||
Python 3.11+.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
|
||||
# Sentinel pushed on the queue to wake the worker for stop/flush.
|
||||
_STOP = object()
|
||||
_FLUSH = object()
|
||||
|
||||
# W1/D-01 — overflow spill-to-disk.
|
||||
OVERFLOW_DIR_NAME = ".provenance-overflow"
|
||||
# Worker idle poll: 5s upper bound on overflow-drain responsiveness.
|
||||
# Bounded so under sustained overload the spill drain catches up
|
||||
# within a small constant time after _q clears.
|
||||
_WORKER_IDLE_POLL_S = 5.0
|
||||
|
||||
|
||||
class ProvenanceWriteQueue:
|
||||
"""Single-daemon-thread coalescing queue for provenance batches.
|
||||
|
||||
Usage:
|
||||
q = ProvenanceWriteQueue(store, coalesce_ms=50)
|
||||
q.start() # idempotent
|
||||
q.enqueue([(record_id, entry_dict), ...]) # non-blocking
|
||||
q.flush(timeout=2.0) # drain + wait
|
||||
q.stop() # drain + join
|
||||
|
||||
The worker loop:
|
||||
1. Blocking `.get()` on the queue (wakes on enqueue or sentinel).
|
||||
2. Opportunistic drain up to ``max_batch_pairs`` pairs OR until
|
||||
the queue has been empty for ``coalesce_ms``.
|
||||
3. Single call to ``store.append_provenance_batch(pairs,
|
||||
records_cache=None)``.
|
||||
4. Back to (1).
|
||||
|
||||
All worker exceptions are logged to stderr as structured JSON events
|
||||
and swallowed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: "MemoryStore",
|
||||
*,
|
||||
coalesce_ms: int = 50,
|
||||
max_queue_size: int = 4096,
|
||||
max_batch_pairs: int = 256,
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._coalesce_s = max(1, int(coalesce_ms)) / 1000.0
|
||||
self._max_batch = int(max_batch_pairs)
|
||||
# Queue items are either lists of (UUID, dict) pairs or the
|
||||
# _STOP / _FLUSH sentinels.
|
||||
self._q: queue.Queue = queue.Queue(maxsize=int(max_queue_size))
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = False
|
||||
self._stop_requested = False
|
||||
# flush synchronisation: drained_event is set by the worker when
|
||||
# it has processed everything up to a _FLUSH sentinel.
|
||||
self._flush_event = threading.Event()
|
||||
self._atexit_registered = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# ------------------------------------------------------------------ lifecycle
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the worker thread. Idempotent."""
|
||||
with self._lock:
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
self._stop_requested = False
|
||||
self._thread = threading.Thread(
|
||||
target=self._run,
|
||||
name="iai-mcp-provenance-queue",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
if not self._atexit_registered:
|
||||
atexit.register(self._atexit_flush)
|
||||
self._atexit_registered = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker, drain remaining items, join the thread.
|
||||
|
||||
Idempotent. After stop the queue is no longer usable; call
|
||||
start() to revive (fresh worker, same queue instance).
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._started:
|
||||
return
|
||||
self._stop_requested = True
|
||||
try:
|
||||
self._q.put_nowait(_STOP)
|
||||
except queue.Full:
|
||||
# Drop one item to make room for the sentinel.
|
||||
try:
|
||||
self._q.get_nowait()
|
||||
self._q.put_nowait(_STOP)
|
||||
except queue.Empty:
|
||||
pass
|
||||
t = self._thread
|
||||
if t is not None:
|
||||
t.join(timeout=5.0)
|
||||
with self._lock:
|
||||
self._started = False
|
||||
self._thread = None
|
||||
|
||||
def flush(self, timeout: float = 2.0) -> None:
|
||||
"""Wait until the worker has drained everything enqueued so far.
|
||||
|
||||
Puts a _FLUSH sentinel; the worker signals _flush_event once it
|
||||
has processed all pairs that were in the queue at that point.
|
||||
Times out silently — the caller is responsible for deciding
|
||||
whether to retry; recall latency is never blocked by flush().
|
||||
"""
|
||||
if not self._started:
|
||||
return
|
||||
self._flush_event.clear()
|
||||
try:
|
||||
self._q.put(_FLUSH, timeout=timeout)
|
||||
except queue.Full:
|
||||
return
|
||||
self._flush_event.wait(timeout=timeout)
|
||||
|
||||
# ---------------------------------------------------------------- public write
|
||||
|
||||
def enqueue(self, pairs: "list[tuple[UUID, dict]]") -> None:
|
||||
"""Non-blocking enqueue.
|
||||
|
||||
W1/when the in-memory queue is full, the batch
|
||||
spills to ``~/.iai-mcp/.provenance-overflow/<ts>-<n>.jsonl``.
|
||||
The worker thread drains the spill dir on idle and re-enqueues
|
||||
the batches. zero drops under overload (only path that
|
||||
can drop is disk-write failure, which is itself alarmed).
|
||||
"""
|
||||
if not pairs:
|
||||
return
|
||||
try:
|
||||
self._q.put_nowait(list(pairs))
|
||||
return
|
||||
except queue.Full:
|
||||
pass
|
||||
# In-memory queue full — spill to disk. Worker will pick this
|
||||
# up on its next idle cycle. Recall hot path is unaffected
|
||||
# (this branch only fires on the WRITE side under overload).
|
||||
self._spill_to_disk(list(pairs))
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"provenance_queue_overflow_spill","n_pairs":'
|
||||
+ str(len(pairs))
|
||||
+ "}\n"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ---------------------------------------------------------------- spill / drain
|
||||
|
||||
def _spill_to_disk(self, pairs: list) -> None:
|
||||
"""Persist a rejected batch to ``~/.iai-mcp/.provenance-overflow/``.
|
||||
|
||||
Per-batch JSONL file: one line per (uuid_str, entry_dict) pair.
|
||||
File-level atomicity — the worker re-enqueues the entire file's
|
||||
contents in one call, then unlinks. Format:
|
||||
|
||||
{"id": "<uuid>", "entry": {...}}\n
|
||||
{"id": "<uuid>", "entry": {...}}\n
|
||||
|
||||
Failure modes:
|
||||
- Disk full / permission denied: emits structured stderr event
|
||||
``provenance_queue_spill_failed``. This is the ONLY drop path
|
||||
remaining post-07.9 W1; it's a system-level alarm condition,
|
||||
not a normal-operation outcome.
|
||||
"""
|
||||
if not pairs:
|
||||
return
|
||||
try:
|
||||
overflow_dir = Path.home() / ".iai-mcp" / OVERFLOW_DIR_NAME
|
||||
overflow_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts_ms = int(time.time() * 1000)
|
||||
# Tag with the batch length and a short pid suffix so two
|
||||
# spills inside the same millisecond never collide.
|
||||
fpath = overflow_dir / f"{ts_ms}-{len(pairs)}-{id(pairs) & 0xFFFF:04x}.jsonl"
|
||||
tmp_path = fpath.with_suffix(fpath.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as fh:
|
||||
for rid, entry in pairs:
|
||||
fh.write(json.dumps({"id": str(rid), "entry": entry}) + "\n")
|
||||
tmp_path.rename(fpath) # atomic rename keeps drain from
|
||||
# ever reading a half-written file.
|
||||
except Exception as exc:
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"provenance_queue_spill_failed","error":'
|
||||
+ _json_str(str(exc))
|
||||
+ ',"n_pairs":' + str(len(pairs)) + '}\n'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain_overflow_dir(self) -> int:
|
||||
"""Re-enqueue any spilled batches into ``_q``.
|
||||
|
||||
Called by the worker on idle (between blocking `_q.get()` cycles).
|
||||
Per-file atomicity: re-enqueue ALL pairs from a file via a single
|
||||
``_q.put`` call, then unlink. If ``_q`` is still full, leave the
|
||||
file on disk for the next idle cycle.
|
||||
|
||||
Returns the number of pairs successfully re-enqueued in this pass.
|
||||
"""
|
||||
overflow_dir = Path.home() / ".iai-mcp" / OVERFLOW_DIR_NAME
|
||||
if not overflow_dir.exists():
|
||||
return 0
|
||||
n_re_enqueued = 0
|
||||
# sorted() so older spill files drain first (FIFO durability).
|
||||
for fpath in sorted(overflow_dir.glob("*.jsonl")):
|
||||
try:
|
||||
pairs: list = []
|
||||
with fpath.open(encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
pairs.append((UUID(obj["id"]), obj["entry"]))
|
||||
if not pairs:
|
||||
fpath.unlink()
|
||||
continue
|
||||
# Short-timeout put: this is the worker thread, so
|
||||
# blocking briefly is fine, but a long block would
|
||||
# delay normal-path enqueues that arrive during drain.
|
||||
try:
|
||||
self._q.put(pairs, timeout=0.5)
|
||||
except queue.Full:
|
||||
# Queue still saturated — leave the file for the
|
||||
# next idle cycle. Don't unlink.
|
||||
return n_re_enqueued
|
||||
fpath.unlink()
|
||||
n_re_enqueued += len(pairs)
|
||||
except Exception as exc:
|
||||
# Malformed spill file: preserve evidence, do not lose data.
|
||||
try:
|
||||
failed = fpath.with_suffix(f".failed-{int(time.time())}.jsonl")
|
||||
fpath.rename(failed)
|
||||
sys.stderr.write(
|
||||
'{"event":"provenance_queue_spill_drain_failed","error":'
|
||||
+ _json_str(str(exc)) + '}\n'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return n_re_enqueued
|
||||
|
||||
# ------------------------------------------------------------------ internals
|
||||
|
||||
def _run(self) -> None:
|
||||
"""Worker loop.
|
||||
|
||||
W1/between blocking `_q.get()` cycles the worker
|
||||
drains any spilled overflow files at ``~/.iai-mcp/.provenance-overflow/``.
|
||||
Bounded poll: idle-timeout = ``_WORKER_IDLE_POLL_S`` so the spill
|
||||
drain runs at most once per ``_WORKER_IDLE_POLL_S`` seconds when
|
||||
the queue is empty.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
item = self._q.get(timeout=_WORKER_IDLE_POLL_S)
|
||||
except queue.Empty:
|
||||
# Idle tick — try to drain the overflow dir back into _q.
|
||||
# Defensive: any error during drain is logged + swallowed.
|
||||
try:
|
||||
self._drain_overflow_dir()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
if item is _STOP:
|
||||
# Drain remaining real items before exit.
|
||||
self._drain_remaining()
|
||||
return
|
||||
if item is _FLUSH:
|
||||
# Drain everything enqueued before this sentinel.
|
||||
self._drain_remaining()
|
||||
self._flush_event.set()
|
||||
continue
|
||||
# Normal batch. Coalesce: pull more pending items until we
|
||||
# hit max_batch_pairs or a short idle window.
|
||||
pairs: list = list(item)
|
||||
while len(pairs) < self._max_batch:
|
||||
try:
|
||||
nxt = self._q.get(timeout=self._coalesce_s)
|
||||
except queue.Empty:
|
||||
break
|
||||
if nxt is _STOP:
|
||||
# Flush what we have, then exit.
|
||||
self._flush_batch(pairs)
|
||||
self._drain_remaining()
|
||||
return
|
||||
if nxt is _FLUSH:
|
||||
self._flush_batch(pairs)
|
||||
self._drain_remaining()
|
||||
self._flush_event.set()
|
||||
pairs = []
|
||||
break
|
||||
pairs.extend(nxt)
|
||||
if pairs:
|
||||
self._flush_batch(pairs)
|
||||
|
||||
def _drain_remaining(self) -> None:
|
||||
"""Pull everything currently in the queue and flush as one batch."""
|
||||
pairs: list = []
|
||||
saw_flush = False
|
||||
while True:
|
||||
try:
|
||||
item = self._q.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if item is _STOP:
|
||||
continue
|
||||
if item is _FLUSH:
|
||||
saw_flush = True
|
||||
continue
|
||||
pairs.extend(item)
|
||||
if pairs:
|
||||
self._flush_batch(pairs)
|
||||
if saw_flush:
|
||||
self._flush_event.set()
|
||||
|
||||
def _flush_batch(self, pairs: list) -> None:
|
||||
"""Call store.append_provenance_batch, swallow all exceptions (Rule 1)."""
|
||||
if not pairs:
|
||||
return
|
||||
try:
|
||||
self._store.append_provenance_batch(pairs, records_cache=None)
|
||||
except Exception as exc:
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"provenance_queue_flush_failed","n_pairs":'
|
||||
+ str(len(pairs))
|
||||
+ ',"error":'
|
||||
+ _json_str(str(exc))
|
||||
+ "}\n"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _atexit_flush(self) -> None:
|
||||
"""atexit handler — drain and stop the worker. Never raises."""
|
||||
try:
|
||||
if self._started:
|
||||
self.flush(timeout=2.0)
|
||||
self.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _json_str(s: str) -> str:
|
||||
"""Minimal JSON string escape for stderr structured logs."""
|
||||
return '"' + s.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n") + '"'
|
||||
145
src/iai_mcp/quiet_window.py
Normal file
145
src/iai_mcp/quiet_window.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Phase 4 -- activity-learned quiet-window scheduler (DAEMON-03).
|
||||
|
||||
Learn the user's quiet window from their own `session_started` event history.
|
||||
48 buckets of 30-min granularity over a 7-day rolling window. Find the longest
|
||||
contiguous span where bucket activity < threshold. Min 3h, max 8h. Bootstrap
|
||||
when <7 days of data: trigger on 2h MCP idle. Re-learn every 24h.
|
||||
|
||||
Constitutional guard:
|
||||
- learned from events, NOT clock-based.
|
||||
- global-product mandate -- no Western 9-5 assumption, no baked-in
|
||||
local-time default. Respects nocturnal / shift / time-zone-mobile users.
|
||||
- C3: no LLM code, no paid-API env var reference in this module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from iai_mcp.events import query_events
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
# Bucket sizing.
|
||||
BUCKET_COUNT = 48 # 30-min * 48 = 24h
|
||||
BUCKET_MINUTES = 30
|
||||
|
||||
# Window bounds.
|
||||
MIN_WINDOW_HOURS = 3 # discard spans shorter than 3h
|
||||
MAX_WINDOW_HOURS = 8 # human sleep ceiling
|
||||
|
||||
# Learning / bootstrap parameters.
|
||||
MIN_DAYS_FOR_LEARN = 7
|
||||
BOOTSTRAP_IDLE_HOURS = 2 # fallback trigger when <7d data
|
||||
|
||||
# Scheduler cadence gates (used by daemon; exported for caller convenience).
|
||||
WIND_DOWN_GATE_MINUTES_BEFORE = 30 # dual-gate: within 30min of quiet start
|
||||
DIGEST_SHOW_THRESHOLD_HOURS = 18 # morning digest gating (re-exported by daemon_state)
|
||||
|
||||
|
||||
def learn_quiet_window(
|
||||
store: MemoryStore,
|
||||
now: datetime,
|
||||
tz: ZoneInfo,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""Learn the user's quiet window from 7-day session_started history.
|
||||
|
||||
Returns (start_bucket, duration_buckets) in LOCAL time, or None if
|
||||
insufficient data / no contiguous quiet span (caller falls back to the
|
||||
bootstrap idle rule).
|
||||
|
||||
start_bucket: 0..BUCKET_COUNT-1 index into 30-min-bucket local-time day.
|
||||
duration_buckets: number of 30-min buckets in the quiet span (3h=6, 8h=16).
|
||||
"""
|
||||
since = now - timedelta(days=MIN_DAYS_FOR_LEARN)
|
||||
events = query_events(store, kind="session_started", since=since, limit=10000)
|
||||
if not events:
|
||||
return None
|
||||
|
||||
# Count sessions per 30-min local-time bucket + track unique days seen.
|
||||
counts = [0] * BUCKET_COUNT
|
||||
days_seen: set[tuple[int, int, int]] = set()
|
||||
for e in events:
|
||||
ts = e["ts"]
|
||||
# Pandas may surface a Timestamp -- coerce to aware datetime.
|
||||
if not isinstance(ts, datetime):
|
||||
try:
|
||||
ts = ts.to_pydatetime()
|
||||
except Exception:
|
||||
continue
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
try:
|
||||
ts_local = ts.astimezone(tz)
|
||||
except Exception:
|
||||
# DST edge: astimezone is robust on stdlib, but guard anyway.
|
||||
continue
|
||||
bucket = (ts_local.hour * 60 + ts_local.minute) // BUCKET_MINUTES
|
||||
if 0 <= bucket < BUCKET_COUNT:
|
||||
counts[bucket] += 1
|
||||
days_seen.add((ts_local.year, ts_local.month, ts_local.day))
|
||||
|
||||
if len(days_seen) < MIN_DAYS_FOR_LEARN:
|
||||
return None # bootstrap path -- caller uses 2h-idle.
|
||||
|
||||
# Low-activity threshold = 20% of peak.
|
||||
peak = max(counts)
|
||||
if peak == 0:
|
||||
return None
|
||||
threshold = max(1, int(peak * 0.2))
|
||||
|
||||
# Longest contiguous circular span of sub-threshold buckets.
|
||||
# Double-array walk to handle wrap-around across local midnight.
|
||||
doubled = counts + counts
|
||||
best_start, best_len = 0, 0
|
||||
cur_start, cur_len = None, 0
|
||||
for i, c in enumerate(doubled):
|
||||
if c < threshold:
|
||||
if cur_start is None:
|
||||
cur_start = i
|
||||
cur_len = 1
|
||||
else:
|
||||
cur_len += 1
|
||||
if cur_len > best_len:
|
||||
best_start = cur_start
|
||||
best_len = cur_len
|
||||
else:
|
||||
cur_start, cur_len = None, 0
|
||||
|
||||
min_buckets = MIN_WINDOW_HOURS * (60 // BUCKET_MINUTES) # 6
|
||||
max_buckets = MAX_WINDOW_HOURS * (60 // BUCKET_MINUTES) # 16
|
||||
if best_len < min_buckets:
|
||||
# 24/7 user with no contiguous quiet span -> fallback to idle-only.
|
||||
return None
|
||||
duration = min(best_len, max_buckets)
|
||||
# Don't allow a span longer than a full day after wrap.
|
||||
if duration > BUCKET_COUNT:
|
||||
duration = BUCKET_COUNT
|
||||
return (best_start % BUCKET_COUNT, duration)
|
||||
|
||||
|
||||
def should_relearn(last_learned_at: Optional[datetime], now: datetime) -> bool:
|
||||
"""Re-learn cadence: 24h since last learn (D-04 24h adaptation)."""
|
||||
if last_learned_at is None:
|
||||
return True
|
||||
if last_learned_at.tzinfo is None:
|
||||
last_learned_at = last_learned_at.replace(tzinfo=timezone.utc)
|
||||
if now.tzinfo is None:
|
||||
now = now.replace(tzinfo=timezone.utc)
|
||||
return (now - last_learned_at) >= timedelta(hours=24)
|
||||
|
||||
|
||||
def should_bootstrap_trigger(last_session_ts: Optional[datetime], now: datetime) -> bool:
|
||||
"""Bootstrap idle trigger: daemon fires when no MCP session for 2h.
|
||||
|
||||
Used when `learn_quiet_window` returns None (insufficient data or 24/7
|
||||
user). Also used by the daemon as the always-on idle rule in addition to
|
||||
the learned quiet window.
|
||||
"""
|
||||
if last_session_ts is None:
|
||||
return True
|
||||
if last_session_ts.tzinfo is None:
|
||||
last_session_ts = last_session_ts.replace(tzinfo=timezone.utc)
|
||||
if now.tzinfo is None:
|
||||
now = now.replace(tzinfo=timezone.utc)
|
||||
return (now - last_session_ts) >= timedelta(hours=BOOTSTRAP_IDLE_HOURS)
|
||||
439
src/iai_mcp/response_decorator.py
Normal file
439
src/iai_mcp/response_decorator.py
Normal file
|
|
@ -0,0 +1,439 @@
|
|||
"""Plan 05-03 TOK-13 / D5-04 -- server-side profile knob decorator.
|
||||
|
||||
`apply_profile(response, profile)` mutates a response dict in place based on
|
||||
the 11 sealed profile knobs. Every per-knob helper is silent-fail so a
|
||||
malformed knob value can never break the response path.
|
||||
|
||||
C3 invariant (Plan 04): this module is pure-local Python. NO paid-API SDK
|
||||
import. NO API-key env read. The static grep guard
|
||||
`test_no_api_key_in_response_decorator` enforces the invariant at CI time.
|
||||
|
||||
TOK-13 contract: knob NAMES never cross the MCP wire. They are read from
|
||||
the per-process profile state, applied to the response here, and the
|
||||
result goes back over JSON-RPC free of any knob identifiers.
|
||||
|
||||
Helper layout (10 dispatch helpers — one per AUTIST knob the decorator
|
||||
mutates; wake_depth has no helper here, see end note):
|
||||
- _apply_formality_relaxation (AUTIST-13 camouflaging_relaxation)
|
||||
- _apply_monotropic_focus (AUTIST-01 monotropism_depth)
|
||||
- _apply_literal_preservation
|
||||
- _apply_masking_off
|
||||
- _apply_task_support
|
||||
- _apply_scene_construction
|
||||
- _apply_dunn_quadrant
|
||||
- _apply_pda_tolerance (AUTIST-05 demand_avoidance_tolerance)
|
||||
- _apply_interest_boost
|
||||
- _apply_inertia_awareness
|
||||
|
||||
(Phase 07.12-02 removed the dead-knob helpers
|
||||
_apply_sensory_channel_weights / _apply_alexithymia / _apply_double_empathy
|
||||
along with the orphan helpers _apply_verbosity_level / _apply_surface_language
|
||||
that read non-sealed-knob fields.)
|
||||
|
||||
wake_depth affects the session-start payload, not the response
|
||||
shape, so it gets no helper here.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# Phase 07.12-03: HELPER_TO_KNOB_ID maps each apply_profile helper (and the
|
||||
# upstream-gains / session-start virtual keys) to its knob requirement ID.
|
||||
# Used by the dispatch loop to populate response['_knobs_applied'] with
|
||||
# file:symbol provenance for every helper invocation. After Phase 07.12-02
|
||||
# the table contains:
|
||||
# - 8 helper-keyed entries (the AUTIST helpers wired in apply_profile that
|
||||
# produce response-level mutations)
|
||||
# - 2 upstream-gains entries (AUTIST-03 dunn_quadrant, interest_boost)
|
||||
# — provenance strings are written by profile.py:profile_modulation_for_record;
|
||||
# the dispatch loop ignores these virtual keys (HELPER_TO_KNOB_ID.get(...)
|
||||
# returns None for them when keyed by helper name).
|
||||
# - 1 session-start entry (MCP-12 wake_depth) — provenance points into
|
||||
# session.py:assemble_session_start; written by core.dispatch.
|
||||
#
|
||||
# DO NOT re-add removed-knob keys (AUTIST-02 sensory_channel_weights,
|
||||
# event_vs_time_cue, alexithymia_accommodation,
|
||||
# double_empathy) — Plan 07.12-02 deleted them from the registry.
|
||||
HELPER_TO_KNOB_ID: dict[str, str] = {
|
||||
# --- helper-keyed entries (8) — recorded by the dispatch loop -----------
|
||||
"_apply_monotropic_focus": "AUTIST-01", # monotropism_depth
|
||||
"_apply_literal_preservation": "AUTIST-04", # literal_preservation
|
||||
"_apply_pda_tolerance": "AUTIST-05", # demand_avoidance_tolerance
|
||||
"_apply_masking_off": "AUTIST-06", # masking_off
|
||||
"_apply_task_support": "AUTIST-07", # task_support
|
||||
"_apply_inertia_awareness": "AUTIST-10", # inertia_awareness
|
||||
"_apply_formality_relaxation": "AUTIST-13", # camouflaging_relaxation
|
||||
"_apply_scene_construction": "AUTIST-14", # scene_construction_scaffold
|
||||
# --- upstream-gains entries (2) — recorded by profile.py via the kwarg --
|
||||
# These are virtual lookup keys (NOT helper names). The dispatch loop's
|
||||
# HELPER_TO_KNOB_ID.get(helper_name) returns None for the existing pass-
|
||||
# through helpers _apply_dunn_quadrant / _apply_interest_boost because
|
||||
# those helpers are NOT in this table — the AUTHORITATIVE provenance for
|
||||
# the gain is profile.py:profile_modulation_for_record:613-625, written
|
||||
# by the upstream accumulator.
|
||||
"dunn_quadrant": "AUTIST-03", # via profile.py:621-625
|
||||
"interest_boost": "AUTIST-09", # via profile.py:613-616
|
||||
# --- session-start entry (1) — recorded by core.dispatch ---------------
|
||||
# wake_depth is operator-facing; the seed entry is set in
|
||||
# core.dispatch when the session-start path runs. Provenance points
|
||||
# into session.py:373 (assemble_session_start: wake_depth = state.get(...)).
|
||||
"wake_depth": "MCP-12",
|
||||
}
|
||||
|
||||
|
||||
def apply_profile(response: dict, profile: dict) -> dict:
|
||||
"""Apply the 10 dispatch profile knobs to ``response`` in place.
|
||||
|
||||
Contract:
|
||||
- Returns the same response for chainability.
|
||||
- Never raises. Each per-knob helper has its own try/except AND the
|
||||
central dispatch wraps every helper call with an outer guard so a
|
||||
monkey-patched or mis-named helper cannot break the hot path.
|
||||
- Malformed profile state is tolerated (unexpected types, missing keys).
|
||||
- No MCP-side knob names are added to the response.
|
||||
|
||||
Phase 07.12-03 telemetry: emits response['_knobs_applied'] — a dict
|
||||
mapping knob requirement IDs (e.g., 'AUTIST-01') to deterministic
|
||||
file:symbol provenance strings. Future code-readers can audit, per
|
||||
response, which knobs actually mutated which fields. CONTEXT D-04.
|
||||
|
||||
The accumulator is preserved across upstream paths: any entries
|
||||
seeded by core.dispatch BEFORE apply_profile runs (typically the
|
||||
upstream-gains entries for / and the wake_depth
|
||||
seed for MCP-12) survive — the dispatch loop only ADDS entries via
|
||||
helper-keyed lookup, never overwrites the dict shape.
|
||||
"""
|
||||
if not isinstance(response, dict) or not isinstance(profile, dict):
|
||||
return response
|
||||
|
||||
# Phase 07.12-03 BLOCKER 3 fix: preserve any upstream-seeded entries.
|
||||
# core.dispatch seeds knobs_applied for / (via
|
||||
# profile_modulation_for_record) + wake_depth before this
|
||||
# function runs. We extend, never overwrite the dict reference held
|
||||
# by core.dispatch.
|
||||
pre_seeded = response.get("_knobs_applied")
|
||||
if isinstance(pre_seeded, dict):
|
||||
applied: dict[str, str] = pre_seeded
|
||||
else:
|
||||
applied = {}
|
||||
|
||||
# Outer guard per helper call — tolerates a helper that was monkey-patched
|
||||
# to raise (seen in test_pre_existing_keys_untouched_on_exception) or an
|
||||
# accidental helper rewrite that skips the inner try/except.
|
||||
for helper in (
|
||||
_apply_formality_relaxation,
|
||||
_apply_monotropic_focus,
|
||||
_apply_literal_preservation,
|
||||
_apply_masking_off,
|
||||
_apply_task_support,
|
||||
_apply_scene_construction,
|
||||
_apply_dunn_quadrant,
|
||||
_apply_pda_tolerance,
|
||||
_apply_interest_boost,
|
||||
_apply_inertia_awareness,
|
||||
):
|
||||
helper_raised = False
|
||||
try:
|
||||
helper(response, profile)
|
||||
except Exception:
|
||||
helper_raised = True # silent-fail per D5-04 — no audit entry
|
||||
if helper_raised:
|
||||
continue
|
||||
helper_name = helper.__name__
|
||||
knob_id = HELPER_TO_KNOB_ID.get(helper_name)
|
||||
if knob_id is None:
|
||||
# Unmapped helper (e.g., _apply_dunn_quadrant, _apply_interest_boost
|
||||
# — their provenance lives in profile.py via the upstream gains
|
||||
# accumulator). Skip rather than corrupt the audit.
|
||||
continue
|
||||
provenance = f"response_decorator.py:{helper_name}"
|
||||
# No-op markers for the three known mode-gate sites (CONTEXT D-04
|
||||
# line 167 — "consulted and chose to do nothing" vs "knob is dead").
|
||||
if helper_name == "_apply_pda_tolerance":
|
||||
mode = profile.get("demand_avoidance_tolerance", "collaborative")
|
||||
if mode == "neutral":
|
||||
provenance = f"{provenance}:no-op (mode=neutral)"
|
||||
elif helper_name == "_apply_inertia_awareness":
|
||||
if not profile.get("inertia_awareness", False):
|
||||
provenance = f"{provenance}:no-op (knob=False)"
|
||||
elif not response.get("first_turn_recall"):
|
||||
provenance = f"{provenance}:no-op (subsequent turn)"
|
||||
elif helper_name == "_apply_scene_construction":
|
||||
if not profile.get("scene_construction_scaffold", True):
|
||||
provenance = f"{provenance}:no-op (knob=False)"
|
||||
applied[knob_id] = provenance
|
||||
|
||||
response["_knobs_applied"] = applied
|
||||
# wake_depth is the operator-facing knob; it drives session-start payload
|
||||
# shape, not response content. No helper here by design (D5-04). Its
|
||||
# entry is seeded by core.dispatch before apply_profile runs.
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------- per-knob helpers
|
||||
# Each helper MUST be wrapped in try/except Exception: pass — a malformed
|
||||
# profile knob value cannot break the hot recall path.
|
||||
|
||||
|
||||
def _apply_formality_relaxation(response: dict, profile: dict) -> None:
|
||||
"""AUTIST-13 camouflaging_relaxation > 0.5 -> rewrite surface_text toward
|
||||
informal register.
|
||||
|
||||
The transform here is intentionally minimal (just strips trailing
|
||||
"Sir"/"Madam" honorifics). The weekly pass owns the heavy lift; this
|
||||
hook ensures response-time consistency.
|
||||
"""
|
||||
try:
|
||||
level = float(profile.get("camouflaging_relaxation", 0.0))
|
||||
if level <= 0.5:
|
||||
return
|
||||
for hit in response.get("hits", []) or []:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
text = hit.get("literal_surface") or hit.get("surface_text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
# Drop stale honorifics if present (best-effort).
|
||||
stripped = text
|
||||
for honorific in (" Sir.", " Sir,", " Madam.", " Madam,"):
|
||||
stripped = stripped.replace(honorific, ".")
|
||||
if "surface_text" in hit:
|
||||
hit["surface_text"] = stripped
|
||||
# Leave literal_surface byte-exact (C5 invariant).
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_monotropic_focus(response: dict, profile: dict) -> None:
|
||||
"""AUTIST-01 monotropism_depth per domain -> narrow top-k to dominant.
|
||||
|
||||
When any domain in monotropism_depth has depth > 0.7, hits carrying a
|
||||
non-matching domain tag are down-ranked to the tail of the list. The
|
||||
transform is conservative: we reorder, never delete.
|
||||
"""
|
||||
try:
|
||||
md = profile.get("monotropism_depth")
|
||||
if not isinstance(md, dict) or not md:
|
||||
return
|
||||
hot_domains = {d for d, depth in md.items() if _as_float(depth, 0.0) > 0.7}
|
||||
if not hot_domains:
|
||||
return
|
||||
hits = response.get("hits")
|
||||
if not isinstance(hits, list) or not hits:
|
||||
return
|
||||
def _key(h):
|
||||
if not isinstance(h, dict):
|
||||
return 1
|
||||
tags = h.get("tags") or []
|
||||
for t in tags:
|
||||
if isinstance(t, str) and t.startswith("domain:"):
|
||||
return 0 if t.split(":", 1)[1] in hot_domains else 1
|
||||
return 1
|
||||
hits.sort(key=_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_literal_preservation(response: dict, profile: dict) -> None:
|
||||
"""strong -> keep literal_surface byte-exact (default); loose
|
||||
-> surface_text may be summarised. C5 invariant: literal_surface is
|
||||
never mutated.
|
||||
"""
|
||||
try:
|
||||
mode = profile.get("literal_preservation", "strong")
|
||||
if mode not in ("strong", "medium", "loose"):
|
||||
return
|
||||
# No-op by design: the hook exists for future summarisation logic but
|
||||
# must never mutate literal_surface per C5.
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_masking_off(response: dict, profile: dict) -> None:
|
||||
"""masking_off True -> strip performative empathy filler."""
|
||||
try:
|
||||
if not profile.get("masking_off", True):
|
||||
return
|
||||
filler = (
|
||||
"Great question! ",
|
||||
"Certainly! ",
|
||||
"Of course! ",
|
||||
)
|
||||
for hit in response.get("hits", []) or []:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
txt = hit.get("surface_text")
|
||||
if isinstance(txt, str):
|
||||
for f in filler:
|
||||
if txt.startswith(f):
|
||||
hit["surface_text"] = txt[len(f):]
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_task_support(response: dict, profile: dict) -> None:
|
||||
"""cued_recognition -> adjacent_suggestions populated (no-op
|
||||
here because retrieve.recall already emits them); blank_recall -> strip
|
||||
suggestions to force free recall.
|
||||
"""
|
||||
try:
|
||||
mode = profile.get("task_support", "cued_recognition")
|
||||
if mode != "blank_recall":
|
||||
return
|
||||
for hit in response.get("hits", []) or []:
|
||||
if isinstance(hit, dict) and "adjacent_suggestions" in hit:
|
||||
hit["adjacent_suggestions"] = []
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_scene_construction(response: dict, profile: dict) -> None:
|
||||
"""scene_construction_scaffold autobiographical reconstruction
|
||||
hint (Phase 07.12-01).
|
||||
|
||||
PATTERNS.md option-3 reconciliation: the hit dict from _hit_to_json
|
||||
(core.py:712-719) does NOT carry tier/session_id/captured_at, so we drop
|
||||
the tier filter from the original design. When knob=True, attach
|
||||
_scene_hint to EVERY hit; downstream consumers ignore the hint on
|
||||
non-episodic content without harm. The 'advice' string is fixed —
|
||||
no LLM call.
|
||||
|
||||
When False: no _scene_hint key added (test asserts absence).
|
||||
"""
|
||||
try:
|
||||
if not profile.get("scene_construction_scaffold", True):
|
||||
return
|
||||
for hit in response.get("hits", []) or []:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
hit["_scene_hint"] = {
|
||||
"session_id": hit.get("session_id"),
|
||||
"captured_at": hit.get("captured_at"),
|
||||
"advice": "use as scaffold for autobiographical reconstruction",
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_dunn_quadrant(response: dict, profile: dict) -> None:
|
||||
"""dunn_quadrant -> HIPPEA precision is upstream; no-op here."""
|
||||
try:
|
||||
_ = profile.get("dunn_quadrant", "neutral")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_pda_tolerance(response: dict, profile: dict) -> None:
|
||||
"""demand_avoidance_tolerance lexical softener (Phase 07.12-01).
|
||||
|
||||
- collaborative (default): replace leading imperatives in each
|
||||
adjacent_suggestion entry per the frozen substitution table from
|
||||
D-02. Only first-word matches; mid-sentence
|
||||
imperatives are NOT touched (avoids false positives in code blocks).
|
||||
- avoidant: prepend 'FYI: ' to every adjacent_suggestion entry.
|
||||
- neutral: bypass.
|
||||
"""
|
||||
try:
|
||||
mode = profile.get("demand_avoidance_tolerance", "collaborative")
|
||||
if mode == "neutral":
|
||||
return
|
||||
if mode == "avoidant":
|
||||
for hit in response.get("hits", []) or []:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
suggestions = hit.get("adjacent_suggestions")
|
||||
if not isinstance(suggestions, list):
|
||||
continue
|
||||
hit["adjacent_suggestions"] = [
|
||||
f"FYI: {entry}" for entry in suggestions
|
||||
]
|
||||
return
|
||||
if mode == "collaborative":
|
||||
# Frozen table per CONTEXT — DO NOT extend without a phase decision.
|
||||
substitutions: tuple[tuple[str, str], ...] = (
|
||||
("Try ", "You could try "),
|
||||
("Do ", "Consider "),
|
||||
("Use ", "Try using "),
|
||||
("Run ", "Try running "),
|
||||
)
|
||||
for hit in response.get("hits", []) or []:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
suggestions = hit.get("adjacent_suggestions")
|
||||
if not isinstance(suggestions, list):
|
||||
continue
|
||||
rewritten: list = []
|
||||
for entry in suggestions:
|
||||
if not isinstance(entry, str):
|
||||
rewritten.append(entry)
|
||||
continue
|
||||
new_entry = entry
|
||||
for prefix, replacement in substitutions:
|
||||
if entry.startswith(prefix):
|
||||
new_entry = replacement + entry[len(prefix):]
|
||||
break
|
||||
rewritten.append(new_entry)
|
||||
hit["adjacent_suggestions"] = rewritten
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_interest_boost(response: dict, profile: dict) -> None:
|
||||
"""interest_boost > 0 -> amplify hits in interest domains.
|
||||
Applied during scoring, not at response rewrite time; no-op here.
|
||||
"""
|
||||
try:
|
||||
_ = profile.get("interest_boost", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_inertia_awareness(response: dict, profile: dict) -> None:
|
||||
"""inertia_awareness session-resumption cue (Phase 07.12-01).
|
||||
|
||||
BLOCKER 1 fix (CONTEXT D-02, 2026-04-30): the live upstream hook at
|
||||
core.py:1178 sets response["first_turn_recall"] to a DICT, not a bool.
|
||||
The gate MUST be a shape-agnostic truthy check — `is True` equality
|
||||
would silent-no-op in production.
|
||||
|
||||
When knob=True AND response["first_turn_recall"] is truthy (set by
|
||||
_first_turn_recall_hook at core.py:1178 on the first turn of a
|
||||
session), prepend a one-line resumption cue to the top-1 hit's
|
||||
literal_surface. The text is fixed (not LLM-generated) for determinism.
|
||||
|
||||
CONTEXT explicitly forbids the per-recall fallback: if the
|
||||
first_turn_recall flag is unreliable, escalate via checkpoint rather
|
||||
than silently re-introducing recall-noise.
|
||||
|
||||
Subsequent turns OR knob=False → no transform; literal_surface stays
|
||||
byte-exact (C5 invariant).
|
||||
"""
|
||||
try:
|
||||
if not profile.get("inertia_awareness", False):
|
||||
return
|
||||
# Truthy presence check — shape-agnostic (works for dict OR bool).
|
||||
# core.py:1178 sets this to a dict on the first turn; the truthy
|
||||
# check covers both production (dict) and any test path (bool).
|
||||
if not response.get("first_turn_recall"):
|
||||
return
|
||||
hits = response.get("hits") or []
|
||||
if not hits:
|
||||
return
|
||||
top = hits[0]
|
||||
if not isinstance(top, dict):
|
||||
return
|
||||
literal = top.get("literal_surface")
|
||||
if not isinstance(literal, str):
|
||||
return
|
||||
top["literal_surface"] = f"Resuming from your last session: {literal}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- utilities
|
||||
def _as_float(value, default: float) -> float:
|
||||
"""Coerce ``value`` to float; return ``default`` on failure."""
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
701
src/iai_mcp/retrieve.py
Normal file
701
src/iai_mcp/retrieve.py
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
"""Retrieval + reinforcement + contradiction paths.
|
||||
|
||||
- `recall`: baseline cosine top-k -- kept as a fallback for the
|
||||
empty-store case and for regression tests.
|
||||
- `build_runtime_graph`: reconstruct a MemoryGraph + CommunityAssignment +
|
||||
rich-club from LanceDB state; consumed by core.py to drive `pipeline_recall`.
|
||||
- `reinforce_edges`, `contradict`: unchanged from Plan 01.
|
||||
- `link_temporal_next`: records a `record_inserted` event
|
||||
and creates a `temporal_next` edge from the previous same-session insertion
|
||||
to the new record if that event happened within the last 5 minutes.
|
||||
|
||||
Constitutional rules enforced here:
|
||||
- every recall appends a provenance entry to every returned record.
|
||||
- reinforce boosts pairwise Hebbian edges among co-retrieved ids.
|
||||
- edge-based: contradict creates a linked record, preserves original.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from itertools import combinations
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from iai_mcp.aaak import enforce_english_raw, generate_aaak_index
|
||||
from iai_mcp.events import query_events, write_event
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import (
|
||||
EMBED_DIM,
|
||||
EdgeUpdate,
|
||||
MemoryHit,
|
||||
MemoryRecord,
|
||||
RecallResponse,
|
||||
ReconsolidationReceipt,
|
||||
)
|
||||
|
||||
|
||||
# Plan 07.11-02 / structured-log handle for the graph-build
|
||||
# decrypt-failure path. Same one-liner the rest of the project uses
|
||||
# (cf. capture.py:54, pipeline.py:33-imports). Used by the
|
||||
# `graph_build_decrypt_failed` event when AES-GCM decrypt of a
|
||||
# record's literal_surface raises during build_runtime_graph.
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Per-process rate limit for graph_build_decrypt_failed (rid -> monotonic ts).
|
||||
_GRAPH_DECRYPT_WARN_LAST: dict[str, float] = {}
|
||||
_GRAPH_DECRYPT_WARN_INTERVAL_SEC = 300.0
|
||||
|
||||
|
||||
# temporal_next window. Records inserted within this window
|
||||
# in the same session are linked with a temporal_next edge.
|
||||
TEMPORAL_NEXT_WINDOW = timedelta(minutes=5)
|
||||
|
||||
|
||||
def recall(
|
||||
store: MemoryStore,
|
||||
cue_embedding: list[float],
|
||||
cue_text: str,
|
||||
session_id: str,
|
||||
budget_tokens: int = 1500,
|
||||
k_hits: int = 5,
|
||||
k_anti: int = 3,
|
||||
mode: str = "verbatim",
|
||||
) -> RecallResponse:
|
||||
"""Phase 1 baseline retrieval.
|
||||
|
||||
Fetches top (k_hits + k_anti) by cosine similarity; treats the top k_hits as
|
||||
excitatory hits and the bottom k_anti as a naive anti-hit stub. Plan 02 will
|
||||
replace anti-hits with real contradicts-edge + AAAK-opposition logic.
|
||||
|
||||
Every returned hit gets a provenance entry appended.
|
||||
|
||||
R7: `mode` kwarg defaults to 'verbatim'. The baseline
|
||||
is the conservative fallback path (used by core.dispatch when the runtime
|
||||
graph is unavailable / build fails / store is empty). Defaulting to
|
||||
verbatim protects the North-Star ≥99% essential variable on the degraded
|
||||
path — the user never silently lands on a schema-dominated surface even
|
||||
when the full pipeline is unreachable. Verbatim mode applies the same
|
||||
tier filter + schema exclusion as pipeline_recall verbatim mode so the
|
||||
contract on hits[] is identical regardless of which route core dispatched
|
||||
to. Concept mode preserves today's pure-cosine baseline (no filter).
|
||||
"""
|
||||
raw = store.query_similar(cue_embedding, k=k_hits + k_anti)
|
||||
|
||||
# R7: verbatim mode candidate filter on the baseline path.
|
||||
# tier='episodic' AND no pattern:* tag — same exclusion contract as
|
||||
# pipeline_recall verbatim mode (R5). Also excludes D-09
|
||||
# tier='semantic_pruned' soft-deleted schemas naturally.
|
||||
if mode == "verbatim":
|
||||
raw = [
|
||||
(rec, score) for rec, score in raw
|
||||
if rec.tier == "episodic"
|
||||
and not any(t.startswith("pattern:") for t in (rec.tags or []))
|
||||
]
|
||||
|
||||
hits: list[MemoryHit] = []
|
||||
# (D5-01 effect c fix): collect provenance entries during the
|
||||
# hit-building loop, flush via ONE store.append_provenance_batch call
|
||||
# after the loop closes. Replaces the per-hit
|
||||
# `store.append_provenance(record.id, entry)` pattern that produced the
|
||||
# 64x wall-clock blow-up and rank perturbation under memory pressure
|
||||
# (pressplay 8 GB M1, 2026-04-19). Mirrors the L-02 fix already in
|
||||
# src/iai_mcp/pipeline.py::pipeline_recall (see D-SPEED SC-6).
|
||||
provenance_pending: list[tuple[UUID, dict]] = []
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
for record, score in raw[:k_hits]:
|
||||
hits.append(
|
||||
MemoryHit(
|
||||
record_id=record.id,
|
||||
score=float(score),
|
||||
reason=f"cosine {score:.3f}",
|
||||
literal_surface=record.literal_surface,
|
||||
adjacent_suggestions=[], # Plan 03 fills per AUTIST-07
|
||||
)
|
||||
)
|
||||
# every recall appends a provenance entry; write is batched
|
||||
# end-of-loop to preserve rank stability (Plan 05-02 effect c fix).
|
||||
provenance_pending.append((
|
||||
record.id,
|
||||
{
|
||||
"ts": now_iso,
|
||||
"cue": cue_text,
|
||||
"session_id": session_id,
|
||||
},
|
||||
))
|
||||
|
||||
# flush: single merge_insert transaction replaces N read-modify-writes.
|
||||
# Diagnostic-only: never block the user's recall on a provenance-write failure
|
||||
# (Rule 1 -- matches pipeline_recall's defensive contract).
|
||||
if provenance_pending:
|
||||
try:
|
||||
store.append_provenance_batch(provenance_pending)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
anti_hits: list[MemoryHit] = []
|
||||
# Naive anti-hit stub: bottom-k of the same query. Plan 02 replaces with
|
||||
# real contradicts-edge + AAAK-opposition scoring.
|
||||
tail = raw[-k_anti:] if len(raw) >= k_anti else []
|
||||
for record, score in reversed(tail):
|
||||
anti_hits.append(
|
||||
MemoryHit(
|
||||
record_id=record.id,
|
||||
score=float(score),
|
||||
reason="low-similarity baseline anti-hit",
|
||||
literal_surface=record.literal_surface,
|
||||
adjacent_suggestions=[],
|
||||
)
|
||||
)
|
||||
|
||||
# on-read S4 viability check on the baseline recall
|
||||
# path too, so behaviour is consistent regardless of which recall route
|
||||
# core.py dispatches to.
|
||||
try:
|
||||
from iai_mcp.s4 import on_read_check
|
||||
s4_hints = on_read_check(store, hits, session_id=session_id)
|
||||
except Exception:
|
||||
s4_hints = []
|
||||
|
||||
response = RecallResponse(
|
||||
hits=hits,
|
||||
anti_hits=anti_hits,
|
||||
activation_trace=[h.record_id for h in hits],
|
||||
# ~4 chars per token heuristic; Plan 03 benchmark will use Anthropic count_tokens.
|
||||
budget_used=sum(len(h.literal_surface) for h in hits) // 4,
|
||||
hints=s4_hints,
|
||||
# surface mode on the baseline response too. The
|
||||
# baseline does not produce concept-mode patterns_observed (that's
|
||||
# the full pipeline's job — patterns_observed reflects displaced
|
||||
# candidates the rank stage would have surfaced; baseline has no
|
||||
# rank stage). Default [] is correct for both modes here.
|
||||
cue_mode=mode,
|
||||
patterns_observed=[],
|
||||
)
|
||||
|
||||
# (M2 LIVE prerequisite): emit kind='retrieval_used' so M2
|
||||
# precision@5 can be computed live from production emits, not seeded
|
||||
# events. Diagnostic-only: never block the recall path on emit failure.
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
kind="retrieval_used",
|
||||
data={
|
||||
"hit_ids": [str(h.record_id) for h in hits],
|
||||
"query": cue_text,
|
||||
"used": len(hits) > 0,
|
||||
"budget_used": response.budget_used,
|
||||
"path": "baseline_recall",
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def reinforce_edges(
|
||||
store: MemoryStore, ids: list[UUID], delta: float = 0.1
|
||||
) -> EdgeUpdate:
|
||||
"""Hebbian boost on all pairwise edges among co-retrieved ids.
|
||||
|
||||
Pairwise = C(n, 2) combinations. Delta 0.1 is the Phase-1 simple-increment
|
||||
default.
|
||||
"""
|
||||
pairs: list[tuple[UUID, UUID]] = list(combinations(ids, 2))
|
||||
new_weights = store.boost_edges(pairs, delta=delta)
|
||||
# Canonical JSON-string keys (tuples are not JSON-serialisable).
|
||||
new_weights_str = {f"{a}|{b}": float(w) for (a, b), w in new_weights.items()}
|
||||
return EdgeUpdate(
|
||||
edges_boosted=len(pairs),
|
||||
pairs=pairs,
|
||||
new_weights=new_weights_str,
|
||||
)
|
||||
|
||||
|
||||
def contradict(
|
||||
store: MemoryStore,
|
||||
original_id: UUID,
|
||||
new_fact: str,
|
||||
new_embedding: list[float],
|
||||
) -> ReconsolidationReceipt:
|
||||
"""MEM-05 edge-based reconsolidation.
|
||||
|
||||
Creates a new record with `new_fact` and adds a `contradicts` edge from
|
||||
original -> new. Does NOT rewrite the original record -- full amend-in-place
|
||||
is deferred to a future version.
|
||||
"""
|
||||
original = store.get(original_id)
|
||||
if original is None:
|
||||
raise ValueError(f"unknown record {original_id}")
|
||||
# validate against the store's actual embedding dim,
|
||||
# not the legacy hardcoded EMBED_DIM. Migrations and env overrides both
|
||||
# rely on store.embed_dim as source of truth.
|
||||
target_dim = store.embed_dim
|
||||
if len(new_embedding) != target_dim:
|
||||
raise ValueError(
|
||||
f"new_embedding must be {target_dim}d, got {len(new_embedding)}"
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
new_rec = MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier=original.tier,
|
||||
literal_surface=new_fact,
|
||||
aaak_index="",
|
||||
embedding=list(new_embedding),
|
||||
community_id=original.community_id,
|
||||
centrality=0.0,
|
||||
detail_level=original.detail_level,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=(original.detail_level >= 3),
|
||||
never_merge=False,
|
||||
provenance=[{"ts": now.isoformat(), "cue": "contradict", "session_id": "-"}],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["contradict"],
|
||||
# propagate the original record's language tag to the contradiction.
|
||||
# A contradiction is a linguistic amendment; it lives in the same
|
||||
# conversational register as the source.
|
||||
language=getattr(original, "language", "en") or "en",
|
||||
)
|
||||
# H-02: constitutional guard must run on EVERY write path, not just the
|
||||
# L0 seed. A Cyrillic/CJK `new_fact` without an explicit `raw:<lang>` tag
|
||||
# would otherwise land in literal_surface unguarded. Callers who intentionally
|
||||
# store non-English raw capture pre-tag the record via the MCP surface.
|
||||
#
|
||||
# note: once Task 2 ships enforce_language_tagged, call sites in
|
||||
# core.py + retrieve should migrate. For Phase-1 back-compat we keep
|
||||
# enforce_english_raw here so the H-02 Cyrillic-rejection test keeps passing.
|
||||
enforce_english_raw(new_rec)
|
||||
new_rec.aaak_index = generate_aaak_index(new_rec)
|
||||
store.insert(new_rec)
|
||||
store.add_contradicts_edge(original_id, new_rec.id)
|
||||
|
||||
# monotropic proactive check fires only in high-focus
|
||||
# domains. Hints aren't surfaced via contradict() (its signature is fixed
|
||||
# to ReconsolidationReceipt), but events land in the events table so the
|
||||
# user can inspect them via `iai-mcp contradictions` in Plan 02-04.
|
||||
try:
|
||||
from iai_mcp.s4 import monotropic_proactive_check
|
||||
# Deliberately empty profile_state: callers of contradict() don't pass
|
||||
# one; core.py can inject a fuller state via its own wrapper once the
|
||||
# profile is wired to pipeline_recall.
|
||||
monotropic_proactive_check(store, new_rec, {}, session_id="-")
|
||||
except Exception:
|
||||
pass # Rule 1: never block writes on S4 diagnostic path.
|
||||
|
||||
return ReconsolidationReceipt(
|
||||
original_id=original_id,
|
||||
new_record_id=new_rec.id,
|
||||
edge_type="contradicts",
|
||||
ts=now,
|
||||
)
|
||||
|
||||
|
||||
def link_temporal_next(
|
||||
store: MemoryStore,
|
||||
new_record: MemoryRecord,
|
||||
session_id: str,
|
||||
) -> UUID | None:
|
||||
"""create temporal_next edge + record_inserted event.
|
||||
|
||||
Reads the most recent `record_inserted` event (any record) from the events
|
||||
table. If that event happened within TEMPORAL_NEXT_WINDOW AND in the same
|
||||
session, create a `temporal_next` edge from the previous record to the new
|
||||
record.
|
||||
|
||||
Then write a fresh `record_inserted` event marking this insertion.
|
||||
|
||||
Returns the previous record UUID (the edge source) or None if no edge was
|
||||
created (either no prior insert or stale / cross-session).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Look at the last ~20 record_inserted events to find the most recent match.
|
||||
prior_events = query_events(
|
||||
store, kind="record_inserted",
|
||||
since=now - TEMPORAL_NEXT_WINDOW, limit=20,
|
||||
)
|
||||
previous_id: UUID | None = None
|
||||
for ev in prior_events:
|
||||
if ev.get("session_id") != session_id:
|
||||
continue
|
||||
raw = ev["data"].get("record_id")
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
candidate = UUID(raw)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if candidate == new_record.id:
|
||||
continue
|
||||
previous_id = candidate
|
||||
break # events are newest-first
|
||||
|
||||
if previous_id is not None:
|
||||
try:
|
||||
store.boost_edges(
|
||||
[(previous_id, new_record.id)],
|
||||
edge_type="temporal_next",
|
||||
delta=1.0,
|
||||
)
|
||||
except Exception:
|
||||
# Diagnostic only; don't block the write path on edge failure.
|
||||
pass
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="record_inserted",
|
||||
data={
|
||||
"record_id": str(new_record.id),
|
||||
"tier": new_record.tier,
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=[new_record.id],
|
||||
)
|
||||
return previous_id
|
||||
|
||||
|
||||
def _make_graph_sync_hook(G):
|
||||
"""factory for the store -> graph mutation callback.
|
||||
|
||||
Returned callable dispatches on ``op`` (insert|update|delete) and
|
||||
mutates ``G`` (a NetworkX Graph) in-place. On unknown op or any
|
||||
payload shape error, the hook is a quiet no-op — the store's
|
||||
try/except surface turns exceptions into stderr events anyway, but
|
||||
we stay defensive here so hook-level bugs never reach the store.
|
||||
"""
|
||||
def _hook(op: str, record) -> None:
|
||||
nid = str(record.id)
|
||||
if op == "insert":
|
||||
payload = {
|
||||
"embedding": list(record.embedding),
|
||||
"surface": record.literal_surface,
|
||||
"centrality": float(record.centrality),
|
||||
"tier": record.tier,
|
||||
"pinned": bool(record.pinned),
|
||||
"tags": list(getattr(record, "tags", []) or []),
|
||||
"language": str(getattr(record, "language", "en") or "en"),
|
||||
}
|
||||
G.add_node(nid, **payload)
|
||||
elif op == "update":
|
||||
payload = {
|
||||
"embedding": list(record.embedding),
|
||||
"surface": record.literal_surface,
|
||||
"centrality": float(record.centrality),
|
||||
"tier": record.tier,
|
||||
"pinned": bool(record.pinned),
|
||||
"tags": list(getattr(record, "tags", []) or []),
|
||||
"language": str(getattr(record, "language", "en") or "en"),
|
||||
}
|
||||
if nid in G.nodes:
|
||||
G.nodes[nid].update(payload)
|
||||
else:
|
||||
G.add_node(nid, **payload)
|
||||
elif op == "delete":
|
||||
if nid in G.nodes:
|
||||
G.remove_node(nid)
|
||||
# Unknown op: silently ignore. The store writes are authoritative;
|
||||
# unknown ops will be picked up on the next full rebuild.
|
||||
return _hook
|
||||
|
||||
|
||||
def build_runtime_graph(store: MemoryStore):
|
||||
"""Reconstruct MemoryGraph + CommunityAssignment + rich-club from LanceDB.
|
||||
|
||||
Called by core.py's `memory_recall` dispatch when the store is non-empty.
|
||||
(P4.A): the expensive pieces -- Leiden community
|
||||
detection + rich-club selection -- are cached to disk in
|
||||
``runtime_graph_cache.json`` keyed on the store's (records_count,
|
||||
edges_count, schema_version, embed_dim) tuple. Cache hit skips
|
||||
~230 ms of Leiden + rich-club work. MemoryGraph itself is rebuilt
|
||||
on every call from the LanceDB rows because caching it would
|
||||
require a non-JSON format for the NetworkX object.
|
||||
|
||||
(hot-path switch): every graph node carries the record's
|
||||
payload (embedding, surface, centrality, tier, pinned) as NetworkX
|
||||
node attributes. ``pipeline._read_record_payload`` reads from these
|
||||
attributes at seed + spread stages, eliminating the per-id
|
||||
``store.get`` LanceDB round-trips that dominated at N=1k
|
||||
(737 ms -> target ~20-30 ms). A ``_graph_sync_hook`` is registered
|
||||
on the store so insert/update/delete mirror their mutations to the
|
||||
in-RAM graph; hook failures are logged, never raised (write-path
|
||||
authoritative). On cache HIT the node_payload blob rehydrates the
|
||||
NetworkX attributes directly; MISS rebuilds them from the fresh
|
||||
store.all_records() walk that was already happening for the graph.
|
||||
|
||||
Returns (graph, assignment, rich_club).
|
||||
|
||||
Local imports keep the heavy graph/community modules out of Plan-01's
|
||||
hot path (core.py module-load time stays small).
|
||||
"""
|
||||
from iai_mcp.community import CommunityAssignment, detect_communities
|
||||
from iai_mcp.graph import MemoryGraph
|
||||
from iai_mcp.richclub import rich_club_nodes
|
||||
from iai_mcp import runtime_graph_cache
|
||||
|
||||
graph = MemoryGraph()
|
||||
|
||||
# try the on-disk cache before running Leiden + rich-club.
|
||||
# Cache-first so we can consult the v2 node_payload blob for free.
|
||||
cached = runtime_graph_cache.try_load(store)
|
||||
assignment = None
|
||||
rich_club = None
|
||||
cached_node_payload: dict[str, dict] | None = None
|
||||
# R2: cached max_degree rehydrates without re-walking the
|
||||
# NetworkX graph. Used as a defensive fallback if the live degree
|
||||
# walk below fails for any reason.
|
||||
cached_max_degree: int = 0
|
||||
if cached is not None:
|
||||
assignment, rich_club, cached_node_payload, cached_max_degree = cached
|
||||
|
||||
# Build nodes. If the cache gave us a node_payload blob AND the store
|
||||
# record count matches, reuse it — skips the encrypted LanceDB scan.
|
||||
# Otherwise fall through to the full row walk so node attrs stay
|
||||
# strictly derived from the authoritative store.
|
||||
records_tbl = store.db.open_table("records")
|
||||
records_count = int(records_tbl.count_rows())
|
||||
use_cached_payload = (
|
||||
cached_node_payload is not None
|
||||
and len(cached_node_payload) == records_count
|
||||
)
|
||||
|
||||
if use_cached_payload:
|
||||
# Fast path: graph nodes + attributes come from the cache JSON.
|
||||
for nid, payload in cached_node_payload.items():
|
||||
# MemoryGraph.add_node has a fixed signature; use it for
|
||||
# topology, then pour the full payload into the NetworkX
|
||||
# node attribute dict.
|
||||
graph.add_node(
|
||||
UUID(nid),
|
||||
community_id=None,
|
||||
embedding=list(payload.get("embedding") or []),
|
||||
)
|
||||
graph._nx.nodes[nid].update({
|
||||
"embedding": list(payload.get("embedding") or []),
|
||||
"surface": payload.get("surface", ""),
|
||||
"centrality": float(payload.get("centrality") or 0.0),
|
||||
"tier": payload.get("tier", "episodic"),
|
||||
"pinned": bool(payload.get("pinned", False)),
|
||||
"tags": list(payload.get("tags") or []),
|
||||
"language": str(payload.get("language", "en") or "en"),
|
||||
})
|
||||
node_payload_for_cache = cached_node_payload
|
||||
else:
|
||||
# MISS path: walk the records table, attach payload at
|
||||
# graph.add_node time, and remember the payload so we can
|
||||
# persist it into the cache below.
|
||||
df = records_tbl.to_pandas()
|
||||
node_payload_for_cache = {}
|
||||
decrypt_fail_events = 0
|
||||
decrypt_fail_unique: set[str] = set()
|
||||
for _, row in df.iterrows():
|
||||
rid = UUID(row["id"])
|
||||
community_id = (
|
||||
UUID(row["community_id"])
|
||||
if row["community_id"]
|
||||
else None
|
||||
)
|
||||
embedding = (
|
||||
list(row["embedding"])
|
||||
if row["embedding"] is not None
|
||||
else [0.0] * EMBED_DIM
|
||||
)
|
||||
# literal_surface is AES-GCM encrypted at rest.
|
||||
# Decrypt here via the store's helper so the graph payload
|
||||
# carries plaintext the pipeline can use directly.
|
||||
literal_raw = row.get("literal_surface") or ""
|
||||
try:
|
||||
from iai_mcp.crypto import is_encrypted
|
||||
if is_encrypted(literal_raw):
|
||||
literal_raw = store._decrypt_for_record(rid, literal_raw)
|
||||
except Exception:
|
||||
# Plan 07.11-02 / (V2-03 fix): a decrypt failure here
|
||||
# used to assign ``literal_raw = ""`` and then fall through
|
||||
# to update the live NetworkX node + persist to
|
||||
# ``node_payload_for_cache``. That empty-surface payload
|
||||
# then poisoned the on-disk runtime_graph_cache, and on
|
||||
# warm-restart pipeline._read_record_payload happily
|
||||
# returned ``literal_surface=""`` claiming success —
|
||||
# silent corruption of verbatim recall.
|
||||
#
|
||||
# Skip-the-node approach (chosen over the _decrypt_failed
|
||||
# sentinel-flag because it produces the smallest disk
|
||||
# footprint and the simplest invariant: "the cache
|
||||
# contains only records whose surface successfully
|
||||
# decrypted"). The pipeline read path falls back to
|
||||
# store.get(rid) which has its own retry semantics in
|
||||
# crypto.py.
|
||||
#
|
||||
# Tail-end mandate: per-record ``graph_build_decrypt_failed``
|
||||
# warnings are rate-limited (default 300s) so wrong-key floods
|
||||
# do not spam launchd stderr; a per-build summary still fires.
|
||||
rid_s = str(rid)
|
||||
decrypt_fail_events += 1
|
||||
decrypt_fail_unique.add(rid_s)
|
||||
now_m = time.monotonic()
|
||||
last_m = _GRAPH_DECRYPT_WARN_LAST.get(rid_s, 0.0)
|
||||
if now_m - last_m >= _GRAPH_DECRYPT_WARN_INTERVAL_SEC:
|
||||
_GRAPH_DECRYPT_WARN_LAST[rid_s] = now_m
|
||||
log.warning(
|
||||
"graph_build_decrypt_failed",
|
||||
extra={"record_id": rid_s},
|
||||
)
|
||||
continue
|
||||
|
||||
tier = row.get("tier") or "episodic"
|
||||
centrality = float(row.get("centrality") or 0.0)
|
||||
pinned = bool(row.get("pinned") or False)
|
||||
# tags travel on graph nodes so the rank stage's
|
||||
# SimpleRecordView carries tags for profile_modulation_for_record
|
||||
# without needing a store.get fallback in the hot path.
|
||||
tags_raw = row.get("tags_json") or "[]"
|
||||
try:
|
||||
import json as _json
|
||||
tags_list = _json.loads(tags_raw) if isinstance(tags_raw, str) else list(tags_raw)
|
||||
if not isinstance(tags_list, list):
|
||||
tags_list = []
|
||||
except Exception:
|
||||
tags_list = []
|
||||
language = str(row.get("language") or "en")
|
||||
|
||||
graph.add_node(
|
||||
rid,
|
||||
community_id=community_id,
|
||||
embedding=embedding,
|
||||
)
|
||||
# Plan 05-12/05-13: attach record payload to the NetworkX node dict.
|
||||
graph._nx.nodes[str(rid)].update({
|
||||
"embedding": list(embedding),
|
||||
"surface": str(literal_raw),
|
||||
"centrality": centrality,
|
||||
"tier": str(tier),
|
||||
"pinned": pinned,
|
||||
"tags": list(tags_list),
|
||||
"language": language,
|
||||
})
|
||||
node_payload_for_cache[str(rid)] = {
|
||||
"embedding": list(embedding),
|
||||
"surface": str(literal_raw),
|
||||
"centrality": centrality,
|
||||
"tier": str(tier),
|
||||
"pinned": pinned,
|
||||
"tags": list(tags_list),
|
||||
"language": language,
|
||||
}
|
||||
|
||||
if decrypt_fail_events > 0:
|
||||
log.warning(
|
||||
"graph_build_decrypt_failed_summary",
|
||||
extra={
|
||||
"unique_records": len(decrypt_fail_unique),
|
||||
"total_skip_events": decrypt_fail_events,
|
||||
},
|
||||
)
|
||||
|
||||
edges_df = store.db.open_table("edges").to_pandas()
|
||||
for _, row in edges_df.iterrows():
|
||||
graph.add_edge(
|
||||
UUID(row["src"]),
|
||||
UUID(row["dst"]),
|
||||
weight=float(row["weight"]),
|
||||
edge_type=row["edge_type"],
|
||||
)
|
||||
|
||||
# R2: cache the maximum graph degree so the rank stage
|
||||
# can normalise log(1+deg) into [0,1] (sample-rank-comparable to
|
||||
# cosine; W_DEGREE * deg_norm bounded by W_DEGREE itself instead of
|
||||
# by an unbounded log term that scales with hub connectivity).
|
||||
# Computed once per build; rehydrated from disk on warm starts via
|
||||
# the runtime_graph_cache.json payload. Defensive: fall back to the
|
||||
# cached value if the live degree() walk fails for any reason — and
|
||||
# never let a bare AttributeError reach the rank stage.
|
||||
try:
|
||||
deg_values = [d for _, d in graph._nx.degree()]
|
||||
max_degree = max(deg_values) if deg_values else 0
|
||||
except Exception:
|
||||
max_degree = cached_max_degree
|
||||
if max_degree == 0 and cached_max_degree > 0:
|
||||
# Live walk produced 0 (no edges yet) but the cache held a real
|
||||
# value — prefer the cached value. Triggers when an upstream
|
||||
# path stripped edges before the rebuild reached us.
|
||||
max_degree = cached_max_degree
|
||||
graph._max_degree = int(max_degree)
|
||||
|
||||
# Run (or reuse cached) Leiden + rich-club.
|
||||
if assignment is None:
|
||||
assignment = detect_communities(graph, prior=None)
|
||||
rich_club = rich_club_nodes(graph, percent=0.10)
|
||||
|
||||
# compute betweenness centrality ONCE per build
|
||||
# and attach to every node as a NetworkX attribute so the rank stage
|
||||
# can read it O(1) instead of calling graph.centrality() on every
|
||||
# recall (the pre-05-13 hot path). Cache HIT path already rehydrated
|
||||
# centrality from node_payload into node attrs above; we only
|
||||
# (re)compute when the cache payload is absent / stale or when
|
||||
# node_payload centrality values are all-zero placeholders.
|
||||
needs_centrality = True
|
||||
if use_cached_payload and cached_node_payload is not None:
|
||||
# If the cache was written AFTER 05-13 the per-node centrality
|
||||
# floats are real (possibly non-zero). If every value is exactly
|
||||
# 0.0 the cache was written pre-05-13 shape — recompute to
|
||||
# populate the live graph, then a subsequent save() below will
|
||||
# upgrade the cache.
|
||||
any_nonzero = any(
|
||||
float(p.get("centrality") or 0.0) != 0.0
|
||||
for p in cached_node_payload.values()
|
||||
)
|
||||
needs_centrality = not any_nonzero
|
||||
if needs_centrality:
|
||||
try:
|
||||
centrality_map = graph.centrality()
|
||||
for rid, cval in centrality_map.items():
|
||||
nid_str = str(rid)
|
||||
if nid_str in graph._nx.nodes:
|
||||
graph._nx.nodes[nid_str]["centrality"] = float(cval)
|
||||
if (
|
||||
node_payload_for_cache is not None
|
||||
and nid_str in node_payload_for_cache
|
||||
):
|
||||
node_payload_for_cache[nid_str]["centrality"] = (
|
||||
float(cval)
|
||||
)
|
||||
except Exception:
|
||||
# Defensive: centrality is a ranking signal, not a
|
||||
# correctness invariant; fall back to zeros on failure.
|
||||
for nid_str in graph._nx.nodes:
|
||||
graph._nx.nodes[nid_str].setdefault("centrality", 0.0)
|
||||
|
||||
# Persist — fresh build, or cache was legacy 05-09 / 05-12 shape.
|
||||
if cached_node_payload is None or needs_centrality:
|
||||
runtime_graph_cache.save(
|
||||
store, assignment, rich_club,
|
||||
node_payload=node_payload_for_cache,
|
||||
# R2: max_degree travels with assignment + rich_club
|
||||
# so warm-start build_runtime_graph rehydrates without recompute.
|
||||
max_degree=int(getattr(graph, "_max_degree", 0) or 0),
|
||||
)
|
||||
|
||||
# register the graph-sync hook so future insert/update/
|
||||
# delete calls mutate the live graph instead of diverging. The store
|
||||
# swallows hook exceptions so a buggy hook never breaks a write.
|
||||
try:
|
||||
store.register_graph_sync_hook(_make_graph_sync_hook(graph._nx))
|
||||
except Exception:
|
||||
# Older store without register_graph_sync_hook — this is a
|
||||
# defensive upgrade path; the graph just won't stay live-sync'd.
|
||||
pass
|
||||
|
||||
# R2 belt-and-braces: every code path above sets
|
||||
# graph._max_degree, but if some future refactor short-circuits
|
||||
# before reaching the live degree walk we still want the rank
|
||||
# stage's `getattr(graph, "_max_degree", 0)` to read a real int.
|
||||
if not hasattr(graph, "_max_degree"):
|
||||
graph._max_degree = 0
|
||||
|
||||
return graph, assignment, rich_club
|
||||
35
src/iai_mcp/richclub.py
Normal file
35
src/iai_mcp/richclub.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Rich-club pre-fetch (CONN-02).
|
||||
|
||||
Top 10% of nodes by centrality. Used by pipeline.pipeline_recall at stage 4
|
||||
(union with 2-hop spread) and by Plan 03's session-start assembler to pre-warm
|
||||
the Anthropic prompt cache with a stable global-hub set.
|
||||
|
||||
van den Heuvel & Sporns 2011 (J Neurosci 31:15775) observed that the top ~10%
|
||||
of hub nodes handle ~69% of the network's shortest-path traffic. We use the
|
||||
same percentile as the pre-fetch size.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from math import ceil
|
||||
from uuid import UUID
|
||||
|
||||
from iai_mcp.graph import MemoryGraph
|
||||
|
||||
|
||||
def rich_club_nodes(graph: MemoryGraph, percent: float = 0.10) -> list[UUID]:
|
||||
"""CONN-02: top `percent` fraction of nodes by centrality.
|
||||
|
||||
- Empty graph -> [].
|
||||
- Non-empty graph -> at least 1 node (ceil) even if percent rounds to 0.
|
||||
A rich club of zero is useless at the pipeline's Stage 4 union step.
|
||||
- Deterministic tie-break: dict.items() preserves insertion order; sort
|
||||
is stable, so equal-centrality nodes keep their insertion ordering.
|
||||
"""
|
||||
if graph.node_count() == 0:
|
||||
return []
|
||||
centrality = graph.centrality()
|
||||
if not centrality:
|
||||
return []
|
||||
k = max(1, ceil(len(centrality) * percent))
|
||||
ranked = sorted(centrality.items(), key=lambda kv: kv[1], reverse=True)
|
||||
return [node_id for node_id, _ in ranked[:k]]
|
||||
642
src/iai_mcp/runtime_graph_cache.py
Normal file
642
src/iai_mcp/runtime_graph_cache.py
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
"""Plan 05-09 (P4.A): persist Leiden community assignment + rich-club
|
||||
to disk so the first ``memory_recall`` call in a fresh core process
|
||||
does not rebuild these expensive artefacts from scratch.
|
||||
|
||||
The Phase-1 ``retrieve.build_runtime_graph`` rebuilds everything on
|
||||
every call:
|
||||
|
||||
graph = MemoryGraph() # ~100 ms to construct from rows
|
||||
detect_communities(graph) # Leiden, ~200 ms at N=1k
|
||||
rich_club_nodes(graph, 0.10) # ~20 ms
|
||||
|
||||
Phase-5 P4 measured first-call cold path at ~440 ms at N=1k. Caching
|
||||
the *Leiden output* and the rich-club node list eliminates the two
|
||||
expensive computations when the store has not changed. MemoryGraph
|
||||
construction itself is cheap enough to rebuild per call; caching it
|
||||
too would require pickle (the NetworkX graph is not JSON-friendly)
|
||||
and the security-vs-speed trade-off is not worth it for ~100 ms.
|
||||
|
||||
**Invalidation** — any of these triggers a rebuild:
|
||||
|
||||
- Record count changed (user saved / consolidated / merged)
|
||||
- Edge count changed (Hebbian reinforcement or contradiction added)
|
||||
- SCHEMA_VERSION_CURRENT bumped (store migrated)
|
||||
- store.embed_dim changed (user swapped embedder; Plan 05-08)
|
||||
- CACHE_VERSION bumped (this module's on-disk format changed)
|
||||
|
||||
Any inconsistency — corrupt JSON, unreadable file, unknown keys —
|
||||
falls through to a clean rebuild. The cache is purely an optimisation;
|
||||
the authoritative graph is always the LanceDB store.
|
||||
|
||||
**Write strategy**: every ``save()`` writes a ``.tmp`` file first then
|
||||
``os.replace``s it over the real path — atomic on POSIX. A crash
|
||||
mid-write leaves either the old cache intact or no cache at all;
|
||||
never a partially written file. No flush timer; the cache refreshes
|
||||
on the next ``build_runtime_graph`` call when the key changes.
|
||||
|
||||
**Why JSON not pickle**: the cached payload is list-of-UUIDs,
|
||||
list-of-floats and scalars — all JSON-native after simple UUID→str
|
||||
conversion. JSON avoids the arbitrary-code-execution risk of pickle
|
||||
and makes the cache auditable (a user can cat the file to see what
|
||||
the brain thinks its communities are).
|
||||
|
||||
Constitutional invariants:
|
||||
|
||||
- C3 (zero API): pure local JSON + filesystem operations.
|
||||
- C6 (read-only against store): cache writes go to the cache file
|
||||
only, never to any LanceDB table.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from iai_mcp.crypto import (
|
||||
CryptoKey,
|
||||
decrypt_field,
|
||||
encrypt_field,
|
||||
is_encrypted,
|
||||
)
|
||||
from iai_mcp.types import SCHEMA_VERSION_CURRENT
|
||||
|
||||
|
||||
# Bump this whenever the on-disk cache shape changes. A mismatch
|
||||
# forces every user on the old shape to rebuild -- safer than silently
|
||||
# loading a file whose key contract has drifted.
|
||||
#
|
||||
# R2: bumped to "06-02-v1" — payload now carries max_degree
|
||||
# (one int) so the rank stage can normalise log(1+deg) by log(1+max_deg)
|
||||
# without re-walking the live graph on every recall. Old caches lacking
|
||||
# the field are invalidated cleanly by the version bump and rebuild on
|
||||
# the next build_runtime_graph call.
|
||||
#
|
||||
# W3 / bumped to "07-09-v3" — cache file is now
|
||||
# AES-256-GCM-wrapped. Old "06-02-v1" caches that pre-date 07.9 are
|
||||
# treated as legacy plaintext: read once, lazily re-saved as ciphertext
|
||||
# on first warm-start under 07.9, then never read again.
|
||||
CACHE_VERSION: str = "07-09-v3"
|
||||
LEGACY_CACHE_VERSION_PLAINTEXT: str = "06-02-v1"
|
||||
|
||||
# AES-GCM associated data (AD): binds the ciphertext to this format and
|
||||
# version. A bytewise tampering attempt that swaps the file with a
|
||||
# v06-02-v1 plaintext or any other stream fails the decrypt tag check.
|
||||
_CACHE_AAD: bytes = b"runtime-graph-cache:v3"
|
||||
|
||||
CACHE_FILENAME: str = "runtime_graph_cache.json"
|
||||
|
||||
# Size cap for the on-disk cache. When the encoded payload exceeds this,
|
||||
# ``save`` drops ``node_payload`` (the large per-record embedding map) and
|
||||
# writes only ``assignment + rich_club``. Cold-start ``build_runtime_graph``
|
||||
# rehydrates the node payload from the LanceDB store on the next recall;
|
||||
# the cache remains advisory. 10 MiB holds the Leiden + rich-club artefacts
|
||||
# for a ~50k-record store comfortably while keeping cold-start load under
|
||||
# the session-start token budget.
|
||||
MAX_CACHE_BYTES: int = 10 * 1024 * 1024
|
||||
|
||||
|
||||
def _cache_path(store: Any) -> Path:
|
||||
"""Cache file lives next to the LanceDB directory so it travels with
|
||||
the store on backup / move. One cache file per MemoryStore."""
|
||||
root = getattr(store, "root", None)
|
||||
if root is None:
|
||||
root = Path.cwd()
|
||||
return Path(root) / CACHE_FILENAME
|
||||
|
||||
|
||||
def _cache_encryption_key(store: Any) -> bytes:
|
||||
"""Phase 07.9 W3 / 32-byte AES key for the runtime-graph-cache
|
||||
sidecar. Reuses the store's already-cached key whenever possible to
|
||||
avoid a second keyring round-trip. Falls back to a fresh CryptoKey
|
||||
lookup keyed on the store's user_id (or "default") when the store
|
||||
doesn't expose a cached key — the same passphrase / keyring contract
|
||||
applies, so the resolved key is identical.
|
||||
"""
|
||||
# MemoryStore caches its key after the first encryption call
|
||||
# (store.py:_key()); that's the cheapest path. Defensive getattr
|
||||
# so this module stays usable from non-store call sites in tests.
|
||||
cached_via_store = getattr(store, "_crypto_key", None)
|
||||
if isinstance(cached_via_store, (bytes, bytearray)) and len(cached_via_store) == 32:
|
||||
return bytes(cached_via_store)
|
||||
if hasattr(store, "_key") and callable(store._key):
|
||||
try:
|
||||
key = store._key()
|
||||
if isinstance(key, (bytes, bytearray)) and len(key) == 32:
|
||||
return bytes(key)
|
||||
except Exception:
|
||||
pass
|
||||
user_id = getattr(store, "user_id", "default") or "default"
|
||||
return CryptoKey(user_id=user_id).get_or_create()
|
||||
|
||||
|
||||
def _cache_key(store: Any) -> tuple:
|
||||
"""Monotonic identity for "the cached graph is still correct for this
|
||||
store state". Any change to a component invalidates the cache.
|
||||
|
||||
(records_count, edges_count, schema_version, embed_dim, cache_version)
|
||||
"""
|
||||
try:
|
||||
records_count = int(store.db.open_table("records").count_rows())
|
||||
except Exception:
|
||||
records_count = -1
|
||||
try:
|
||||
edges_count = int(store.db.open_table("edges").count_rows())
|
||||
except Exception:
|
||||
edges_count = -1
|
||||
embed_dim = int(getattr(store, "embed_dim", 0))
|
||||
return (
|
||||
records_count,
|
||||
edges_count,
|
||||
SCHEMA_VERSION_CURRENT,
|
||||
embed_dim,
|
||||
CACHE_VERSION,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------ JSON encode/decode
|
||||
|
||||
|
||||
def _encode_assignment(assignment: Any) -> dict:
|
||||
"""Serialise CommunityAssignment to a JSON-friendly dict.
|
||||
|
||||
node_to_community and mid_regions have UUID keys; community_centroids
|
||||
is {UUID: [float]}. UUIDs are stringified; floats stay native.
|
||||
"""
|
||||
return {
|
||||
"node_to_community": {
|
||||
str(leaf): str(comm)
|
||||
for leaf, comm in getattr(assignment, "node_to_community", {}).items()
|
||||
},
|
||||
"community_centroids": {
|
||||
str(comm): list(vec)
|
||||
for comm, vec in getattr(assignment, "community_centroids", {}).items()
|
||||
},
|
||||
"modularity": float(getattr(assignment, "modularity", 0.0)),
|
||||
"backend": str(getattr(assignment, "backend", "flat")),
|
||||
"top_communities": [str(c) for c in getattr(assignment, "top_communities", [])],
|
||||
"mid_regions": {
|
||||
str(comm): [str(m) for m in members]
|
||||
for comm, members in getattr(assignment, "mid_regions", {}).items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _decode_assignment(raw: dict) -> Any:
|
||||
"""Inverse of _encode_assignment. Imports CommunityAssignment lazily
|
||||
so this module does not pull in the community layer for callers that
|
||||
only want to poke the cache file."""
|
||||
from iai_mcp.community import CommunityAssignment
|
||||
|
||||
return CommunityAssignment(
|
||||
node_to_community={
|
||||
UUID(leaf): UUID(comm)
|
||||
for leaf, comm in raw.get("node_to_community", {}).items()
|
||||
},
|
||||
community_centroids={
|
||||
UUID(comm): list(vec)
|
||||
for comm, vec in raw.get("community_centroids", {}).items()
|
||||
},
|
||||
modularity=float(raw.get("modularity", 0.0)),
|
||||
backend=str(raw.get("backend", "flat")),
|
||||
top_communities=[UUID(c) for c in raw.get("top_communities", [])],
|
||||
mid_regions={
|
||||
UUID(comm): [UUID(m) for m in members]
|
||||
for comm, members in raw.get("mid_regions", {}).items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _encode_rich_club(rich_club: Any) -> list[str]:
|
||||
return [str(u) for u in (rich_club or [])]
|
||||
|
||||
|
||||
def _decode_rich_club(raw: Any) -> list[UUID]:
|
||||
return [UUID(u) for u in (raw or [])]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- size estimator
|
||||
#
|
||||
# W2 / D-07, D-08, bound peak RSS in save() by estimating
|
||||
# serialised byte cost without materialising the full JSON string.
|
||||
#
|
||||
# The legacy save() path encoded the cache payload up to 4 times -- once
|
||||
# for the initial size check and once after each progressive drop. On
|
||||
# cold-start graphs (Leiden -> ~1 community per record),
|
||||
# assignment.community_centroids balloons with len(records) * 384-dim
|
||||
# float vectors and a single encode call materialises a multi-GB
|
||||
# intermediate Python string (py-spy confirmed RSS 7.6GB on cold start).
|
||||
#
|
||||
# The estimator overshoots rather than undershoots: false-positive drops
|
||||
# are safe (cache stays advisory; cold-start rebuilds from the live store),
|
||||
# false-negative under-drops produce the very bug we are fixing. The
|
||||
# constants below are upper bounds for the JSON-encoded byte width of each
|
||||
# field shape.
|
||||
|
||||
# JSON overhead per dict entry: 4 punctuation chars (quotes, colon, comma)
|
||||
# + variable-length key + value. We track the punctuation explicitly so
|
||||
# the per-field constants below are pure VALUE budgets.
|
||||
_JSON_DICT_ENTRY_OVERHEAD: int = 4
|
||||
|
||||
# node_payload entry value width upper bound. Shape:
|
||||
# {"embedding": [<384 float>], "surface": str(<=256), "centrality": float,
|
||||
# "tier": str(<=24), "pinned": bool, "tags": [<=16 short strings],
|
||||
# "language": str(<=8)}
|
||||
# 384-dim float vector dominates: each float worst-case ~24 bytes
|
||||
# ("-1.2345678901234567,") -> 384*24 = 9216. Plus structural keys / quotes
|
||||
# ~256. Plus other fields ~512. Round to a comfortable ceiling.
|
||||
_NODE_PAYLOAD_BYTES_PER_RECORD: int = 10240
|
||||
|
||||
# community_centroids entry value width upper bound. Shape:
|
||||
# {"<UUID-36>": [<384 float>]}
|
||||
# 384-dim float same calculus as node_payload embedding -> 9216. Plus
|
||||
# 36-char UUID quoted -> 38. Plus brackets / commas -> ~16. Round up.
|
||||
_CENTROID_BYTES_PER_RECORD: int = 9472
|
||||
|
||||
# mid_regions entry value width upper bound. Shape:
|
||||
# {"<UUID-36>": ["<UUID-36>", ..., "<UUID-36>"]}
|
||||
# Variable length; bound by typical mid-region size <= 32 UUIDs * 38 bytes
|
||||
# = 1216, plus brackets / commas -> 1280.
|
||||
_MID_REGION_BYTES_PER_RECORD: int = 1280
|
||||
|
||||
# rich_club is a list of UUID strings: 38 bytes per entry.
|
||||
_RICH_CLUB_BYTES_PER_ENTRY: int = 38
|
||||
|
||||
# Top-level scaffolding (cache_version + key + saved_at + max_degree +
|
||||
# backend / modularity / top_communities / node_to_community + structural
|
||||
# JSON braces). Conservative upper bound; node_to_community at scale is
|
||||
# the variable component.
|
||||
_BASE_SCAFFOLD_BYTES: int = 4096
|
||||
|
||||
|
||||
def _estimate_serialised_bytes(data: dict) -> int:
|
||||
"""Upper-bound estimate of the encoded ``data`` dict's byte width
|
||||
without actually serialising it.
|
||||
|
||||
Walks the cache payload shape and sums per-field worst-case JSON byte
|
||||
widths. Overshoots rather than undershoots so the caller's drop loop
|
||||
is conservative (false-positive drops are safe; the cache is advisory
|
||||
and cold-start rebuilds from the live store).
|
||||
|
||||
Used by ``save`` before every iteration of the drop loop -- replaces
|
||||
the legacy len-of-encoded round-trip which materialised the full
|
||||
JSON string up to 4 times per save.
|
||||
"""
|
||||
total = _BASE_SCAFFOLD_BYTES
|
||||
|
||||
# node_payload: dict[str, dict] of per-record graph attributes.
|
||||
np_block = data.get("node_payload") or {}
|
||||
if isinstance(np_block, dict):
|
||||
total += len(np_block) * (
|
||||
_NODE_PAYLOAD_BYTES_PER_RECORD + _JSON_DICT_ENTRY_OVERHEAD + 38
|
||||
)
|
||||
|
||||
# node_to_community + community_centroids + mid_regions live under
|
||||
# data["assignment"]. Encoded shape is what _encode_assignment returns.
|
||||
assignment_block = data.get("assignment") or {}
|
||||
if isinstance(assignment_block, dict):
|
||||
ntc = assignment_block.get("node_to_community") or {}
|
||||
if isinstance(ntc, dict):
|
||||
# Each entry: "<UUID-36>": <int>; ~50 bytes worst case.
|
||||
total += len(ntc) * 50
|
||||
|
||||
centroids = assignment_block.get("community_centroids") or {}
|
||||
if isinstance(centroids, dict):
|
||||
total += len(centroids) * (
|
||||
_CENTROID_BYTES_PER_RECORD + _JSON_DICT_ENTRY_OVERHEAD
|
||||
)
|
||||
|
||||
mid = assignment_block.get("mid_regions") or {}
|
||||
if isinstance(mid, dict):
|
||||
total += len(mid) * (
|
||||
_MID_REGION_BYTES_PER_RECORD + _JSON_DICT_ENTRY_OVERHEAD
|
||||
)
|
||||
|
||||
top = assignment_block.get("top_communities") or []
|
||||
if isinstance(top, list):
|
||||
total += len(top) * 16
|
||||
|
||||
rich_club = data.get("rich_club") or []
|
||||
if isinstance(rich_club, list):
|
||||
total += len(rich_club) * _RICH_CLUB_BYTES_PER_ENTRY
|
||||
|
||||
return total
|
||||
|
||||
|
||||
# ------------------------------------------------------------ public API
|
||||
|
||||
|
||||
def try_load(store: Any) -> tuple | None:
|
||||
"""Return the cached ``(assignment, rich_club, node_payload, max_degree)``
|
||||
tuple if the on-disk file is present, readable, and keyed to the
|
||||
current store state. Return ``None`` on any mismatch or error.
|
||||
|
||||
the third element is the ``node_payload`` blob
|
||||
(``dict[str, dict]``: UUID-str -> {embedding, surface, centrality,
|
||||
tier, pinned}) so cold-start ``build_runtime_graph`` can rehydrate
|
||||
NetworkX node attributes without re-walking the encrypted records
|
||||
table.
|
||||
|
||||
R2: the fourth element is ``max_degree`` (one int — the
|
||||
maximum NetworkX degree in the live graph at save() time). Used by
|
||||
the pipeline rank stage to normalise log(1+deg) into [0,1] without
|
||||
re-walking the graph. Missing / malformed value coerces to 0 — the
|
||||
rank stage falls back to deg_norm=0.0 when max_degree==0 (cosine
|
||||
carries the recall on its own at the cold-start scale).
|
||||
|
||||
Callers treat ``None`` as "rebuild from the live graph" — never as
|
||||
an error condition. The cache is advisory.
|
||||
|
||||
W3 / file format is now AES-256-GCM-wrapped JSON.
|
||||
A pre-07.9 plaintext file (cache_version="06-02-v1") is read once
|
||||
and re-saved under the new ciphertext format on the same call —
|
||||
one-cycle lazy migration. Any decrypt failure (wrong key, tampered
|
||||
file) returns None and the caller rebuilds from store.
|
||||
"""
|
||||
path = _cache_path(store)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
raw_text = path.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
legacy_v2_plaintext = False
|
||||
if is_encrypted(raw_text):
|
||||
# v3 ciphertext path.
|
||||
try:
|
||||
key = _cache_encryption_key(store)
|
||||
plaintext_json = decrypt_field(raw_text, key, _CACHE_AAD)
|
||||
data = json.loads(plaintext_json)
|
||||
except Exception as exc:
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"runtime_graph_cache_decrypt_failed","error":'
|
||||
+ json.dumps(str(exc))
|
||||
+ '}\n'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
else:
|
||||
# Legacy plaintext path. Accept ONLY the documented v2 cache
|
||||
# version; anything else falls through to a clean rebuild
|
||||
# (the file is not necessarily ours).
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
if data.get("cache_version") == LEGACY_CACHE_VERSION_PLAINTEXT:
|
||||
legacy_v2_plaintext = True
|
||||
else:
|
||||
# Unknown format / version — treat as no cache.
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
if not legacy_v2_plaintext and data.get("cache_version") != CACHE_VERSION:
|
||||
return None
|
||||
saved_key = tuple(data.get("key", []))
|
||||
current_key = _cache_key(store)
|
||||
if legacy_v2_plaintext:
|
||||
# Legacy v2 caches embed CACHE_VERSION="06-02-v1" in the last
|
||||
# key slot; compare against an expected key that swaps the
|
||||
# current CACHE_VERSION for the legacy one. All other
|
||||
# invariants (records_count, edges_count, schema_version,
|
||||
# embed_dim) MUST still match — anything else means the cache
|
||||
# is stale and we rebuild from store.
|
||||
expected_legacy_key = tuple(
|
||||
list(current_key)[:-1] + [LEGACY_CACHE_VERSION_PLAINTEXT]
|
||||
)
|
||||
if saved_key != expected_legacy_key:
|
||||
return None
|
||||
else:
|
||||
if saved_key != current_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
assignment = _decode_assignment(data["assignment"])
|
||||
rich_club = _decode_rich_club(data.get("rich_club"))
|
||||
node_payload_raw = data.get("node_payload")
|
||||
node_payload: dict[str, dict] | None
|
||||
if isinstance(node_payload_raw, dict):
|
||||
# Shallow dict-of-dicts; embedding list[float] round-trips
|
||||
# through JSON natively.
|
||||
#
|
||||
# Plan 07.11-02 / (V2-03 fix): defensively drop
|
||||
# poisoned entries on rehydrate. Even though Plan 07.11-02's
|
||||
# retrieve.py fix prevents future writes of empty-surface
|
||||
# entries, an existing on-disk cache from before this fix
|
||||
# may still contain them. Belt-and-braces: rehydrate-side
|
||||
# filter ensures a poisoned cache from any source (legacy
|
||||
# write, future regression, manual tamper) cannot leak an
|
||||
# empty/None surface into the live graph.
|
||||
#
|
||||
# Drop rule: surface in (None, "") OR _decrypt_failed=True.
|
||||
# The structured event uses the same stderr-JSON idiom as
|
||||
# the existing runtime_graph_cache_decrypt_failed emission
|
||||
# at lines 376-383 — runtime_graph_cache.py intentionally
|
||||
# bypasses logging because the logger's re-entrant import
|
||||
# path can deadlock during cache rehydrate at very-cold-start.
|
||||
node_payload = {}
|
||||
drop_count = 0
|
||||
for k, v in node_payload_raw.items():
|
||||
if not isinstance(v, dict):
|
||||
continue
|
||||
surface = v.get("surface")
|
||||
if surface in (None, "") or v.get("_decrypt_failed"):
|
||||
drop_count += 1
|
||||
continue # poisoned entry — never expose as a "valid" record
|
||||
node_payload[str(k)] = dict(v)
|
||||
if drop_count > 0:
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"runtime_graph_cache_drop_poisoned_entry","count":'
|
||||
+ str(drop_count)
|
||||
+ '}\n'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
node_payload = None
|
||||
# R2: max_degree is one int — never participates in
|
||||
# the iterative drop path because dropping it costs nothing at
|
||||
# the JSON byte-budget level.
|
||||
try:
|
||||
max_degree = int(data.get("max_degree", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
max_degree = 0
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if legacy_v2_plaintext:
|
||||
# W3 / lazy migration — re-save the loaded
|
||||
# content under the new v3 encrypted format. Wrapped: a
|
||||
# migration write failure must not block the caller from
|
||||
# using the loaded values they already have in memory.
|
||||
try:
|
||||
save(
|
||||
store, assignment, rich_club,
|
||||
node_payload=node_payload, max_degree=max_degree,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return assignment, rich_club, node_payload, max_degree
|
||||
|
||||
|
||||
def save(
|
||||
store: Any,
|
||||
assignment: Any,
|
||||
rich_club: Any,
|
||||
node_payload: "dict[str, dict] | None" = None,
|
||||
max_degree: int = 0,
|
||||
) -> bool:
|
||||
"""Persist the cache atomically. Returns True on success, False on
|
||||
any write error. Errors are swallowed — the caller has freshly
|
||||
computed values in memory either way; a failed cache write is not
|
||||
a reason to break the recall path.
|
||||
|
||||
``node_payload`` persists the per-record graph-node
|
||||
attribute map (UUID-str -> {embedding: list[float], surface: str,
|
||||
centrality: float, tier: str, pinned: bool}). Absent / None -> the
|
||||
cache still writes assignment + rich_club and next cold-start will
|
||||
rebuild node payload from the live store walk. JSON-native shape
|
||||
(no binary serialisation) keeps the cache auditable.
|
||||
|
||||
R2: ``max_degree`` (one int) is the maximum graph degree
|
||||
at save() time. Used by the rank stage to normalise log(1+deg) into
|
||||
[0,1] without re-walking the graph on every recall. Always present
|
||||
in the payload — never participates in the iterative drop path
|
||||
(one int costs nothing against MAX_CACHE_BYTES).
|
||||
"""
|
||||
path = _cache_path(store)
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
# Normalise node_payload for JSON: stringify keys, list() embeddings.
|
||||
encoded_node_payload: dict[str, dict] | None = None
|
||||
if node_payload:
|
||||
encoded_node_payload = {}
|
||||
for k, v in node_payload.items():
|
||||
if not isinstance(v, dict):
|
||||
continue
|
||||
# embeddings can be numpy float32 from LanceDB
|
||||
# rows; coerce to plain Python float so json.dump does not
|
||||
# trip on "Object of type float32 is not JSON serializable".
|
||||
raw_emb = v.get("embedding") or []
|
||||
# `centrality` is now betweenness, computed once
|
||||
# during build_runtime_graph and persisted here so warm starts
|
||||
# don't recompute it. Missing/None coerces to 0.0 (legacy
|
||||
# pre-05-13 pre-compute shape). `tags`/`language` persisted
|
||||
# so SimpleRecordView surfaces the full profile_modulation
|
||||
# input set without a store.get fallback.
|
||||
raw_tags = v.get("tags") or []
|
||||
encoded_node_payload[str(k)] = {
|
||||
"embedding": [float(x) for x in raw_emb],
|
||||
"surface": str(v.get("surface", "")),
|
||||
"centrality": float(v.get("centrality") or 0.0),
|
||||
"tier": str(v.get("tier", "episodic")),
|
||||
"pinned": bool(v.get("pinned", False)),
|
||||
"tags": [str(t) for t in raw_tags if t is not None],
|
||||
"language": str(v.get("language", "en") or "en"),
|
||||
}
|
||||
|
||||
data = {
|
||||
"cache_version": CACHE_VERSION,
|
||||
"key": list(_cache_key(store)),
|
||||
"assignment": _encode_assignment(assignment),
|
||||
"rich_club": _encode_rich_club(rich_club),
|
||||
"node_payload": encoded_node_payload or {},
|
||||
# R2: max_degree is one int — survives every iterative
|
||||
# drop step below because dropping it saves no measurable bytes.
|
||||
"max_degree": int(max_degree or 0),
|
||||
"saved_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Size guard: the previous single-drop path only trimmed
|
||||
# ``node_payload`` and shipped whatever remained, even when the bloat
|
||||
# lived elsewhere. On an all-isolated graph (0 edges) Leiden returns
|
||||
# one community per node and ``assignment.community_centroids`` alone
|
||||
# balloons to 70+ MiB (one 384-dim float vector per record).
|
||||
#
|
||||
# Drop candidates in decreasing marginal-value order. W2 /
|
||||
# D-07, D-08, estimate the encoded byte cost BEFORE materialising
|
||||
# the JSON string, so peak RSS during save matches the final on-disk
|
||||
# file size instead of the pre-drop full payload size. ``json.dumps``
|
||||
# is called AT MOST ONCE per ``save`` invocation, after all drop
|
||||
# decisions are made. The authoritative slim output of Leiden
|
||||
# (``node_to_community``, ``top_communities``, ``modularity``,
|
||||
# ``backend``) and the ``rich_club`` list always survive -- they are
|
||||
# cheap to encode and expensive to recompute from the live store.
|
||||
if _estimate_serialised_bytes(data) > MAX_CACHE_BYTES:
|
||||
# 1) node_payload: per-record blob, rebuildable from the live
|
||||
# store walk on the next cold start.
|
||||
data["node_payload"] = {}
|
||||
if _estimate_serialised_bytes(data) > MAX_CACHE_BYTES:
|
||||
# 2) assignment.community_centroids: {UUID: [float; embed_dim]}.
|
||||
# On sparse graphs this is the biggest single field. Leiden
|
||||
# recomputes centroids on the next build.
|
||||
if isinstance(data.get("assignment"), dict):
|
||||
data["assignment"]["community_centroids"] = {}
|
||||
if _estimate_serialised_bytes(data) > MAX_CACHE_BYTES:
|
||||
# 3) assignment.mid_regions: {UUID: [UUID, ...]}. Smaller view;
|
||||
# also recomputable.
|
||||
if isinstance(data.get("assignment"), dict):
|
||||
data["assignment"]["mid_regions"] = {}
|
||||
if _estimate_serialised_bytes(data) > MAX_CACHE_BYTES:
|
||||
# Still over the cap after dropping every advisory field. Prefer
|
||||
# a clean "give up" to shipping an oversized file; the caller
|
||||
# already has the in-memory values and the next build will
|
||||
# recompute everything from the live store.
|
||||
return False
|
||||
|
||||
# Single final encode -- AT MOST ONE json.dumps per save() per D-10.
|
||||
serialised = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
# W3 / encrypt the JSON payload before writing.
|
||||
# Same AES-256-GCM machinery + key as the LanceDB literal_surface
|
||||
# column. ASCII-only ciphertext (b64 envelope) lets us keep the
|
||||
# text-mode write path; on-disk plaintext canary is provably absent.
|
||||
try:
|
||||
key = _cache_encryption_key(store)
|
||||
ciphertext = encrypt_field(serialised, key, _CACHE_AAD)
|
||||
except Exception:
|
||||
# Encryption failure: skip the cache write rather than persist
|
||||
# plaintext on disk. Cache is advisory; recall path unaffected.
|
||||
try:
|
||||
sys.stderr.write(
|
||||
'{"event":"runtime_graph_cache_encrypt_failed"}\n'
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tmp_path.open("w", encoding="ascii") as f:
|
||||
f.write(ciphertext)
|
||||
os.replace(str(tmp_path), str(path))
|
||||
return True
|
||||
except Exception:
|
||||
try:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def invalidate(store: Any) -> None:
|
||||
"""Delete the cache file for ``store``. Safe when the file does not
|
||||
exist. Used by explicit ``needs_refresh`` signals and by tests that
|
||||
want a clean slate."""
|
||||
path = _cache_path(store)
|
||||
try:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
459
src/iai_mcp/s4.py
Normal file
459
src/iai_mcp/s4.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
"""S4 viability -- on-read consistency + monotropic proactive checks (MEM-08, D-17).
|
||||
|
||||
D-17 constitutional:
|
||||
- (e) on-read consistency: runs inside `pipeline_recall` on top-K returned
|
||||
records. Pairwise cosine with ART vigilance ρ_s4=0.97 + `contradicts`
|
||||
edge lookup. Emits `s4_contradiction` events. Populates
|
||||
`RecallResponse.hints`.
|
||||
- (f) monotropic proactive: only fires when profile.monotropism_depth[domain]
|
||||
> 0.7 AND new_record.detail_level >= 4. Scans within-domain only.
|
||||
Performance guard: if domain > 100 records, skip with warning event.
|
||||
|
||||
Plan 03-02 CONN-07 addition:
|
||||
- `run_offline_pass(store)` -- new entry point, CALLED by the daemon /
|
||||
session_exit hook. Currently runs `sigma.compute_and_emit(store)` only;
|
||||
future plans append more offline-pass items here. Failures emit
|
||||
`kind="s4_error"` and never crash the pass.
|
||||
|
||||
Explicitly forbidden (D-17 negative assertions):
|
||||
- NO `daily_scan` function (Ashby Requisite Variety violation).
|
||||
- NO `session_exit_sweep` function (Anderson activation-based violation).
|
||||
|
||||
All detected contradictions go through `events.write_event` -- no .jsonl files
|
||||
(D-STORAGE).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryHit, MemoryRecord
|
||||
|
||||
|
||||
# D-17(e) vigilance: 0.97 for near-duplicate contradiction detection.
|
||||
# Stricter than write-path ρ=0.95: we only flag VERY close matches.
|
||||
S4_VIGILANCE_RHO = 0.97
|
||||
|
||||
# D-17(f) performance guard: skip when domain has > this many records.
|
||||
MONOTROPIC_MAX_PAIRWISE = 100
|
||||
|
||||
# D-17(f) monotropism-depth threshold.
|
||||
S4_MONOTROPIC_THETA = 0.7
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
"""Cosine similarity in [-1, 1]. Returns 0.0 on zero-norm inputs."""
|
||||
av = np.asarray(a, dtype=np.float32)
|
||||
bv = np.asarray(b, dtype=np.float32)
|
||||
na = float(np.linalg.norm(av))
|
||||
nb = float(np.linalg.norm(bv))
|
||||
if na == 0.0 or nb == 0.0:
|
||||
return 0.0
|
||||
return float(np.dot(av, bv) / (na * nb))
|
||||
|
||||
|
||||
def on_read_check(
|
||||
store: MemoryStore,
|
||||
hits: list[MemoryHit],
|
||||
session_id: str,
|
||||
) -> list[dict]:
|
||||
"""D-17(e) on-read consistency check.
|
||||
|
||||
Two detection paths, both run per-retrieval on the top-K hits:
|
||||
|
||||
1. `contradicts`-edge authoritative: any pair of hits connected by an
|
||||
existing `contradicts` edge is flagged regardless of cosine. This is
|
||||
the definitive route -- the user (or a prior S4 run) already said
|
||||
"these two disagree", so we surface it every time they co-retrieve.
|
||||
|
||||
2. Cosine + tag-polarity heuristic: pairs with cosine >= ρ_s4 (0.97) AND
|
||||
conflicting polarity tags ({positive,negative} or {asserted,retracted})
|
||||
are flagged as `info`-severity. or can replace this
|
||||
with NLI-based semantic contradiction.
|
||||
|
||||
Returns a list of hint dicts; each dict is shaped per
|
||||
RecallResponse.hints contract. Also writes one `s4_contradiction` event
|
||||
per detected pair to the LanceDB events table (D-STORAGE).
|
||||
|
||||
note: `on_read_check_batch` is the D-SPEED variant. It accepts
|
||||
an optional `records_cache` kwarg so pipeline_recall can reuse the cache
|
||||
it already built at stage 1 (zero extra store.get calls). This function
|
||||
is preserved as the back-compat / ad-hoc caller API (retrieve.recall
|
||||
still calls it; no records_cache available there).
|
||||
"""
|
||||
if len(hits) < 2:
|
||||
return []
|
||||
|
||||
hint_list: list[dict] = []
|
||||
|
||||
# Load records for the hit ids. Missing records are skipped silently -- a
|
||||
# recent store.delete could race us.
|
||||
records: dict[UUID, MemoryRecord] = {}
|
||||
for h in hits:
|
||||
rec = store.get(h.record_id)
|
||||
if rec is not None:
|
||||
records[h.record_id] = rec
|
||||
if len(records) < 2:
|
||||
return []
|
||||
|
||||
# Load contradicts edges among these records. We precompute the set of
|
||||
# (sorted src,dst) pairs so the pairwise loop below is O(1) lookup.
|
||||
contradict_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
edges_df = store.db.open_table("edges").to_pandas()
|
||||
except Exception:
|
||||
edges_df = None
|
||||
if edges_df is not None and not edges_df.empty:
|
||||
contradict_df = edges_df[edges_df["edge_type"] == "contradicts"]
|
||||
hit_ids = {str(h.record_id) for h in hits}
|
||||
for _, row in contradict_df.iterrows():
|
||||
src = row["src"]
|
||||
dst = row["dst"]
|
||||
if src in hit_ids and dst in hit_ids:
|
||||
contradict_pairs.add(tuple(sorted([src, dst])))
|
||||
|
||||
# Pairwise scan across hit records.
|
||||
hit_records = list(records.values())
|
||||
for i in range(len(hit_records)):
|
||||
for j in range(i + 1, len(hit_records)):
|
||||
a = hit_records[i]
|
||||
b = hit_records[j]
|
||||
key = tuple(sorted([str(a.id), str(b.id)]))
|
||||
sim = _cosine(a.embedding, b.embedding)
|
||||
|
||||
# Path 1: explicit edge is authoritative.
|
||||
if key in contradict_pairs:
|
||||
hint = {
|
||||
"kind": "s4_contradiction",
|
||||
"severity": "warning",
|
||||
"source_ids": [str(a.id), str(b.id)],
|
||||
"text": (
|
||||
f"inconsistency: records have a contradicts edge; "
|
||||
f"review {a.id}, {b.id}"
|
||||
),
|
||||
"similarity": sim,
|
||||
}
|
||||
hint_list.append(hint)
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_contradiction",
|
||||
data={
|
||||
"source_ids": list(key),
|
||||
"similarity": sim,
|
||||
"mechanism": "contradicts_edge",
|
||||
},
|
||||
severity="warning",
|
||||
session_id=session_id,
|
||||
source_ids=[a.id, b.id],
|
||||
)
|
||||
continue
|
||||
|
||||
# Path 2: cosine + polarity-tag heuristic.
|
||||
if sim >= S4_VIGILANCE_RHO:
|
||||
a_tags = set(a.tags or [])
|
||||
b_tags = set(b.tags or [])
|
||||
polarity_conflict = (
|
||||
("positive" in a_tags and "negative" in b_tags)
|
||||
or ("negative" in a_tags and "positive" in b_tags)
|
||||
or ("asserted" in a_tags and "retracted" in b_tags)
|
||||
or ("retracted" in a_tags and "asserted" in b_tags)
|
||||
)
|
||||
if polarity_conflict:
|
||||
hint = {
|
||||
"kind": "s4_contradiction",
|
||||
"severity": "info",
|
||||
"source_ids": [str(a.id), str(b.id)],
|
||||
"text": (
|
||||
f"inconsistency: near-duplicate ({sim:.3f}) with "
|
||||
f"conflicting polarity tags"
|
||||
),
|
||||
"similarity": sim,
|
||||
}
|
||||
hint_list.append(hint)
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_contradiction",
|
||||
data={
|
||||
"source_ids": list(key),
|
||||
"similarity": sim,
|
||||
"mechanism": "tag_polarity",
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=[a.id, b.id],
|
||||
)
|
||||
return hint_list
|
||||
|
||||
|
||||
def on_read_check_batch(
|
||||
store: MemoryStore,
|
||||
hits: list[MemoryHit],
|
||||
session_id: str,
|
||||
records_cache: "dict[UUID, MemoryRecord] | None" = None,
|
||||
) -> list[dict]:
|
||||
"""Plan 02-07 D-SPEED: batched variant of on_read_check.
|
||||
|
||||
Semantically identical to on_read_check (returns the same hint-shape list,
|
||||
emits the same events). The ONLY difference is the record-loading step:
|
||||
|
||||
- If `records_cache` is provided, use it directly. ZERO store.get calls.
|
||||
- Otherwise, do ONE `store.all_records()` call instead of N `store.get()`
|
||||
calls. ZERO per-hit round-trips either way.
|
||||
|
||||
The pairwise contradiction-detection loop, the polarity-tag heuristic, the
|
||||
vigilance threshold (S4_VIGILANCE_RHO), and the event-emission logic are
|
||||
byte-for-byte equivalent to on_read_check.
|
||||
|
||||
Why this is the perf-critical surface (D-SPEED SC-6):
|
||||
Pre-fix: pipeline_recall built records_cache at stage 1, then s4.on_read_check
|
||||
called `store.get(h.record_id)` per hit -- every call is a full
|
||||
to_pandas() scan (~140ms each at N=100 on executor hardware).
|
||||
Post-fix: pipeline_recall passes records_cache through; s4 does zero extra
|
||||
round-trips. Saves ~140ms per hit x N hits per recall.
|
||||
"""
|
||||
if len(hits) < 2:
|
||||
return []
|
||||
|
||||
hint_list: list[dict] = []
|
||||
|
||||
# Load records via cache (preferred) or one batched fallback.
|
||||
records: dict[UUID, MemoryRecord] = {}
|
||||
if records_cache is not None:
|
||||
for h in hits:
|
||||
rec = records_cache.get(h.record_id)
|
||||
if rec is not None:
|
||||
records[h.record_id] = rec
|
||||
else:
|
||||
all_recs = store.all_records()
|
||||
by_id = {r.id: r for r in all_recs}
|
||||
for h in hits:
|
||||
rec = by_id.get(h.record_id)
|
||||
if rec is not None:
|
||||
records[h.record_id] = rec
|
||||
if len(records) < 2:
|
||||
return []
|
||||
|
||||
# Load contradicts edges among these records. One edges.to_pandas() scan
|
||||
# (same as on_read_check).
|
||||
contradict_pairs: set[tuple[str, str]] = set()
|
||||
try:
|
||||
edges_df = store.db.open_table("edges").to_pandas()
|
||||
except Exception:
|
||||
edges_df = None
|
||||
if edges_df is not None and not edges_df.empty:
|
||||
contradict_df = edges_df[edges_df["edge_type"] == "contradicts"]
|
||||
hit_ids = {str(h.record_id) for h in hits}
|
||||
for _, row in contradict_df.iterrows():
|
||||
src = row["src"]
|
||||
dst = row["dst"]
|
||||
if src in hit_ids and dst in hit_ids:
|
||||
contradict_pairs.add(tuple(sorted([src, dst])))
|
||||
|
||||
# Pairwise scan -- identical logic to on_read_check.
|
||||
hit_records = list(records.values())
|
||||
for i in range(len(hit_records)):
|
||||
for j in range(i + 1, len(hit_records)):
|
||||
a = hit_records[i]
|
||||
b = hit_records[j]
|
||||
key = tuple(sorted([str(a.id), str(b.id)]))
|
||||
sim = _cosine(a.embedding, b.embedding)
|
||||
|
||||
# Path 1: explicit edge is authoritative.
|
||||
if key in contradict_pairs:
|
||||
hint = {
|
||||
"kind": "s4_contradiction",
|
||||
"severity": "warning",
|
||||
"source_ids": [str(a.id), str(b.id)],
|
||||
"text": (
|
||||
f"inconsistency: records have a contradicts edge; "
|
||||
f"review {a.id}, {b.id}"
|
||||
),
|
||||
"similarity": sim,
|
||||
}
|
||||
hint_list.append(hint)
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_contradiction",
|
||||
data={
|
||||
"source_ids": list(key),
|
||||
"similarity": sim,
|
||||
"mechanism": "contradicts_edge",
|
||||
},
|
||||
severity="warning",
|
||||
session_id=session_id,
|
||||
source_ids=[a.id, b.id],
|
||||
)
|
||||
continue
|
||||
|
||||
# Path 2: cosine + polarity-tag heuristic.
|
||||
if sim >= S4_VIGILANCE_RHO:
|
||||
a_tags = set(a.tags or [])
|
||||
b_tags = set(b.tags or [])
|
||||
polarity_conflict = (
|
||||
("positive" in a_tags and "negative" in b_tags)
|
||||
or ("negative" in a_tags and "positive" in b_tags)
|
||||
or ("asserted" in a_tags and "retracted" in b_tags)
|
||||
or ("retracted" in a_tags and "asserted" in b_tags)
|
||||
)
|
||||
if polarity_conflict:
|
||||
hint = {
|
||||
"kind": "s4_contradiction",
|
||||
"severity": "info",
|
||||
"source_ids": [str(a.id), str(b.id)],
|
||||
"text": (
|
||||
f"inconsistency: near-duplicate ({sim:.3f}) with "
|
||||
f"conflicting polarity tags"
|
||||
),
|
||||
"similarity": sim,
|
||||
}
|
||||
hint_list.append(hint)
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_contradiction",
|
||||
data={
|
||||
"source_ids": list(key),
|
||||
"similarity": sim,
|
||||
"mechanism": "tag_polarity",
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=[a.id, b.id],
|
||||
)
|
||||
return hint_list
|
||||
|
||||
|
||||
def monotropic_proactive_check(
|
||||
store: MemoryStore,
|
||||
new_record: MemoryRecord,
|
||||
profile_state: dict,
|
||||
session_id: str,
|
||||
) -> list[dict]:
|
||||
"""D-17(f) monotropic proactive check.
|
||||
|
||||
Three gates (all must pass):
|
||||
|
||||
1. `profile_state["monotropism_depth"][domain] > θ_deep` (0.7). The user's
|
||||
autistic profile indicates DEEP focus in this domain -- we're willing
|
||||
to spend cycles checking for near-duplicates.
|
||||
2. `new_record.detail_level >= 4`. Shallow records (detail 1-3) don't
|
||||
warrant the pairwise scan.
|
||||
3. `new_record` carries a `domain:<name>` tag. Records without a domain
|
||||
tag are excluded (nothing to compare against).
|
||||
|
||||
Performance guard: if the domain has > MONOTROPIC_MAX_PAIRWISE records,
|
||||
skip the scan and emit a `s4_monotropic_skip` warning event. The scan is
|
||||
O(N) cosine comparisons; 100 is a reasonable ceiling.
|
||||
|
||||
Rule 1 deviation: if `profile_state["monotropism_depth"]` is not a dict
|
||||
(type drift), degrade silently to empty hints (no exception).
|
||||
"""
|
||||
md = profile_state.get("monotropism_depth", {})
|
||||
if not isinstance(md, dict):
|
||||
return [] # profile_state wrongly typed -- degrade silently
|
||||
|
||||
# Locate the record's domain tag ("domain:coding", "domain:gardening", ...)
|
||||
domain_tag: str | None = next(
|
||||
(t for t in (new_record.tags or []) if t.startswith("domain:")),
|
||||
None,
|
||||
)
|
||||
if domain_tag is None:
|
||||
return []
|
||||
|
||||
# Gate 1: monotropism depth must exceed θ_deep.
|
||||
domain_name = domain_tag.split(":", 1)[1]
|
||||
depth = md.get(domain_name, 0.0)
|
||||
if depth <= S4_MONOTROPIC_THETA:
|
||||
return []
|
||||
|
||||
# Gate 2: detail_level must be >= 4.
|
||||
if new_record.detail_level < 4:
|
||||
return []
|
||||
|
||||
# Load same-domain records (excluding the new record itself).
|
||||
same_domain = [
|
||||
r for r in store.all_records()
|
||||
if (r.tags or []) and domain_tag in r.tags and r.id != new_record.id
|
||||
]
|
||||
|
||||
# Performance guard: skip + warn above ceiling.
|
||||
if len(same_domain) > MONOTROPIC_MAX_PAIRWISE:
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_monotropic_skip",
|
||||
data={
|
||||
"domain": domain_tag,
|
||||
"count": len(same_domain),
|
||||
"record_id": str(new_record.id),
|
||||
},
|
||||
severity="warning",
|
||||
domain=domain_tag,
|
||||
session_id=session_id,
|
||||
)
|
||||
return []
|
||||
|
||||
hints: list[dict] = []
|
||||
for r in same_domain:
|
||||
sim = _cosine(new_record.embedding, r.embedding)
|
||||
if sim >= S4_VIGILANCE_RHO:
|
||||
hint = {
|
||||
"kind": "s4_monotropic_contradiction",
|
||||
"severity": "info",
|
||||
"source_ids": [str(new_record.id), str(r.id)],
|
||||
"text": (
|
||||
f"monotropic near-duplicate in {domain_tag}: sim={sim:.3f}"
|
||||
),
|
||||
"similarity": sim,
|
||||
}
|
||||
hints.append(hint)
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_monotropic_contradiction",
|
||||
data={
|
||||
"domain": domain_tag,
|
||||
"source_ids": [str(new_record.id), str(r.id)],
|
||||
"similarity": sim,
|
||||
},
|
||||
severity="info",
|
||||
domain=domain_tag,
|
||||
session_id=session_id,
|
||||
source_ids=[new_record.id, r.id],
|
||||
)
|
||||
return hints
|
||||
|
||||
|
||||
def run_offline_pass(store: MemoryStore) -> dict:
|
||||
"""Plan 03-02 CONN-07: S4 offline-pass entry point.
|
||||
|
||||
Called by the daemon's offline cycle (or by session_exit / cron).
|
||||
Currently runs ONE check: `sigma.compute_and_emit(store)` -- which writes
|
||||
`kind=sigma_observation` (developmental / healthy / insufficient_data) OR
|
||||
`kind=sigma_drift` (mid_life_drift) and (in developmental phase) bumps the
|
||||
Hebbian rate via a `profile_updated` event.
|
||||
|
||||
Failures are caught and emitted as `kind="s4_error"`; the pass does NOT
|
||||
crash. This mirrors the diagnostic discipline of `on_read_check`:
|
||||
S4 work is observation, never blocks reads or writes.
|
||||
|
||||
Returns a dict with the per-step outcome:
|
||||
{"sigma": <snapshot dict or {"error": "..."}>}
|
||||
"""
|
||||
from iai_mcp import sigma # local import; sigma is heavy (networkx)
|
||||
|
||||
out: dict = {}
|
||||
try:
|
||||
out["sigma"] = sigma.compute_and_emit(store)
|
||||
except Exception as exc: # noqa: BLE001 - diagnostic catch-all
|
||||
try:
|
||||
write_event(
|
||||
store,
|
||||
kind="s4_error",
|
||||
data={"step": "sigma", "error": repr(exc)},
|
||||
severity="warning",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
out["sigma"] = {"error": repr(exc)}
|
||||
return out
|
||||
417
src/iai_mcp/s5.py
Normal file
417
src/iai_mcp/s5.py
Normal file
|
|
@ -0,0 +1,417 @@
|
|||
"""S5 identity kernel -- invariant protection via M-of-N consensus (MEM-09, D-22).
|
||||
|
||||
D-22 constitutional rules enforced here:
|
||||
- ρ_identity = 0.99 (stricter than write-path ρ=0.95 and S4 ρ_s4=0.97).
|
||||
- 3-of-5 session-window consensus: an invariant update only commits after 3
|
||||
vigilance-passing proposals within the consensus window. A single-session
|
||||
attacker (e.g. prompt injection) cannot reach M by itself.
|
||||
- 48h cooldown: after a commit, any subsequent proposal on the same anchor
|
||||
is rejected for 48h. Prevents rapid sequential poisoning.
|
||||
- TRUST_THRESHOLD_IDENTITY = 0.9: records with s5_trust_score >= 0.9 are
|
||||
"invariant-tier". Direct writes bypassing propose_invariant_update are
|
||||
rejected by `check_identity_anchor_on_write`.
|
||||
- All commits emit `s5_invariant_update` events with full provenance
|
||||
(proposal history, session_ids, similarity scores).
|
||||
|
||||
Proposal events (kind=s5_invariant_proposal) are emitted for EVERY proposal
|
||||
so the M-of-N tally can be reconstructed from the events table alone -- no
|
||||
hidden in-memory state. Cooldown lookups read kind=s5_invariant_update.
|
||||
|
||||
Plan 02-05 additions (OPS-07 / gradual-drift detection):
|
||||
- `detect_drift_anomaly` reads trajectory_metric events for M4 (profile-vector
|
||||
variance). When the last `window_sessions` consecutive values have been
|
||||
monotonically increasing (was-decreasing becoming increasing), emits an
|
||||
s5_drift_alert event. User audit via `iai-mcp audit drift` surfaces these.
|
||||
- `audit_identity_events` aggregates s5_* + shield_* + s5_drift_alert events
|
||||
chronologically (newest first) for `iai-mcp audit` / `audit identity`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.aaak import enforce_language_tagged, generate_aaak_index
|
||||
from iai_mcp.events import query_events, write_event
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
|
||||
# ------------------------------------------------------------ constitutional constants
|
||||
|
||||
IDENTITY_VIGILANCE_RHO: float = 0.99 # strict vigilance on identity updates
|
||||
S5_CONSENSUS_M: int = 3 # 3-of-5: required agreeing proposals
|
||||
S5_CONSENSUS_N: int = 5 # 3-of-5: window size
|
||||
COOLDOWN_HOURS: int = 48 # cooldown after a commit
|
||||
TRUST_THRESHOLD_IDENTITY: float = 0.9 # score >= this => invariant-tier record
|
||||
CONSENSUS_WINDOW_HOURS: int = 24 # all M votes must land within this window
|
||||
|
||||
|
||||
# ------------------------------------------------------------ private helpers
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
av = np.asarray(a, dtype=np.float32)
|
||||
bv = np.asarray(b, dtype=np.float32)
|
||||
na = float(np.linalg.norm(av))
|
||||
nb = float(np.linalg.norm(bv))
|
||||
if na == 0.0 or nb == 0.0:
|
||||
return 0.0
|
||||
return float(np.dot(av, bv) / (na * nb))
|
||||
|
||||
|
||||
def _recent_proposals_for(
|
||||
store: MemoryStore, anchor_id: UUID,
|
||||
) -> list[dict]:
|
||||
"""Return all s5_invariant_proposal events for this anchor inside the
|
||||
consensus window, newest first."""
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=CONSENSUS_WINDOW_HOURS)
|
||||
events = query_events(store, kind="s5_invariant_proposal", since=since, limit=100)
|
||||
return [e for e in events if e["data"].get("anchor_id") == str(anchor_id)]
|
||||
|
||||
|
||||
def _in_cooldown(store: MemoryStore, anchor_id: UUID) -> bool:
|
||||
"""True iff an s5_invariant_update for this anchor landed in the last COOLDOWN_HOURS."""
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=COOLDOWN_HOURS)
|
||||
events = query_events(store, kind="s5_invariant_update", since=since, limit=10)
|
||||
for e in events:
|
||||
if e["data"].get("anchor_id") == str(anchor_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------ public API
|
||||
|
||||
|
||||
def propose_invariant_update(
|
||||
store: MemoryStore,
|
||||
anchor_id: UUID,
|
||||
new_fact: str,
|
||||
session_id: str,
|
||||
) -> tuple[str, UUID | None]:
|
||||
"""D-22 M-of-N voting on identity-tier updates.
|
||||
|
||||
Workflow:
|
||||
1. If the anchor is in 48h cooldown, reject (``cooldown``).
|
||||
2. If the anchor does not exist, reject (``rejected``).
|
||||
3. Encode the proposed fact; compute cosine against the anchor.
|
||||
4. Log an `s5_invariant_proposal` event regardless of vigilance outcome.
|
||||
(This is how the M-of-N tally is reconstructed on subsequent calls.)
|
||||
5. Count vigilance-passing proposals in the current consensus window.
|
||||
- If >= M (3): commit -- insert new record, create invariant_anchor
|
||||
edge, log `s5_invariant_update` event, return ("committed", new_id).
|
||||
- Else if total >= N (5) proposals in window: reject (``rejected``).
|
||||
- Else: stage (``staged``), return the proposal UUID.
|
||||
|
||||
Returns one of:
|
||||
("cooldown", None)
|
||||
("rejected", None)
|
||||
("staged", proposal_id)
|
||||
("committed", new_record_id)
|
||||
"""
|
||||
# Step 1: cooldown gate.
|
||||
if _in_cooldown(store, anchor_id):
|
||||
write_event(
|
||||
store,
|
||||
kind="s5_cooldown_block",
|
||||
data={"anchor_id": str(anchor_id), "session_id": session_id},
|
||||
severity="warning",
|
||||
session_id=session_id,
|
||||
source_ids=[anchor_id],
|
||||
)
|
||||
return "cooldown", None
|
||||
|
||||
# Step 2: anchor existence.
|
||||
anchor = store.get(anchor_id)
|
||||
if anchor is None:
|
||||
return "rejected", None
|
||||
|
||||
# Step 3: encode proposed fact + compute vigilance similarity.
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
emb = embedder_for_store(store).embed(new_fact)
|
||||
sim = _cosine(anchor.embedding, emb)
|
||||
passes_vigilance = sim >= IDENTITY_VIGILANCE_RHO
|
||||
|
||||
# Step 4: log the proposal (counts toward N).
|
||||
proposal_id = uuid4()
|
||||
write_event(
|
||||
store,
|
||||
kind="s5_invariant_proposal",
|
||||
data={
|
||||
"proposal_id": str(proposal_id),
|
||||
"anchor_id": str(anchor_id),
|
||||
"new_fact": new_fact[:200], # payload size cap (T-02-02-05)
|
||||
"similarity": sim,
|
||||
"passes_vigilance": passes_vigilance,
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=[anchor_id],
|
||||
)
|
||||
|
||||
# Step 5: tally.
|
||||
recent = _recent_proposals_for(store, anchor_id)
|
||||
agree_count = sum(1 for r in recent if r["data"].get("passes_vigilance"))
|
||||
total = len(recent)
|
||||
|
||||
if agree_count >= S5_CONSENSUS_M:
|
||||
# COMMIT: create the invariant_anchor edge + log the update.
|
||||
now = datetime.now(timezone.utc)
|
||||
updated = MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier=anchor.tier,
|
||||
literal_surface=new_fact,
|
||||
aaak_index="",
|
||||
embedding=emb,
|
||||
community_id=anchor.community_id,
|
||||
centrality=anchor.centrality,
|
||||
detail_level=anchor.detail_level,
|
||||
pinned=anchor.pinned,
|
||||
stability=anchor.stability,
|
||||
difficulty=anchor.difficulty,
|
||||
last_reviewed=now,
|
||||
never_decay=True,
|
||||
never_merge=True,
|
||||
provenance=[
|
||||
{
|
||||
"ts": now.isoformat(),
|
||||
"cue": "s5_consensus",
|
||||
"session_id": session_id,
|
||||
}
|
||||
],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=[*anchor.tags, "s5_consensus"],
|
||||
language=anchor.language or "en",
|
||||
s5_trust_score=min(1.0, anchor.s5_trust_score + 0.05),
|
||||
profile_modulation_gain=dict(anchor.profile_modulation_gain),
|
||||
schema_version=2,
|
||||
)
|
||||
enforce_language_tagged(updated, detect=False)
|
||||
updated.aaak_index = generate_aaak_index(updated)
|
||||
store.insert(updated)
|
||||
store.boost_edges(
|
||||
[(anchor_id, updated.id)],
|
||||
edge_type="invariant_anchor",
|
||||
delta=1.0,
|
||||
)
|
||||
write_event(
|
||||
store,
|
||||
kind="s5_invariant_update",
|
||||
data={
|
||||
"anchor_id": str(anchor_id),
|
||||
"new_record_id": str(updated.id),
|
||||
"session_ids": [r["session_id"] for r in recent],
|
||||
"agree_count": agree_count,
|
||||
"total_proposals": total,
|
||||
"similarity": sim,
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
source_ids=[anchor_id, updated.id],
|
||||
)
|
||||
return "committed", updated.id
|
||||
|
||||
if total >= S5_CONSENSUS_N:
|
||||
return "rejected", None
|
||||
|
||||
return "staged", proposal_id
|
||||
|
||||
|
||||
def check_identity_anchor_on_write(
|
||||
store: MemoryStore,
|
||||
record: MemoryRecord,
|
||||
profile_state: dict,
|
||||
) -> tuple[bool, str]:
|
||||
"""Guard invoked by write paths that accept externally-originated records.
|
||||
|
||||
Records with s5_trust_score >= TRUST_THRESHOLD_IDENTITY (0.9) are
|
||||
considered invariant-tier. They may NOT be written through any path that
|
||||
bypasses propose_invariant_update (D-22 consensus requirement).
|
||||
|
||||
extension (OPS-07, D-31): the shield is evaluated in
|
||||
HARD_BLOCK tier BEFORE the consensus marker check. Any detected
|
||||
injection signal short-circuits with "shield HARD_BLOCK" -- a
|
||||
mitigation for the "direct override" branch of the threat model.
|
||||
|
||||
cross-lingual warning: an identity update whose
|
||||
language differs from the existing pinned identity anchor(s) emits a
|
||||
`identity_cross_lingual_warning` event but does NOT block -- multi-lingual
|
||||
identity refinement is a design goal of the global-product roadmap. The
|
||||
warning surfaces via `iai-mcp audit identity` for user review.
|
||||
|
||||
We distinguish between:
|
||||
- DIRECT identity writes (reject): s5_trust_score >= 0.9 and no
|
||||
`s5_consensus` tag -- attacker trying to plant an invariant.
|
||||
- CONSENSUS-PROMOTED writes (accept): s5_trust_score >= 0.9 and
|
||||
`s5_consensus` tag present -- output of propose_invariant_update's
|
||||
own store.insert call.
|
||||
- NORMAL writes (accept): s5_trust_score < 0.9 -- below identity tier.
|
||||
"""
|
||||
if record.s5_trust_score < TRUST_THRESHOLD_IDENTITY:
|
||||
return True, ""
|
||||
|
||||
# shield HARD_BLOCK pre-check on identity-tier writes.
|
||||
from iai_mcp.shield import ShieldTier, evaluate_injection_risk
|
||||
|
||||
shield_verdict = evaluate_injection_risk(
|
||||
record.literal_surface or "",
|
||||
ShieldTier.HARD_BLOCK,
|
||||
target_language=record.language or None,
|
||||
)
|
||||
if shield_verdict.action == "reject":
|
||||
return (
|
||||
False,
|
||||
f"shield HARD_BLOCK: {shield_verdict.reason}",
|
||||
)
|
||||
|
||||
if "s5_consensus" not in (record.tags or []):
|
||||
return (
|
||||
False,
|
||||
"identity-tier write (s5_trust_score >= 0.9) requires "
|
||||
"propose_invariant_update consensus; direct inserts forbidden "
|
||||
"(D-22).",
|
||||
)
|
||||
|
||||
# cross-lingual warning. Non-fatal: emit an event and
|
||||
# continue. Inspect the existing pinned identity anchors for a language
|
||||
# mismatch with the incoming record.
|
||||
try:
|
||||
anchors_with_other_lang = [
|
||||
r for r in store.all_records()
|
||||
if r.pinned
|
||||
and r.s5_trust_score >= TRUST_THRESHOLD_IDENTITY
|
||||
and (r.language or "") != ""
|
||||
and (r.language or "") != (record.language or "")
|
||||
]
|
||||
except Exception:
|
||||
anchors_with_other_lang = []
|
||||
if anchors_with_other_lang:
|
||||
anchor_langs = sorted({
|
||||
r.language for r in anchors_with_other_lang if r.language
|
||||
})
|
||||
write_event(
|
||||
store,
|
||||
kind="identity_cross_lingual_warning",
|
||||
data={
|
||||
"record_id": str(record.id),
|
||||
"record_language": record.language,
|
||||
"existing_anchor_languages": anchor_langs,
|
||||
},
|
||||
severity="warning",
|
||||
session_id="-",
|
||||
source_ids=[record.id],
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------- drift detection
|
||||
|
||||
# Relevant kinds for user audit surface. aggregates these under
|
||||
# `iai-mcp audit`.
|
||||
AUDIT_EVENT_KINDS: tuple[str, ...] = (
|
||||
"s5_invariant_update",
|
||||
"s5_invariant_proposal",
|
||||
"s5_cooldown_block",
|
||||
"s5_drift_alert",
|
||||
"shield_rejection",
|
||||
"shield_flag",
|
||||
"identity_cross_lingual_warning",
|
||||
)
|
||||
|
||||
|
||||
def detect_drift_anomaly(
|
||||
store: MemoryStore,
|
||||
window_sessions: int = 5,
|
||||
) -> list[dict]:
|
||||
"""D-30 gradual-drift detection via trajectory M4 reversal.
|
||||
|
||||
Reads trajectory_metric events filtered to metric=m4 (profile-vector
|
||||
variance). The expected direction is DECREASING (the profile is
|
||||
converging as the user is learnt over time). When the last
|
||||
`window_sessions` values are monotonically INCREASING or mostly so
|
||||
(at least window_sessions - 2 adjacent pairs increase), emits an
|
||||
s5_drift_alert event and returns the alert payload in a list.
|
||||
|
||||
Returns [] on insufficient data or no drift.
|
||||
"""
|
||||
events = query_events(store, kind="trajectory_metric", limit=1000)
|
||||
m4: list[tuple] = []
|
||||
for e in events:
|
||||
data = e.get("data") or {}
|
||||
if data.get("metric") != "m4":
|
||||
continue
|
||||
try:
|
||||
v = float(data.get("value", 0.0))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
ts = e.get("ts")
|
||||
m4.append((ts, v))
|
||||
|
||||
if len(m4) < window_sessions:
|
||||
return []
|
||||
|
||||
# Sort ascending (oldest first) so "recent" slice is the tail.
|
||||
try:
|
||||
m4.sort(key=lambda x: x[0])
|
||||
except TypeError:
|
||||
# Fallback: if ts objects are not comparable, keep insertion order.
|
||||
pass
|
||||
recent = m4[-window_sessions:]
|
||||
|
||||
increases = 0
|
||||
for i in range(1, len(recent)):
|
||||
if recent[i][1] > recent[i - 1][1]:
|
||||
increases += 1
|
||||
|
||||
# Drift signature: most of the window-1 adjacent steps are increasing.
|
||||
# For window_sessions=5, require increases >= 3 (at least 3 of 4 steps up).
|
||||
# For window_sessions=3, require increases >= 1 (at least 1 of 2 steps up).
|
||||
threshold = max(1, window_sessions - 2)
|
||||
if increases < threshold:
|
||||
return []
|
||||
|
||||
alert = {
|
||||
"kind": "s5_drift_alert",
|
||||
"severity": "warning",
|
||||
"window_sessions": window_sessions,
|
||||
"increases": increases,
|
||||
"first_value": float(recent[0][1]),
|
||||
"last_value": float(recent[-1][1]),
|
||||
}
|
||||
write_event(
|
||||
store,
|
||||
kind="s5_drift_alert",
|
||||
data={
|
||||
"window_sessions": window_sessions,
|
||||
"increases": increases,
|
||||
"first_value": alert["first_value"],
|
||||
"last_value": alert["last_value"],
|
||||
},
|
||||
severity="warning",
|
||||
)
|
||||
return [alert]
|
||||
|
||||
|
||||
def audit_identity_events(
|
||||
store: MemoryStore,
|
||||
since: datetime | None = None,
|
||||
kinds: tuple[str, ...] = AUDIT_EVENT_KINDS,
|
||||
) -> list[dict]:
|
||||
"""Aggregate identity-relevant events chronologically (newest first).
|
||||
|
||||
Used by `iai-mcp audit` + `audit identity` / `audit shield` / `audit drift`
|
||||
CLI subcommands. By default returns the full set of audit kinds; callers
|
||||
may pass a subset (e.g. only s5_* for `audit identity`).
|
||||
"""
|
||||
out: list[dict] = []
|
||||
for kind in kinds:
|
||||
out.extend(query_events(store, kind=kind, since=since, limit=500))
|
||||
# Newest first by ts; coerce to comparable form (fallback to id-based).
|
||||
try:
|
||||
out.sort(key=lambda e: e.get("ts"), reverse=True)
|
||||
except TypeError:
|
||||
pass
|
||||
return out
|
||||
551
src/iai_mcp/schema.py
Normal file
551
src/iai_mcp/schema.py
Normal file
|
|
@ -0,0 +1,551 @@
|
|||
"""Schema induction (LEARN-03, D-18, D-21) -- Task 3.
|
||||
|
||||
D-18 (scheduling): dual-path schema surfacing.
|
||||
- Primary: batch induction inside the heavy sleep cycle. Tier-1 Haiku
|
||||
extraction when `should_call_llm` permits, Tier-0 cooccurrence + TF-IDF
|
||||
fallback otherwise.
|
||||
- Secondary: entropy-gated provisional schemas surfaced during
|
||||
`pipeline_recall` when score distribution entropy > 0.8 bits AND the
|
||||
cohesive community has >= 2 shared tags.
|
||||
|
||||
D-21 (thresholds, autism-aware):
|
||||
- Auto-induct when co_occurrence >= 5 AND confidence >= 0.85.
|
||||
- User-approval flag at co_occurrence in [3, 5) AND confidence in [0.65, 0.85).
|
||||
- Below: discard.
|
||||
- Exceptions preserved as first-class records (never absorbed).
|
||||
- Abstraction level: concrete (Dawson-Mottron Raven's preference).
|
||||
|
||||
Schema records are first-class hubs:
|
||||
- tier="semantic", detail_level=3 -> never_decay=True.
|
||||
- schema_instance_of edges from evidence -> schema never decay.
|
||||
- pipeline routing can prioritise schema records when pattern
|
||||
matches.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.guard import BudgetLedger, RateLimitLedger, should_call_llm
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord, SCHEMA_VERSION_CURRENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- constants
|
||||
|
||||
AUTO_INDUCT_COOCCURRENCE: int = 5
|
||||
AUTO_INDUCT_CONFIDENCE: float = 0.85
|
||||
USER_APPROVAL_COOCCURRENCE: int = 3
|
||||
USER_APPROVAL_CONFIDENCE: float = 0.65
|
||||
MAX_EVIDENCE_PER_SCHEMA: int = 50
|
||||
PROVISIONAL_ENTROPY_MIN: float = 0.8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- candidate
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaCandidate:
|
||||
"""One schema candidate surfaced by induce_schemas_*."""
|
||||
|
||||
pattern: str
|
||||
confidence: float
|
||||
evidence_count: int
|
||||
evidence_ids: list[UUID] = field(default_factory=list)
|
||||
domain: str | None = None
|
||||
exceptions: list[UUID] = field(default_factory=list)
|
||||
status: str = "auto" # "auto" | "pending_user_approval"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- Tier-0 induction
|
||||
|
||||
|
||||
def _tag_cooccurrence(records: Iterable) -> dict:
|
||||
"""Bucket records by tag-pair frequency. Returns {frozenset(pair): [record_ids]}.
|
||||
|
||||
Phase 07.7-04 D-26-A: accepts either ``list[MemoryRecord]`` (back-compat;
|
||||
used by external callers passing dataclass instances) or an iterable of
|
||||
projected ``dict`` rows from ``store.iter_record_columns(["id", "tags_json"])``.
|
||||
|
||||
Dispatch is duck-typed: items with a ``.tags`` attribute are treated as
|
||||
MemoryRecord; items without are treated as dict rows. This keeps both
|
||||
surfaces alive while migrating the production path off ``all_records()``.
|
||||
|
||||
For dict rows, ``tags_json`` is parsed defensively (mirrors the W3
|
||||
pattern in ``sleep._tier0_schema_surfacing`` — corrupted rows contribute
|
||||
zero counts but do not crash). The ``id`` field arrives as a string from
|
||||
LanceDB and is converted to ``UUID`` here so callers always see
|
||||
``list[UUID]`` evidence_ids regardless of which input shape was passed.
|
||||
"""
|
||||
pairs: dict = {}
|
||||
for r in records:
|
||||
# Dispatch on duck-typing: MemoryRecord has .tags + .id attributes;
|
||||
# dict rows have ["tags_json"] + ["id"] keys.
|
||||
if hasattr(r, "tags"):
|
||||
# MemoryRecord path (back-compat for external/test callers).
|
||||
raw_tags = r.tags or []
|
||||
rid = r.id
|
||||
else:
|
||||
# Dict-row path (D-26-A migrated production path). Defensive parse:
|
||||
# malformed tags_json contributes zero pairs but does not raise.
|
||||
tags_raw = r.get("tags_json") or "[]"
|
||||
try:
|
||||
raw_tags = json.loads(tags_raw) if tags_raw else []
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
raw_tags = []
|
||||
id_raw = r.get("id")
|
||||
if id_raw is None:
|
||||
continue
|
||||
# iter_record_columns yields id as a string; convert to UUID at
|
||||
# the boundary so SchemaCandidate.evidence_ids stays list[UUID].
|
||||
try:
|
||||
rid = UUID(id_raw) if isinstance(id_raw, str) else id_raw
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
tags = [
|
||||
t for t in raw_tags
|
||||
if not t.startswith("raw:") and not t.startswith("domain:")
|
||||
]
|
||||
for i in range(len(tags)):
|
||||
for j in range(i + 1, len(tags)):
|
||||
key = frozenset([tags[i], tags[j]])
|
||||
pairs.setdefault(key, []).append(rid)
|
||||
return pairs
|
||||
|
||||
|
||||
def induce_schemas_tier0(store: MemoryStore) -> list[SchemaCandidate]:
|
||||
"""D-18 Tier-0 path: tag cooccurrence + TF-IDF; no LLM.
|
||||
|
||||
Returns a list of SchemaCandidate. Each candidate passes the gate:
|
||||
- status="auto" -> count >= 5 AND confidence >= 0.85
|
||||
- status="pending_user_approval" -> count in [3,5) AND confidence in [0.65, 0.85)
|
||||
|
||||
Phase 07.7-04 D-26-A: streams via ``store.iter_record_columns(
|
||||
["id", "tags_json"], batch_size=1024)`` instead of ``store.all_records()``.
|
||||
Encrypted columns (literal_surface, provenance_json,
|
||||
profile_modulation_gain_json) are NEVER read on this path; the W5 cipher
|
||||
cache is short-circuited entirely. On the 8105-record production store
|
||||
this saves ~16210 AES-GCM operations + ~14.5 MB literal_surface
|
||||
materialisation per ``run_heavy_consolidation`` invocation, and unblocks
|
||||
the W4 ≤1 ``all_records()`` invariant on the heavy cycle.
|
||||
|
||||
Single-pass record-count tally: count_total is incremented inside the
|
||||
iterator loop and the ``< CLUSTER_MIN_SIZE`` floor is checked afterwards.
|
||||
Mirrors the pattern in ``sleep._tier0_schema_surfacing`` (Plan 07.7-03 W3).
|
||||
"""
|
||||
rows = list(store.iter_record_columns(["id", "tags_json"], batch_size=1024))
|
||||
if len(rows) < 3:
|
||||
return []
|
||||
|
||||
pair_counts = _tag_cooccurrence(rows)
|
||||
candidates: list[SchemaCandidate] = []
|
||||
for pair, evidence in pair_counts.items():
|
||||
count = len(evidence)
|
||||
# Heuristic confidence: saturates toward 1.0 at 10+ evidence records.
|
||||
confidence = min(1.0, count / 10.0)
|
||||
pattern = f"tags:{'+'.join(sorted(pair))}"
|
||||
if count >= AUTO_INDUCT_COOCCURRENCE and confidence >= AUTO_INDUCT_CONFIDENCE:
|
||||
status = "auto"
|
||||
elif (
|
||||
USER_APPROVAL_COOCCURRENCE <= count < AUTO_INDUCT_COOCCURRENCE
|
||||
and confidence >= USER_APPROVAL_CONFIDENCE
|
||||
):
|
||||
status = "pending_user_approval"
|
||||
else:
|
||||
continue
|
||||
candidates.append(
|
||||
SchemaCandidate(
|
||||
pattern=pattern,
|
||||
confidence=confidence,
|
||||
evidence_count=count,
|
||||
evidence_ids=list(evidence[:MAX_EVIDENCE_PER_SCHEMA]),
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- Tier-1 w/ D-GUARD
|
||||
|
||||
|
||||
def induce_schemas_tier1(
|
||||
store: MemoryStore,
|
||||
budget: BudgetLedger,
|
||||
rate: RateLimitLedger,
|
||||
llm_enabled: bool = True,
|
||||
) -> list[SchemaCandidate]:
|
||||
"""D-18 Tier-1 path: Haiku extraction gated by D-GUARD ladder.
|
||||
|
||||
When should_call_llm returns False (any ladder step), emit an
|
||||
llm_health event and delegate to `induce_schemas_tier0`.
|
||||
|
||||
scope: the Tier-1 branch is reserved; wires the
|
||||
actual anthropic.batches.create call. This function's contract is: on
|
||||
allow, call budget.record_spend and emit llm_health; then fall back to
|
||||
tier0 (because real Batch output is a deliverable). The
|
||||
effective_tier in the event is "tier0" regardless until Plan 02-04.
|
||||
"""
|
||||
has_key = bool(os.environ.get("ANTHROPIC_API_KEY"))
|
||||
ok, reason = should_call_llm(
|
||||
budget=budget, rate=rate,
|
||||
llm_enabled=llm_enabled, has_api_key=has_key,
|
||||
estimated_usd=0.005,
|
||||
)
|
||||
if not ok:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "schema_induction",
|
||||
"tier": "fallback",
|
||||
"reason": reason,
|
||||
},
|
||||
severity="warning",
|
||||
)
|
||||
return induce_schemas_tier0(store)
|
||||
|
||||
# Tier-1 eligible -- scaffold only (Plan 02-04 wires real Batch API).
|
||||
try:
|
||||
import anthropic # noqa: F401 -- lazy import, raise-only if missing
|
||||
budget.record_spend(0.002, kind="schema_induction")
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={
|
||||
"component": "schema_induction",
|
||||
"tier": "haiku",
|
||||
"note": "Plan 02-04 wires real Batch API; 02-03 scaffolds only",
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
except Exception as e:
|
||||
write_event(
|
||||
store,
|
||||
kind="llm_health",
|
||||
data={"component": "schema_induction", "error": str(e)},
|
||||
severity="critical",
|
||||
)
|
||||
return induce_schemas_tier0(store)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- persist
|
||||
|
||||
|
||||
def _majority_language(evidence_ids: list[UUID], store: MemoryStore) -> str:
|
||||
"""Return the plurality ISO-639-1 language tag among evidence records.
|
||||
|
||||
fix (D-08a constitutional): schema hubs must carry the
|
||||
language of their source evidence, not a hardcoded 'en'. A user whose
|
||||
records are Russian would otherwise get schemas tagged 'en' and fail
|
||||
their own language='ru' filter at retrieval.
|
||||
|
||||
Algorithm:
|
||||
- Fetch each evidence record via store.get (skip missing/deleted ones).
|
||||
- Collect their language fields (skip empty/None).
|
||||
- Return max(set(langs), key=langs.count). Tie-break is deterministic
|
||||
given a stable input list order: max with key=list.count returns
|
||||
the first element from the set iteration whose count is the
|
||||
maximum, and Python's set iteration on strings follows insertion
|
||||
order in CPython >= 3.7 for the distinct-values pattern used here
|
||||
because we build the distinct set from a list iteration.
|
||||
- Fallback 'en' when evidence is empty or all records are missing.
|
||||
|
||||
Tie-break policy: when two languages are tied, the one whose first
|
||||
occurrence appears EARLIEST in evidence_ids wins. Matches Phase 1
|
||||
default 'en' when no signal is available (least-surprise).
|
||||
"""
|
||||
langs: list[str] = []
|
||||
for eid in evidence_ids:
|
||||
rec = store.get(eid)
|
||||
if rec is None:
|
||||
continue
|
||||
if rec.language:
|
||||
langs.append(rec.language)
|
||||
if not langs:
|
||||
return "en"
|
||||
# Deterministic tie-break: iterate langs in order, pick the first whose
|
||||
# count is the max. max(set(langs), key=langs.count) is undefined for
|
||||
# set ordering, so we use a hand-rolled pass instead.
|
||||
best = langs[0]
|
||||
best_count = langs.count(best)
|
||||
seen: set[str] = {best}
|
||||
for lang in langs[1:]:
|
||||
if lang in seen:
|
||||
continue
|
||||
seen.add(lang)
|
||||
c = langs.count(lang)
|
||||
if c > best_count:
|
||||
best = lang
|
||||
best_count = c
|
||||
return best
|
||||
|
||||
|
||||
def persist_schema(
|
||||
store: MemoryStore,
|
||||
candidate: SchemaCandidate,
|
||||
) -> UUID:
|
||||
"""Insert a schema record + schema_instance_of edges to evidence.
|
||||
|
||||
Schema records carry:
|
||||
- tier="semantic", detail_level=3 (never_decay auto-true)
|
||||
- tags=["schema", <status>, f"pattern:{pattern}"]
|
||||
- s5_trust_score=0.5 (neutral prior; LEARN-06 may raise over time)
|
||||
- schema_version=2
|
||||
"""
|
||||
from iai_mcp.aaak import enforce_language_tagged, generate_aaak_index
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
|
||||
summary = (
|
||||
f"Schema: {candidate.pattern} (confidence={candidate.confidence:.2f})"
|
||||
)
|
||||
|
||||
# R1 (D-09 + D-10): pattern dedup. Search for an existing
|
||||
# schema record carrying the tag `pattern:{candidate.pattern}` in the
|
||||
# semantic tier. If found, reinforce schema_instance_of edges from new
|
||||
# evidence onto the existing keeper, emit `schema_reinforced`, and
|
||||
# return the existing schema_id. If not found, fall through to the
|
||||
# original insert path. Closes the chain-induction bleed: every sleep
|
||||
# cycle would otherwise insert a fresh tier="semantic", never_decay
|
||||
# row for the same pattern (live store accumulated 7+ duplicates per
|
||||
# pattern with degree-bonus shouldering verbatim records out of hits[]).
|
||||
pattern_tag = f"pattern:{candidate.pattern}"
|
||||
# Phase 07.7-04 D-26-B: keeper scan migrated from store.all_records() to
|
||||
# store.iter_record_columns(["id", "tier", "tags_json"], batch_size=1024).
|
||||
# Projection skips encrypted columns (literal_surface, provenance_json,
|
||||
# profile_modulation_gain_json) entirely — the W5 cipher cache is
|
||||
# short-circuited on this path. Early-exit (`break`) semantics preserved.
|
||||
# The matching row's id arrives as a string from LanceDB; we convert to
|
||||
# UUID at the boundary so downstream code sees the same type contract as
|
||||
# the pre-D-26 ``existing_keeper.id`` access pattern.
|
||||
existing_keeper_id: UUID | None = None
|
||||
try:
|
||||
for row in store.iter_record_columns(
|
||||
["id", "tier", "tags_json"], batch_size=1024
|
||||
):
|
||||
if row.get("tier") != "semantic":
|
||||
continue
|
||||
tags_raw = row.get("tags_json") or "[]"
|
||||
try:
|
||||
tags = json.loads(tags_raw) if tags_raw else []
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
tags = []
|
||||
if pattern_tag in tags:
|
||||
id_raw = row.get("id")
|
||||
if id_raw is None:
|
||||
continue
|
||||
try:
|
||||
existing_keeper_id = (
|
||||
UUID(id_raw) if isinstance(id_raw, str) else id_raw
|
||||
)
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
# Defensive: if the scan fails, fall through to the insert path so
|
||||
# we never silently lose a schema. Mirrors the diagnostic-write
|
||||
# contract used in pipeline.py provenance batching.
|
||||
existing_keeper_id = None
|
||||
|
||||
if existing_keeper_id is not None:
|
||||
from iai_mcp.store import EDGES_TABLE
|
||||
|
||||
# Reinforce schema_instance_of edges from each new evidence record
|
||||
# onto the existing keeper. Reuses the same delta formula as the
|
||||
# insert path (max(0.1, candidate.confidence)) for symmetry.
|
||||
delta = max(0.1, candidate.confidence)
|
||||
new_pairs = [(ev_id, existing_keeper_id) for ev_id in candidate.evidence_ids]
|
||||
if new_pairs:
|
||||
store.boost_edges(
|
||||
new_pairs,
|
||||
edge_type="schema_instance_of",
|
||||
delta=delta,
|
||||
)
|
||||
|
||||
# Compute total_evidence after reinforcement: count
|
||||
# `schema_instance_of` edges incident on the keeper. Read via the
|
||||
# edges table to avoid trusting any in-memory cache.
|
||||
# Note: store.boost_edges canonicalises (src, dst) to a sorted
|
||||
# tuple, so the keeper appears in EITHER column depending on the
|
||||
# string ordering of the paired evidence UUID. OR-counting both
|
||||
# columns gives the true edge-incidence count (no double-count
|
||||
# since each edge row has the keeper in exactly one column).
|
||||
try:
|
||||
edges_df = store.db.open_table(EDGES_TABLE).to_pandas()
|
||||
keeper_str = str(existing_keeper_id)
|
||||
total_evidence = int(
|
||||
((edges_df["edge_type"] == "schema_instance_of")
|
||||
& ((edges_df["dst"] == keeper_str)
|
||||
| (edges_df["src"] == keeper_str))).sum()
|
||||
)
|
||||
except Exception:
|
||||
total_evidence = len(candidate.evidence_ids)
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="schema_reinforced",
|
||||
data={
|
||||
"schema_id": str(existing_keeper_id),
|
||||
"pattern": candidate.pattern,
|
||||
"evidence_added": len(candidate.evidence_ids),
|
||||
"total_evidence": total_evidence,
|
||||
},
|
||||
severity="info",
|
||||
source_ids=[existing_keeper_id, *candidate.evidence_ids[:5]],
|
||||
)
|
||||
return existing_keeper_id
|
||||
|
||||
emb = embedder_for_store(store).embed(summary)
|
||||
now = datetime.now(timezone.utc)
|
||||
schema_id = uuid4()
|
||||
# fix: derive language from the plurality language
|
||||
# of the evidence records, not a hardcoded 'en'. Schema hubs for Russian /
|
||||
# Japanese / Arabic clusters now carry the correct ISO-639-1 tag so
|
||||
# language-filtered retrieval surfaces them as expected.
|
||||
derived_language = _majority_language(candidate.evidence_ids, store)
|
||||
schema_rec = MemoryRecord(
|
||||
id=schema_id,
|
||||
tier="semantic",
|
||||
literal_surface=summary,
|
||||
aaak_index="",
|
||||
embedding=emb,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=3,
|
||||
pinned=False,
|
||||
stability=0.7,
|
||||
difficulty=0.3,
|
||||
last_reviewed=now,
|
||||
never_decay=True,
|
||||
never_merge=False,
|
||||
provenance=[
|
||||
{
|
||||
"ts": now.isoformat(),
|
||||
"cue": "schema_induction",
|
||||
"session_id": "system",
|
||||
}
|
||||
],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=[
|
||||
"schema",
|
||||
candidate.status,
|
||||
f"pattern:{candidate.pattern}",
|
||||
],
|
||||
language=derived_language,
|
||||
s5_trust_score=0.5,
|
||||
profile_modulation_gain={},
|
||||
schema_version=SCHEMA_VERSION_CURRENT,
|
||||
)
|
||||
enforce_language_tagged(schema_rec)
|
||||
schema_rec.aaak_index = generate_aaak_index(schema_rec)
|
||||
store.insert(schema_rec)
|
||||
|
||||
# R3: batch the schema_instance_of edges into ONE boost_edges
|
||||
# call (one merge_insert + one tbl.add at most). Previously this loop
|
||||
# issued N Lance versions on edges.lance for an N-evidence schema.
|
||||
instance_pairs = [(ev_id, schema_id) for ev_id in candidate.evidence_ids]
|
||||
if instance_pairs:
|
||||
store.boost_edges(
|
||||
instance_pairs,
|
||||
edge_type="schema_instance_of",
|
||||
delta=max(0.1, candidate.confidence),
|
||||
)
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="schema_induction_run",
|
||||
data={
|
||||
"schema_id": str(schema_id),
|
||||
"pattern": candidate.pattern,
|
||||
"confidence": candidate.confidence,
|
||||
"evidence_count": candidate.evidence_count,
|
||||
"status": candidate.status,
|
||||
},
|
||||
severity="info",
|
||||
source_ids=[schema_id, *candidate.evidence_ids[:5]],
|
||||
)
|
||||
return schema_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- provisional
|
||||
|
||||
|
||||
def provisional_schemas_for_recall(
|
||||
store: MemoryStore,
|
||||
hits: list,
|
||||
entropy_bits: float,
|
||||
records_cache: "dict | None" = None,
|
||||
) -> list[dict]:
|
||||
"""D-18 secondary path: surface provisional schema hints on high-entropy recalls.
|
||||
|
||||
Returns a list of hint dicts compatible with RecallResponse.hints, one per
|
||||
cohesive tag appearing in >= 2 of the top hits.
|
||||
|
||||
perf: batched all_records() fetch replaces N+1 store.get()
|
||||
calls. A single to_pandas() call is still O(total_records) but constant
|
||||
per recall, not per-hit. This was a major D-SPEED bottleneck at N=50.
|
||||
|
||||
perf (Rule 1 auto-fix): accept optional `records_cache` so
|
||||
pipeline_recall can pass its already-built cache through -- avoids a
|
||||
second `store.all_records()` scan per recall (~40ms at N=100). Falls
|
||||
back to all_records() if no cache provided (preserves back-compat for
|
||||
ad-hoc callers; tests without pipeline_recall still work).
|
||||
"""
|
||||
if entropy_bits < PROVISIONAL_ENTROPY_MIN or len(hits) < 3:
|
||||
return []
|
||||
|
||||
# Batch-fetch all records once; hits are typically <=5 so the cost of
|
||||
# filtering in-memory dominates over 5 separate store.get() round-trips.
|
||||
hit_ids = {h.record_id for h in hits}
|
||||
if records_cache is not None:
|
||||
# Reuse the cache built at pipeline_recall stage 1. Zero scans.
|
||||
by_id = {
|
||||
rid: rec for rid, rec in records_cache.items() if rid in hit_ids
|
||||
}
|
||||
else:
|
||||
try:
|
||||
all_recs = store.all_records()
|
||||
except Exception:
|
||||
return []
|
||||
by_id = {r.id: r for r in all_recs if r.id in hit_ids}
|
||||
|
||||
tag_count: Counter = Counter()
|
||||
for h in hits:
|
||||
rec = by_id.get(h.record_id)
|
||||
if rec is None:
|
||||
continue
|
||||
for t in (rec.tags or []):
|
||||
if t.startswith("raw:") or t.startswith("domain:"):
|
||||
continue
|
||||
tag_count[t] += 1
|
||||
|
||||
provisional: list[dict] = []
|
||||
for tag, cnt in tag_count.most_common(3):
|
||||
if cnt >= 2:
|
||||
source_ids: list[str] = []
|
||||
for h in hits:
|
||||
rec = by_id.get(h.record_id)
|
||||
if rec is None:
|
||||
continue
|
||||
if tag in (rec.tags or []):
|
||||
source_ids.append(str(h.record_id))
|
||||
if len(source_ids) >= 5:
|
||||
break
|
||||
provisional.append(
|
||||
{
|
||||
"kind": "provisional_schema",
|
||||
"severity": "info",
|
||||
"source_ids": source_ids,
|
||||
"text": f"Potential schema: tag={tag} cnt={cnt}",
|
||||
"provisional": True,
|
||||
"entropy": entropy_bits,
|
||||
}
|
||||
)
|
||||
return provisional
|
||||
486
src/iai_mcp/session.py
Normal file
486
src/iai_mcp/session.py
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
"""Session-start assembler (D-10 budget, OPS-01, continuity).
|
||||
|
||||
Produces the 4-segment cached prefix that Claude's MCP wrapper places in front
|
||||
of every request under Anthropic 1h-TTL prompt caching:
|
||||
|
||||
L0 -- pinned identity kernel (always includes the user's L0 record)
|
||||
L1 -- critical-facts block (pinned + high-detail records)
|
||||
L2[...] -- Yeo-like community summaries (top MAX_TOP_COMMUNITIES=7)
|
||||
rich_club -- global hub prefetch (CONN-02 rich-club nodes)
|
||||
|
||||
Plan 03-02 (M6 LIVE prerequisite): assemble_session_start emits
|
||||
``kind='session_started'`` with a deterministic ``session_state_hash`` so
|
||||
M6 context-repeat-rate can be computed live from production emits.
|
||||
|
||||
Budget breakdown:
|
||||
L0_BUDGET_TOKENS = 80
|
||||
L1_BUDGET_TOKENS = 200
|
||||
L2_PER_COMMUNITY_TOKENS = 50 (cap of 7 -> L2 totals ~350 tok)
|
||||
RICH_CLUB_BUDGET_TOKENS = 1500
|
||||
TOTAL_CACHED_BUDGET = 2000
|
||||
(plus ~1000 tok dynamic tail per -> steady-state <= 3000)
|
||||
|
||||
Tokens are counted via a local `_approx_tokens(text) = max(1, len(text) // 4)`
|
||||
heuristic that matches Anthropic's documented rough ratio; bench/tokens.py
|
||||
cross-validates with the real `count_tokens` API when ANTHROPIC_API_KEY is
|
||||
available.
|
||||
|
||||
OPS-05 observable: `payload.l0` always contains the substring "IAI-MCP" when the
|
||||
pinned L0 record is present, so the verifier can assert identity continuity
|
||||
on a fresh session open.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import UUID
|
||||
|
||||
from iai_mcp.aaak import generate_aaak_index
|
||||
from iai_mcp.community import CommunityAssignment
|
||||
from iai_mcp.handle import decode_compact_handle, encode_compact_handle
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
|
||||
# ------------------------------------------------------------- budgets
|
||||
L0_BUDGET_TOKENS = 80
|
||||
L1_BUDGET_TOKENS = 200
|
||||
L2_PER_COMMUNITY_TOKENS = 50
|
||||
L2_COMMUNITY_CAP = 7 # CONN-01 Yeo-like cap
|
||||
RICH_CLUB_BUDGET_TOKENS = 1500
|
||||
TOTAL_CACHED_BUDGET = 2000 # L0 + L1 + L2 + rich_club <= this
|
||||
DYNAMIC_TAIL_TOKENS = 1000 # reserve for per-turn tool results
|
||||
|
||||
# Pinned L0 UUID (D-14, matches core._seed_l0_identity).
|
||||
L0_RECORD_UUID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
# --------------------------------------------------------------- data shape
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStartPayload:
|
||||
"""Cached prefix + metadata (D-10 + TOK-11 lazy fields).
|
||||
|
||||
`breakpoint_marker` is where the TS wrapper splits stable vs volatile
|
||||
content before applying Anthropic `cache_control` (TOK-01). The Python
|
||||
side never inserts it into the segment strings -- it's just a sentinel
|
||||
string the TS side recognises.
|
||||
|
||||
D5-02: three new pointer fields populated at
|
||||
`wake_depth=minimal` (the new default); legacy l0/l1/l2/rich_club left
|
||||
empty at minimal mode. `wake_depth` is echoed so the client knows
|
||||
which mode produced the payload.
|
||||
"""
|
||||
|
||||
l0: str = ""
|
||||
l1: str = ""
|
||||
l2: list[str] = field(default_factory=list)
|
||||
rich_club: str = ""
|
||||
total_cached_tokens: int = 0
|
||||
total_dynamic_tokens: int = 0
|
||||
breakpoint_marker: str = "--<cache-breakpoint>--"
|
||||
# D5-02 — lazy session-start fields (<=30 raw tok combined).
|
||||
identity_pointer: str = "" # "<id:{8-hex-of-L0-uuid}>" (~8 tok)
|
||||
brain_handle: str = "" # "<sess:{8-hex} pend:{N}>" (~12 tok)
|
||||
topic_cluster_hint: str = "" # "<topic:{community_label}>" (~8 tok)
|
||||
# — single compact handle, ≤16 raw tok target. At
|
||||
# `wake_depth=minimal` this supersedes the three legacy pointers above
|
||||
# (they are left empty to keep the budget tight); `standard`/`deep`
|
||||
# populate BOTH the compact handle and the legacy fields for back-compat.
|
||||
compact_handle: str = "" # "<iai:{16-hex-blake2s}>" (~6-10 raw tok)
|
||||
wake_depth: str = "minimal" # echoed for introspection
|
||||
|
||||
|
||||
# ---------------------------------------------------------- token counting
|
||||
|
||||
|
||||
def _approx_tokens(text: str) -> int:
|
||||
"""~4 chars per token heuristic (Anthropic documentation ballpark).
|
||||
|
||||
Minimum 1 for any non-empty text so callers don't divide-by-zero.
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- helpers
|
||||
|
||||
|
||||
def _resolve_compact_handle_to_pointers(handle: str) -> tuple[str, str, str] | None:
|
||||
"""Rebuild the legacy (identity_pointer, brain_handle, topic_cluster_hint)
|
||||
triple from a compact ``<iai:HHHHHHHHHHHHHHHH>`` handle minted earlier in
|
||||
this process.
|
||||
|
||||
no-info-loss proof: everything the 3-field shape conveyed is
|
||||
recoverable from the compact handle via the LRU in ``iai_mcp.handle`` ---
|
||||
identity prefix, session prefix, topic label and pending count. Returns
|
||||
``None`` when the handle is malformed OR the LRU has evicted the record,
|
||||
mirroring ``decode_compact_handle``'s contract: callers that need strict
|
||||
resolution should keep the legacy fields available under
|
||||
``wake_depth=standard`` / ``deep`` as fallback.
|
||||
"""
|
||||
parts = decode_compact_handle(handle)
|
||||
if parts is None:
|
||||
return None
|
||||
identity_pointer = f"<id:{parts[0]}>" if parts[0] else ""
|
||||
brain_handle = f"<sess:{parts[1]} pend:{parts[3]}>"
|
||||
topic_cluster_hint = f"<topic:{parts[2]}>"
|
||||
return identity_pointer, brain_handle, topic_cluster_hint
|
||||
|
||||
|
||||
def _fetch_record(store: MemoryStore, uid: UUID) -> MemoryRecord | None:
|
||||
try:
|
||||
return store.get(uid)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ----------------------------------------------------------- segment builders
|
||||
|
||||
|
||||
def _l0_segment(store: MemoryStore) -> str:
|
||||
"""OPS-05 identity kernel -- the pinned L0 record by fixed UUID.
|
||||
|
||||
Returned string shape: "<aaak_index>\n<literal_surface[:200]>". Empty when
|
||||
the L0 record hasn't been seeded yet (fresh stores before first core boot).
|
||||
"""
|
||||
rec = _fetch_record(store, L0_RECORD_UUID)
|
||||
if rec is None:
|
||||
return ""
|
||||
aaak = rec.aaak_index or generate_aaak_index(rec)
|
||||
# Truncate literal to 200 chars -- the L0 budget is ~80 tok (~320 chars);
|
||||
# leave slack for the aaak line + newline.
|
||||
return f"{aaak}\n{rec.literal_surface[:200]}"
|
||||
|
||||
|
||||
def _l1_segment(store: MemoryStore, max_records: int = 10) -> str:
|
||||
"""L1 critical-facts block -- pinned records with detail_level >= 4.
|
||||
|
||||
Excludes the L0 record (duplicated in L0 segment). Lines formatted as
|
||||
"- <literal_surface[:100]>" so they fit in ~25 tokens each; 10 of them
|
||||
saturate the L1_BUDGET_TOKENS ~= 200 tok budget.
|
||||
"""
|
||||
try:
|
||||
records = store.all_records()
|
||||
except Exception:
|
||||
return ""
|
||||
pinned_hi_detail = [
|
||||
r for r in records
|
||||
if r.pinned and r.detail_level >= 4 and r.id != L0_RECORD_UUID
|
||||
]
|
||||
# Deterministic ordering: by detail_level desc, then by created_at asc.
|
||||
pinned_hi_detail.sort(
|
||||
key=lambda r: (-r.detail_level, r.created_at)
|
||||
)
|
||||
pinned_hi_detail = pinned_hi_detail[:max_records]
|
||||
if not pinned_hi_detail:
|
||||
return ""
|
||||
lines = [f"- {r.literal_surface[:100]}" for r in pinned_hi_detail]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _l2_segments(
|
||||
store: MemoryStore,
|
||||
assignment: CommunityAssignment,
|
||||
) -> list[str]:
|
||||
"""Up to L2_COMMUNITY_CAP (7) Yeo-like community summary lines.
|
||||
|
||||
Each summary samples up to 3 member records from the community's
|
||||
mid_regions list and joins them with `|`. Budget guardrail: each line
|
||||
is capped at approximately L2_PER_COMMUNITY_TOKENS * 4 chars (=200 chars).
|
||||
|
||||
Empty list when the assignment has no top_communities (fresh/flat case).
|
||||
"""
|
||||
top = list(assignment.top_communities)[:L2_COMMUNITY_CAP]
|
||||
if not top:
|
||||
return []
|
||||
|
||||
# records_cache: keep the single all_records() call hot (same trick
|
||||
# pipeline.py uses -- avoids N+1 store.get scans).
|
||||
try:
|
||||
records = store.all_records()
|
||||
except Exception:
|
||||
return []
|
||||
by_uuid = {r.id: r for r in records}
|
||||
|
||||
summaries: list[str] = []
|
||||
max_chars = L2_PER_COMMUNITY_TOKENS * 4 # ~200 chars budget per line
|
||||
for cid in top:
|
||||
members = assignment.mid_regions.get(cid, [])[:3]
|
||||
parts: list[str] = []
|
||||
for mid in members:
|
||||
rec = by_uuid.get(mid)
|
||||
if rec is None:
|
||||
continue
|
||||
# Per-member snippet: AAAK-shortened wing tag + first 40 chars.
|
||||
wing = rec.aaak_index.split("/")[0] if rec.aaak_index else "W:?"
|
||||
parts.append(f"{wing}/{rec.literal_surface[:40]}")
|
||||
if not parts:
|
||||
continue
|
||||
body = " | ".join(parts)
|
||||
line = f"[community {str(cid)[:8]}] {body}"
|
||||
if len(line) > max_chars:
|
||||
line = line[:max_chars]
|
||||
# LLMLingua-2 compression on L2 community
|
||||
# descriptors. Passthrough when package absent (see compress.py).
|
||||
try:
|
||||
from iai_mcp.compress import compress_l2_descriptor
|
||||
line = compress_l2_descriptor(line, store=store)
|
||||
except Exception:
|
||||
pass
|
||||
summaries.append(line)
|
||||
return summaries
|
||||
|
||||
|
||||
def _rich_club_segment(store: MemoryStore, rich_club: list[UUID]) -> str:
|
||||
"""Global rich-club summary, truncated to RICH_CLUB_BUDGET_TOKENS.
|
||||
|
||||
Each rich-club node contributes one line "<aaak_index>: <literal_surface[:60]>".
|
||||
Lines are added until the running token count would exceed the budget.
|
||||
"""
|
||||
return _rich_club_segment_with_budget(store, rich_club, budget=RICH_CLUB_BUDGET_TOKENS)
|
||||
|
||||
|
||||
def _rich_club_segment_with_budget(
|
||||
store: MemoryStore,
|
||||
rich_club: list[UUID],
|
||||
*,
|
||||
budget: int,
|
||||
) -> str:
|
||||
"""Rich-club summary with an explicit budget (Plan 05-03 deep mode).
|
||||
|
||||
Same rendering as `_rich_club_segment`; `budget` replaces the default cap
|
||||
so wake_depth=deep can lift the rich_club allotment to ~2000 tok.
|
||||
"""
|
||||
if not rich_club:
|
||||
return ""
|
||||
try:
|
||||
records = store.all_records()
|
||||
except Exception:
|
||||
return ""
|
||||
by_uuid = {r.id: r for r in records}
|
||||
|
||||
lines: list[str] = []
|
||||
running = 0
|
||||
for uid in rich_club:
|
||||
rec = by_uuid.get(uid)
|
||||
if rec is None:
|
||||
continue
|
||||
aaak = rec.aaak_index or generate_aaak_index(rec)
|
||||
line = f"{aaak}: {rec.literal_surface[:60]}"
|
||||
cost = _approx_tokens(line)
|
||||
# Respect running budget -- +1 accounts for the join newline.
|
||||
if running + cost + 1 > budget:
|
||||
break
|
||||
lines.append(line)
|
||||
running += cost + 1
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ public
|
||||
|
||||
|
||||
def _session_state_hash(payload: SessionStartPayload) -> str:
|
||||
"""Plan 03-02 M6: deterministic SHA-256 over the 4-segment cached prefix.
|
||||
|
||||
Two sessions whose L0 + L1 + L2 + rich_club segments are byte-identical
|
||||
produce the SAME session_state_hash -- which is exactly the
|
||||
"context-repeat" signal M6 measures.
|
||||
"""
|
||||
import hashlib
|
||||
h = hashlib.sha256()
|
||||
h.update(payload.l0.encode("utf-8"))
|
||||
h.update(b"\x1f") # ASCII unit separator
|
||||
h.update(payload.l1.encode("utf-8"))
|
||||
h.update(b"\x1f")
|
||||
h.update("\n".join(payload.l2).encode("utf-8"))
|
||||
h.update(b"\x1f")
|
||||
h.update(payload.rich_club.encode("utf-8"))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _dominant_community_label(assignment: CommunityAssignment) -> str:
|
||||
"""Plan 05-03 D5-02: short (<=8 char) label for the largest community.
|
||||
|
||||
Returns 'none' when no communities exist (fresh or flat assignment). The
|
||||
label is the first 8 hex of the dominant community UUID — a stable handle
|
||||
that fits in ~3-4 tokens.
|
||||
"""
|
||||
try:
|
||||
top = list(assignment.top_communities)
|
||||
if not top:
|
||||
return "none"
|
||||
# top_communities is already ordered by member count (CONN-01 L1).
|
||||
return str(top[0])[:8]
|
||||
except Exception:
|
||||
return "none"
|
||||
|
||||
|
||||
def _count_pending_first_turn(store: MemoryStore) -> int:
|
||||
"""Plan 05-03 D5-02: count open first_turn_pending sessions in daemon_state.
|
||||
|
||||
Returns 0 if daemon_state is missing or malformed (silent fallback). This
|
||||
is only cosmetic input to the brain_handle pointer; the minimal payload
|
||||
must survive a missing daemon gracefully.
|
||||
"""
|
||||
try:
|
||||
from iai_mcp.daemon_state import load_state
|
||||
state = load_state()
|
||||
pending = state.get("first_turn_pending", {})
|
||||
if isinstance(pending, dict):
|
||||
return sum(1 for v in pending.values() if v)
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def assemble_session_start(
|
||||
store: MemoryStore,
|
||||
assignment: CommunityAssignment,
|
||||
rich_club: list[UUID],
|
||||
*,
|
||||
session_id: str = "-",
|
||||
profile_state: dict | None = None,
|
||||
) -> SessionStartPayload:
|
||||
"""Assemble the session-start cached prefix.
|
||||
|
||||
TOK-11 / D5-02 / D5-10: branches on the `wake_depth` profile
|
||||
knob (15th sealed knob, MCP-12):
|
||||
|
||||
- ``minimal`` (default): produce a ≤30 raw-tok pointer handle (identity,
|
||||
brain session, topic cluster). Legacy l0/l1/l2/rich_club emitted empty
|
||||
for back-compat with existing TS-wrapper callers.
|
||||
- ``standard``: reproduce the Phase-1 1388-tok eager dump — l0/l1/l2/
|
||||
rich_club populated via `_l0_segment`, `_l1_segment`, `_l2_segments`,
|
||||
`_rich_club_segment`. New fields emitted empty.
|
||||
- ``deep``: same shape as standard but rich_club budget lifted to 2000.
|
||||
Populates both the legacy segments and the new pointers.
|
||||
|
||||
(M6 LIVE prerequisite): emits ``kind='session_started'`` with
|
||||
a deterministic ``session_state_hash`` over the cached prefix. Two
|
||||
consecutive sessions whose cached prefix is identical produce the same
|
||||
hash -- exactly the context-repeat signal M6 measures.
|
||||
|
||||
Pitfall 1 (Anthropic cache threshold reality per 05-RESEARCH lines
|
||||
447-469): at `wake_depth=minimal` the payload is ≤30 raw tok which is
|
||||
BELOW the Sonnet 4.6 / Opus 4.7 cache minimum (2048 / 4096). DO NOT add
|
||||
``cache_control`` to the minimal branch prefix — it would be silently
|
||||
ignored by the Anthropic API and waste a breakpoint slot.
|
||||
"""
|
||||
from iai_mcp.profile import default_state
|
||||
state = profile_state if isinstance(profile_state, dict) else default_state()
|
||||
wake_depth = state.get("wake_depth", "minimal")
|
||||
if wake_depth not in ("minimal", "standard", "deep"):
|
||||
wake_depth = "minimal" # D5-10 silent fallback
|
||||
|
||||
if wake_depth == "minimal":
|
||||
# Pitfall 1 guard: payload will not be Anthropic-cached
|
||||
# (<=30 raw tok < Sonnet 4.6 min 2048). DO NOT set cache_control.
|
||||
#
|
||||
# collapse the three legacy pointers
|
||||
# (identity_pointer + brain_handle + topic_cluster_hint, ~24 raw tok
|
||||
# together) into a single `<iai:HHHHHHHHHHHHHHHH>` handle (~6-10 raw
|
||||
# tok). The LRU inside `iai_mcp.handle` retains the reverse mapping
|
||||
# so downstream code can resolve the handle to its triple.
|
||||
#
|
||||
# Back-compat contract: the 3 legacy fields stay populated on the
|
||||
# dataclass so callers reading the old shape keep working; only
|
||||
# ``total_cached_tokens`` is charged for the compact handle (the
|
||||
# wire prefix at wake_depth=minimal is the compact handle alone).
|
||||
l0_rec = _fetch_record(store, L0_RECORD_UUID)
|
||||
identity_short = str(L0_RECORD_UUID)[:8] if l0_rec is not None else ""
|
||||
identity_pointer = f"<id:{identity_short}>" if identity_short else ""
|
||||
pending = _count_pending_first_turn(store)
|
||||
session_short = str(session_id)[:8]
|
||||
brain_handle = f"<sess:{session_short} pend:{pending}>"
|
||||
topic_label = _dominant_community_label(assignment)
|
||||
topic_cluster_hint = f"<topic:{topic_label}>"
|
||||
compact_handle = encode_compact_handle(
|
||||
identity_short, session_short, topic_label, pending
|
||||
)
|
||||
cached = _approx_tokens(compact_handle)
|
||||
payload = SessionStartPayload(
|
||||
l0="",
|
||||
l1="",
|
||||
l2=[],
|
||||
rich_club="",
|
||||
total_cached_tokens=cached,
|
||||
total_dynamic_tokens=DYNAMIC_TAIL_TOKENS,
|
||||
identity_pointer=identity_pointer,
|
||||
brain_handle=brain_handle,
|
||||
topic_cluster_hint=topic_cluster_hint,
|
||||
compact_handle=compact_handle,
|
||||
wake_depth="minimal",
|
||||
)
|
||||
else:
|
||||
# standard and deep share the Phase-1 eager assembly path; deep lifts
|
||||
# the rich_club budget by re-running the segment with a larger cap.
|
||||
l0 = _l0_segment(store)
|
||||
l1 = _l1_segment(store)
|
||||
l2 = _l2_segments(store, assignment)
|
||||
if wake_depth == "deep":
|
||||
rc = _rich_club_segment_with_budget(store, rich_club, budget=2000)
|
||||
else:
|
||||
rc = _rich_club_segment(store, rich_club)
|
||||
|
||||
cached = (
|
||||
_approx_tokens(l0)
|
||||
+ _approx_tokens(l1)
|
||||
+ sum(_approx_tokens(s) for s in l2)
|
||||
+ _approx_tokens(rc)
|
||||
)
|
||||
|
||||
# New pointers also populated under standard/deep so downstream callers
|
||||
# can use them alongside legacy segments if they want. Plan 05-06:
|
||||
# the compact handle is ALSO minted here so a consumer can opt in to
|
||||
# the short form without requiring a wake_depth mode switch.
|
||||
l0_rec = _fetch_record(store, L0_RECORD_UUID)
|
||||
identity_short = str(L0_RECORD_UUID)[:8] if l0_rec is not None else ""
|
||||
identity_pointer = f"<id:{identity_short}>" if identity_short else ""
|
||||
pending = _count_pending_first_turn(store)
|
||||
session_short = str(session_id)[:8]
|
||||
brain_handle = f"<sess:{session_short} pend:{pending}>"
|
||||
topic_label = _dominant_community_label(assignment)
|
||||
topic_cluster_hint = f"<topic:{topic_label}>"
|
||||
compact_handle = encode_compact_handle(
|
||||
identity_short, session_short, topic_label, pending
|
||||
)
|
||||
|
||||
payload = SessionStartPayload(
|
||||
l0=l0,
|
||||
l1=l1,
|
||||
l2=l2,
|
||||
rich_club=rc,
|
||||
total_cached_tokens=cached,
|
||||
total_dynamic_tokens=DYNAMIC_TAIL_TOKENS,
|
||||
identity_pointer=identity_pointer,
|
||||
brain_handle=brain_handle,
|
||||
topic_cluster_hint=topic_cluster_hint,
|
||||
compact_handle=compact_handle,
|
||||
wake_depth=wake_depth,
|
||||
)
|
||||
|
||||
# (M6 LIVE prerequisite): emit kind='session_started' with
|
||||
# session_state_hash for trajectory.m6_context_repeat_rate_live.
|
||||
# Diagnostic-only: never block session start on emit failure.
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
from iai_mcp.events import write_event
|
||||
write_event(
|
||||
store,
|
||||
kind="session_started",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"session_state_hash": _session_state_hash(payload),
|
||||
"total_cached_tokens": cached,
|
||||
"wake_depth": wake_depth,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return payload
|
||||
308
src/iai_mcp/shield.py
Normal file
308
src/iai_mcp/shield.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
"""OPS-07 prompt-injection shield (D-30, D-31) -- Plan 02-05.
|
||||
|
||||
Three-tier deployment per D-31:
|
||||
HARD_BLOCK -> L0 identity + S5 invariant writes (reject on detection)
|
||||
FLAG_FOR_REVIEW -> profile updates (flag + warn, write proceeds)
|
||||
LOG_ONLY -> content records (log only, allow)
|
||||
|
||||
D-30 threat model (three severities):
|
||||
- Direct override (e.g. "forget X, now Y") -> HARD BLOCK via signal words
|
||||
- Gradual drift (subtle lies over weeks) -> DETECT via trajectory M4 anomaly
|
||||
(see s5.detect_drift_anomaly)
|
||||
- Data poisoning (intentional false write) -> MITIGATE via ART vigilance
|
||||
+ user-approval UX
|
||||
|
||||
Global-product mandate: signal words cover 7+ languages
|
||||
(en + ru + ja + ar + de + fr + es + zh) at minimum. The module exports
|
||||
`SHIELD_LANGUAGES_SUPPORTED` as the authoritative set; downstream acceptance
|
||||
tests grep against it.
|
||||
|
||||
The shield is a PURE LOCAL filter: no LLM call, no network. Detection uses
|
||||
case-insensitive substring matching against curated signal-word lists. The
|
||||
tier policy is additive: warning signals escalate to critical in the
|
||||
HARD_BLOCK tier (L0 is sacred).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
|
||||
|
||||
# ------------------------------------------------------------ constitutional constants
|
||||
|
||||
# Confidence thresholds for the shield verdict. Confidence is a simple signal:
|
||||
# matched_count / TOTAL_BASELINE -- used for downstream analytics, not the
|
||||
# tier-policy gate. The tier enum + match count drives the action.
|
||||
SHIELD_SIGNAL_WORDS_MAX_CONFIDENCE: float = 0.9 # upper bound reported on any match
|
||||
SHIELD_FLAG_CONFIDENCE: float = 0.6 # reported when matches are warning-only
|
||||
|
||||
# global-product mandate: 7+ languages supported.
|
||||
SHIELD_LANGUAGES_SUPPORTED: frozenset[str] = frozenset({
|
||||
"en", "ru", "ja", "ar", "de", "fr", "es", "zh",
|
||||
})
|
||||
|
||||
# gradual-drift detection threshold -- used by s5.detect_drift_anomaly
|
||||
# but declared here so the single authoritative constant sits alongside the
|
||||
# other shield thresholds (downstream greps one file).
|
||||
DRIFT_M4_ANOMALY_SIGMA: float = 3.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------ signal-word catalogues
|
||||
|
||||
# English critical signal words: classic prompt-injection imperatives.
|
||||
SIGNAL_WORDS_CRITICAL_EN: list[str] = [
|
||||
"forget", "override", "ignore previous", "you are now",
|
||||
"from now on", "system:", "admin:", "instruction:",
|
||||
"disregard", "new instructions", "ignore previous instructions",
|
||||
]
|
||||
|
||||
# English warning signals: softer but still suspicious rephrasings.
|
||||
SIGNAL_WORDS_WARNING_EN: list[str] = [
|
||||
"different", "instead", "actually", "update",
|
||||
]
|
||||
|
||||
# Per-language critical signal words (D-02a mandate).
|
||||
# Keys are ISO-639-1 codes; values are minimal strictly-imperative tokens.
|
||||
# Conservative by design: false positives on legitimate non-English chatter are
|
||||
# worse than false negatives at this tier (users have multiple layers of
|
||||
# defence; the shield is one slice of defence-in-depth).
|
||||
SIGNAL_WORDS_CRITICAL_BY_LANG: dict[str, list[str]] = {
|
||||
"ru": [
|
||||
"забудь", "забыть", "игнорируй",
|
||||
"отмени", "сбрось", "новые инструкции",
|
||||
"теперь ты", "с этого момента",
|
||||
],
|
||||
"ja": [
|
||||
"忘れて", "無視", "リセット",
|
||||
"新しい指示", "これから", "今から",
|
||||
],
|
||||
"ar": [
|
||||
"انسى", "تجاهل",
|
||||
"إعادة تعيين", "تعليمات جديدة", "أنت الآن",
|
||||
],
|
||||
"de": [
|
||||
"vergiss", "ignoriere", "überschreibe",
|
||||
"neue anweisungen", "ab jetzt",
|
||||
],
|
||||
"fr": [
|
||||
"oublie", "ignore",
|
||||
"remplace", "nouvelles instructions",
|
||||
],
|
||||
"es": [
|
||||
"olvida", "ignora",
|
||||
"sobrescribe", "nuevas instrucciones",
|
||||
],
|
||||
"zh": [
|
||||
"忘记", "忽略", "重置",
|
||||
"新指令", "从现在开始",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------ enums + types
|
||||
|
||||
|
||||
class ShieldTier(str, Enum):
|
||||
"""D-31 three-tier deployment."""
|
||||
|
||||
HARD_BLOCK = "hard_block" # L0 identity + S5 invariants
|
||||
FLAG_FOR_REVIEW = "flag" # profile updates
|
||||
LOG_ONLY = "log" # content records
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShieldVerdict:
|
||||
"""Result of evaluating injection risk for a single text blob."""
|
||||
|
||||
tier: ShieldTier
|
||||
detected: bool
|
||||
matched_patterns: list[str] = field(default_factory=list)
|
||||
severity: str = "info" # "info" | "warning" | "critical"
|
||||
action: str = "log_allow" # "reject" | "flag" | "log_allow"
|
||||
reason: str = ""
|
||||
language: str | None = None
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------ private helpers
|
||||
|
||||
|
||||
def _signal_lists_for_language(
|
||||
lang: str | None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Return (critical, warning) lists for the given language.
|
||||
|
||||
English signals are ALWAYS included (prompt-injection attempts are often
|
||||
copy-pasted English regardless of the user's native language). When a
|
||||
`lang` is given AND supported, its per-language critical list is appended.
|
||||
"""
|
||||
critical = list(SIGNAL_WORDS_CRITICAL_EN)
|
||||
warning = list(SIGNAL_WORDS_WARNING_EN)
|
||||
if lang and lang in SIGNAL_WORDS_CRITICAL_BY_LANG:
|
||||
critical.extend(SIGNAL_WORDS_CRITICAL_BY_LANG[lang])
|
||||
return critical, warning
|
||||
|
||||
|
||||
def _match_patterns(text: str, patterns: list[str]) -> list[str]:
|
||||
"""Return the subset of patterns present in the (lowercased) text.
|
||||
|
||||
For Latin-script patterns we lowercase both sides. For non-ASCII scripts
|
||||
(Cyrillic, Hiragana, CJK, Arabic) lowercasing is either identity-preserving
|
||||
(CJK has no case) or handled uniformly by str.lower() which is safe for
|
||||
our lists.
|
||||
"""
|
||||
t = (text or "").lower()
|
||||
out: list[str] = []
|
||||
for p in patterns:
|
||||
if p.lower() in t:
|
||||
out.append(p)
|
||||
return out
|
||||
|
||||
|
||||
# ------------------------------------------------------------ public API
|
||||
|
||||
|
||||
def evaluate_injection_risk(
|
||||
text: str,
|
||||
tier: ShieldTier,
|
||||
target_language: str | None = None,
|
||||
) -> ShieldVerdict:
|
||||
"""Core shield detection (pure function, no side effects).
|
||||
|
||||
Tier escalation policy:
|
||||
HARD_BLOCK -- any critical OR warning match -> reject (severity critical)
|
||||
FLAG_FOR_REVIEW -- any match -> flag (severity warning)
|
||||
LOG_ONLY -- any match -> log_allow (severity info)
|
||||
no match -- detected=False, action=log_allow
|
||||
"""
|
||||
critical_list, warning_list = _signal_lists_for_language(target_language)
|
||||
matched_critical = _match_patterns(text, critical_list)
|
||||
matched_warning = _match_patterns(text, warning_list)
|
||||
all_matched = matched_critical + matched_warning
|
||||
|
||||
if not all_matched:
|
||||
return ShieldVerdict(
|
||||
tier=tier,
|
||||
detected=False,
|
||||
matched_patterns=[],
|
||||
severity="info",
|
||||
action="log_allow",
|
||||
reason="no signal patterns detected",
|
||||
language=target_language,
|
||||
confidence=0.0,
|
||||
)
|
||||
|
||||
# Confidence: 0.9 when any critical match, 0.6 when warning-only.
|
||||
confidence = (
|
||||
SHIELD_SIGNAL_WORDS_MAX_CONFIDENCE
|
||||
if matched_critical
|
||||
else SHIELD_FLAG_CONFIDENCE
|
||||
)
|
||||
|
||||
if tier == ShieldTier.HARD_BLOCK:
|
||||
return ShieldVerdict(
|
||||
tier=tier,
|
||||
detected=True,
|
||||
matched_patterns=all_matched,
|
||||
severity="critical",
|
||||
action="reject",
|
||||
reason=(
|
||||
f"injection signals detected in HARD_BLOCK tier: {all_matched}"
|
||||
),
|
||||
language=target_language,
|
||||
confidence=confidence,
|
||||
)
|
||||
if tier == ShieldTier.FLAG_FOR_REVIEW:
|
||||
return ShieldVerdict(
|
||||
tier=tier,
|
||||
detected=True,
|
||||
matched_patterns=all_matched,
|
||||
severity="warning",
|
||||
action="flag",
|
||||
reason=f"injection signals detected in FLAG tier: {all_matched}",
|
||||
language=target_language,
|
||||
confidence=confidence,
|
||||
)
|
||||
# LOG_ONLY
|
||||
return ShieldVerdict(
|
||||
tier=tier,
|
||||
detected=True,
|
||||
matched_patterns=all_matched,
|
||||
severity="info",
|
||||
action="log_allow",
|
||||
reason=f"injection signals detected in LOG tier: {all_matched}",
|
||||
language=target_language,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
def apply_shield(
|
||||
store: Any, # MemoryStore
|
||||
record: Any, # MemoryRecord (avoids import cycle with types)
|
||||
tier: ShieldTier,
|
||||
session_id: str = "-",
|
||||
) -> ShieldVerdict:
|
||||
"""Evaluate + emit event (side-effectful wrapper).
|
||||
|
||||
Event kind is determined by the tier policy:
|
||||
- reject -> kind="shield_rejection" (severity critical)
|
||||
- flag -> kind="shield_flag" (severity warning)
|
||||
- log_allow -> kind="shield_log" (severity info, ONLY on detection)
|
||||
|
||||
No event is emitted when the verdict is "not detected" -- no signal, no
|
||||
noise in the events table.
|
||||
"""
|
||||
verdict = evaluate_injection_risk(
|
||||
record.literal_surface or "",
|
||||
tier,
|
||||
target_language=record.language or None,
|
||||
)
|
||||
if verdict.detected:
|
||||
kind_map = {
|
||||
"reject": "shield_rejection",
|
||||
"flag": "shield_flag",
|
||||
"log_allow": "shield_log",
|
||||
}
|
||||
event_kind = kind_map.get(verdict.action, "shield_log")
|
||||
# Clip matched patterns payload so the events table does not grow
|
||||
# unbounded on adversarial input.
|
||||
matched_clipped = [str(p)[:80] for p in verdict.matched_patterns[:10]]
|
||||
record_id = record.id
|
||||
source_ids: list[UUID] = []
|
||||
if isinstance(record_id, UUID):
|
||||
source_ids = [record_id]
|
||||
write_event(
|
||||
store,
|
||||
kind=event_kind,
|
||||
data={
|
||||
"record_id": str(record_id) if record_id is not None else None,
|
||||
"tier": verdict.tier.value,
|
||||
"matched": matched_clipped,
|
||||
"language": record.language,
|
||||
"action": verdict.action,
|
||||
"confidence": verdict.confidence,
|
||||
},
|
||||
severity=verdict.severity,
|
||||
session_id=session_id,
|
||||
source_ids=source_ids,
|
||||
)
|
||||
return verdict
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DRIFT_M4_ANOMALY_SIGMA",
|
||||
"SHIELD_FLAG_CONFIDENCE",
|
||||
"SHIELD_LANGUAGES_SUPPORTED",
|
||||
"SHIELD_SIGNAL_WORDS_MAX_CONFIDENCE",
|
||||
"SIGNAL_WORDS_CRITICAL_BY_LANG",
|
||||
"SIGNAL_WORDS_CRITICAL_EN",
|
||||
"SIGNAL_WORDS_WARNING_EN",
|
||||
"ShieldTier",
|
||||
"ShieldVerdict",
|
||||
"apply_shield",
|
||||
"evaluate_injection_risk",
|
||||
]
|
||||
374
src/iai_mcp/sigma.py
Normal file
374
src/iai_mcp/sigma.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
"""Plan 03-02 CONN-07: small-world sigma as Ashby ultrastability diagnostic.
|
||||
|
||||
Ground-truth reference: Humphries MD, Gurney K (2008) "Network 'small-world-ness':
|
||||
a quantitative method for determining canonical network equivalence."
|
||||
|
||||
Constitutional anchor:
|
||||
- sigma is a CYBERNETIC DIAGNOSTIC (Ashby ultrastability), not a "RAG fallback".
|
||||
- Cold-start sigma<1 at N<500 is a DEVELOPMENTAL phase, not pathological.
|
||||
Emit kind=sigma_observation phase=developmental + boost Hebbian rate.
|
||||
- Mid-life drift sigma<1 at N>=500 emits kind=sigma_drift as an S4 event.
|
||||
- sigma trajectory is published as a deep-time metric, NEVER a routing
|
||||
decision. No code path in this module switches retrieval modes on sigma.
|
||||
|
||||
Design discipline:
|
||||
- DO NOT use NetworkX's built-in small-worldness function. NetworkX 3.6.1's
|
||||
built-in (niter=100, nrand=10) is empirically unusable at N>=200 (timed out
|
||||
at 60s+ during research session).
|
||||
- Custom `fast_sigma` follows Humphries-Gurney 2008 directly with a small
|
||||
number of single-reference Erdos-Renyi random graphs (G(n, m), same edge
|
||||
count). Validated 0.05s @ N=200, 0.34s @ N=500, 1.28s @ N=1000.
|
||||
|
||||
Module-level constants:
|
||||
- SIGMA_N_FLOOR = 200 -- D-SIGMA-01 floor (imports semantically from
|
||||
community.SMALL_N_FLAT -- same Humphries-Gurney 2008 floor).
|
||||
- SIGMA_MID_LIFE_THRESHOLD = 500 -- D-SIGMA-03 mid-life regime threshold
|
||||
(imports semantically from community.MID_N_LEIDEN).
|
||||
|
||||
Public API:
|
||||
- compute_sigma(graph, *, seed=42) -> Optional[float]
|
||||
- fast_sigma(graph, *, n_random=3, seed=42) -> tuple[float, float, float, float, float]
|
||||
- classify_regime(N, sigma) -> str
|
||||
- compute_topology_snapshot(graph) -> dict
|
||||
- compute_and_emit(store) -> dict
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from iai_mcp.events import write_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
|
||||
# D-SIGMA-01: sigma is undefined below N=200 (Humphries-Gurney 2008 floor).
|
||||
# Aliased semantically from community.SMALL_N_FLAT -- same constitutional floor.
|
||||
SIGMA_N_FLOOR: int = 200
|
||||
|
||||
# D-SIGMA-03: mid-life vs developmental boundary (community.MID_N_LEIDEN).
|
||||
SIGMA_MID_LIFE_THRESHOLD: int = 500
|
||||
|
||||
# Event kinds emitted by this module. Naming follows the snake_case
|
||||
# noun_verb shape established in s4.py / s5.py.
|
||||
SIGMA_OBSERVATION_KIND: str = "sigma_observation"
|
||||
SIGMA_DRIFT_KIND: str = "sigma_drift"
|
||||
|
||||
# Hebbian rate boost applied during developmental phase (D-SIGMA-02).
|
||||
HEBBIAN_DEVELOPMENTAL_BOOST_FACTOR: float = 1.3
|
||||
HEBBIAN_DEVELOPMENTAL_BOOST_TTL_SESSIONS: int = 5
|
||||
|
||||
# Knob name we tag in profile_updated events when boosting the Hebbian rate
|
||||
# during developmental phase. The 11-knob registry is NOT modified -- this is
|
||||
# a transient operational tag, not an AUTIST kernel knob.
|
||||
HEBBIAN_RATE_KNOB: str = "hebbian_rate"
|
||||
|
||||
|
||||
def _largest_cc(graph: "nx.Graph") -> "nx.Graph":
|
||||
"""Return the largest connected component as a copy.
|
||||
|
||||
NetworkX raises on disconnected inputs to ``average_shortest_path_length``;
|
||||
take the largest CC up front so the rest of fast_sigma can stay simple.
|
||||
"""
|
||||
if graph.number_of_nodes() == 0:
|
||||
return graph
|
||||
if nx.is_connected(graph):
|
||||
return graph
|
||||
largest = max(nx.connected_components(graph), key=len)
|
||||
return graph.subgraph(largest).copy()
|
||||
|
||||
|
||||
def fast_sigma(
|
||||
graph: "nx.Graph",
|
||||
*,
|
||||
n_random: int = 3,
|
||||
seed: int = 42,
|
||||
) -> tuple[float, float, float, float, float]:
|
||||
"""Humphries-Gurney 2008 sigma via single-reference random graph(s).
|
||||
|
||||
Returns ``(sigma, C, L, Cr, Lr)`` where:
|
||||
- sigma = (C / Cr) / (L / Lr)
|
||||
- C / L : clustering / characteristic path length on the input graph
|
||||
- Cr / Lr : same metrics averaged over ``n_random`` Erdos-Renyi G(n, m)
|
||||
reference graphs.
|
||||
|
||||
DO NOT use NetworkX's built-in small-worldness function -- it is
|
||||
empirically unusable at N>=200 (>60s timeout).
|
||||
This implementation builds ONE G(n, m) reference per seed and averages
|
||||
the C and L values, NOT the library's full edge-rewiring loop.
|
||||
|
||||
Pre-processing: when the input graph is disconnected, the largest
|
||||
connected component is taken first. NetworkX raises on disconnected
|
||||
inputs to ``average_shortest_path_length``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Returns NaN sigma when Cr or Lr collapses to zero (degenerate reference;
|
||||
shouldn't happen at our N>=200 floor but defensive).
|
||||
- Deterministic per ``seed`` -- the n_random reference graphs use
|
||||
``seed, seed+1, ..., seed+n_random-1``.
|
||||
"""
|
||||
g = _largest_cc(graph)
|
||||
n = g.number_of_nodes()
|
||||
m = g.number_of_edges()
|
||||
if n < 2 or m == 0:
|
||||
return (float("nan"), 0.0, 0.0, 0.0, 0.0)
|
||||
|
||||
C = float(nx.average_clustering(g))
|
||||
L = float(nx.average_shortest_path_length(g))
|
||||
|
||||
Cs: list[float] = []
|
||||
Ls: list[float] = []
|
||||
for k in range(max(1, n_random)):
|
||||
gr_full = nx.gnm_random_graph(n, m, seed=seed + k)
|
||||
# Same disconnected-graph guard for the reference.
|
||||
if not nx.is_connected(gr_full):
|
||||
largest = max(nx.connected_components(gr_full), key=len)
|
||||
gr = gr_full.subgraph(largest).copy()
|
||||
else:
|
||||
gr = gr_full
|
||||
if gr.number_of_nodes() < 2 or gr.number_of_edges() == 0:
|
||||
continue
|
||||
Cs.append(float(nx.average_clustering(gr)))
|
||||
Ls.append(float(nx.average_shortest_path_length(gr)))
|
||||
|
||||
if not Cs or not Ls:
|
||||
return (float("nan"), C, L, 0.0, 0.0)
|
||||
|
||||
Cr = sum(Cs) / len(Cs)
|
||||
Lr = sum(Ls) / len(Ls)
|
||||
if Cr <= 0 or Lr <= 0 or L <= 0:
|
||||
return (float("nan"), C, L, Cr, Lr)
|
||||
|
||||
sigma_val = (C / Cr) / (L / Lr)
|
||||
return (sigma_val, C, L, Cr, Lr)
|
||||
|
||||
|
||||
def compute_sigma(graph: "nx.Graph", *, seed: int = 42) -> Optional[float]:
|
||||
"""D-SIGMA-01: sigma at N>=SIGMA_N_FLOOR; otherwise None.
|
||||
|
||||
Returns None for graphs with fewer than SIGMA_N_FLOOR nodes -- below
|
||||
that threshold, the random-graph baselines are too noisy to interpret
|
||||
(Humphries-Gurney 2008).
|
||||
"""
|
||||
if graph.number_of_nodes() < SIGMA_N_FLOOR:
|
||||
return None
|
||||
sigma_val, *_ = fast_sigma(graph, seed=seed)
|
||||
if isinstance(sigma_val, float) and math.isnan(sigma_val):
|
||||
return None
|
||||
return float(sigma_val)
|
||||
|
||||
|
||||
def classify_regime(N: int, sigma: Optional[float]) -> str:
|
||||
"""Four-cell regime truth table (D-SIGMA-02 / D-SIGMA-03).
|
||||
|
||||
Returns one of:
|
||||
- "insufficient_data" : sigma is None (N < SIGMA_N_FLOOR)
|
||||
- "developmental" : N < SIGMA_MID_LIFE_THRESHOLD AND sigma < 1
|
||||
- "mid_life_drift" : N >= SIGMA_MID_LIFE_THRESHOLD AND sigma < 1
|
||||
- "healthy" : sigma >= 1 (any N >= floor)
|
||||
"""
|
||||
if sigma is None:
|
||||
return "insufficient_data"
|
||||
if isinstance(sigma, float) and math.isnan(sigma):
|
||||
return "insufficient_data"
|
||||
if sigma < 1.0:
|
||||
if N < SIGMA_MID_LIFE_THRESHOLD:
|
||||
return "developmental"
|
||||
return "mid_life_drift"
|
||||
return "healthy"
|
||||
|
||||
|
||||
def _coerce_to_nx_graph(graph_or_wrapper) -> "nx.Graph":
|
||||
"""Accept either a raw nx.Graph or our MemoryGraph wrapper.
|
||||
|
||||
MemoryGraph (src/iai_mcp/graph.py) carries the underlying nx.Graph as
|
||||
``_nx``. The CLI passes a MemoryGraph; tests / fast_sigma also accept raw
|
||||
nx.Graph for portability.
|
||||
"""
|
||||
if isinstance(graph_or_wrapper, nx.Graph):
|
||||
return graph_or_wrapper
|
||||
underlying = getattr(graph_or_wrapper, "_nx", None)
|
||||
if isinstance(underlying, nx.Graph):
|
||||
return underlying
|
||||
raise TypeError(
|
||||
f"expected nx.Graph or MemoryGraph wrapper, got {type(graph_or_wrapper).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def compute_topology_snapshot(graph) -> dict:
|
||||
"""Snapshot dict consumed by the topology CLI subcommand.
|
||||
|
||||
Returns: ``{C, L, sigma, community_count, rich_club_ratio, N, regime}``.
|
||||
|
||||
- C : average clustering on the largest connected component.
|
||||
- L : average shortest path length on the largest CC.
|
||||
- sigma : compute_sigma(graph) (None if N < SIGMA_N_FLOOR).
|
||||
- community_count : Leiden community count (CONN-01 reuse via
|
||||
community.detect_communities); uses an isolated MemoryGraph wrapper.
|
||||
- rich_club_ratio : len(rich_club_nodes) / N (CONN-02 reuse).
|
||||
- N : node count.
|
||||
- regime : classify_regime(N, sigma).
|
||||
"""
|
||||
nx_g = _coerce_to_nx_graph(graph)
|
||||
N = int(nx_g.number_of_nodes())
|
||||
|
||||
if N == 0:
|
||||
return {
|
||||
"C": 0.0, "L": 0.0, "sigma": None,
|
||||
"community_count": 0, "rich_club_ratio": 0.0,
|
||||
"N": 0, "regime": "insufficient_data",
|
||||
}
|
||||
|
||||
g_cc = _largest_cc(nx_g)
|
||||
try:
|
||||
C = float(nx.average_clustering(g_cc)) if g_cc.number_of_nodes() else 0.0
|
||||
except Exception:
|
||||
C = 0.0
|
||||
try:
|
||||
L = (
|
||||
float(nx.average_shortest_path_length(g_cc))
|
||||
if g_cc.number_of_nodes() >= 2 and g_cc.number_of_edges() > 0
|
||||
else 0.0
|
||||
)
|
||||
except Exception:
|
||||
L = 0.0
|
||||
|
||||
sigma_val = compute_sigma(nx_g)
|
||||
|
||||
# community_count + rich_club_ratio require the MemoryGraph wrapper.
|
||||
community_count = 0
|
||||
rich_club_ratio = 0.0
|
||||
try:
|
||||
from iai_mcp.community import detect_communities
|
||||
from iai_mcp.graph import MemoryGraph
|
||||
from iai_mcp.richclub import rich_club_nodes
|
||||
if isinstance(graph, MemoryGraph):
|
||||
mg = graph
|
||||
else:
|
||||
mg = None
|
||||
if mg is not None:
|
||||
try:
|
||||
assignment = detect_communities(mg, prior=None)
|
||||
community_count = int(len(assignment.community_centroids))
|
||||
except Exception:
|
||||
community_count = 0
|
||||
try:
|
||||
rc = rich_club_nodes(mg, percent=0.10)
|
||||
rich_club_ratio = (len(rc) / N) if N > 0 else 0.0
|
||||
except Exception:
|
||||
rich_club_ratio = 0.0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
regime = classify_regime(N, sigma_val)
|
||||
return {
|
||||
"C": C,
|
||||
"L": L,
|
||||
"sigma": sigma_val,
|
||||
"community_count": community_count,
|
||||
"rich_club_ratio": rich_club_ratio,
|
||||
"N": N,
|
||||
"regime": regime,
|
||||
}
|
||||
|
||||
|
||||
def _bump_hebbian_rate_developmental(store: "MemoryStore", N: int) -> None:
|
||||
"""Emit a profile_updated event marking the Hebbian-rate boost.
|
||||
|
||||
Per D-SIGMA-02 the developmental phase warrants a temporary
|
||||
Hebbian-rate boost. Rather than mutating the 10-knob AUTIST profile
|
||||
registry (which would violate len(PROFILE_KNOBS)==11), we record the
|
||||
intent as a profile_updated event with knob='hebbian_rate'. Downstream
|
||||
Hebbian write paths can read the most recent value and apply it.
|
||||
"""
|
||||
write_event(
|
||||
store,
|
||||
kind="profile_updated",
|
||||
data={
|
||||
"knob": HEBBIAN_RATE_KNOB,
|
||||
"old": 1.0,
|
||||
"new": HEBBIAN_DEVELOPMENTAL_BOOST_FACTOR,
|
||||
"ttl_sessions": HEBBIAN_DEVELOPMENTAL_BOOST_TTL_SESSIONS,
|
||||
"reason": "sigma_developmental_phase",
|
||||
"N": N,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
severity="info",
|
||||
)
|
||||
|
||||
|
||||
def compute_and_emit(store: "MemoryStore") -> dict:
|
||||
"""S4 offline-pass entry point: build runtime graph, snapshot, emit event.
|
||||
|
||||
Routes to the correct event kind based on the regime classification:
|
||||
- "developmental" -> kind=sigma_observation, data.phase="developmental",
|
||||
AND a profile_updated event for hebbian_rate boost.
|
||||
- "mid_life_drift" -> kind=sigma_drift, data with full snapshot.
|
||||
- "healthy" -> kind=sigma_observation, data.phase="healthy".
|
||||
- "insufficient_data" -> kind=sigma_observation, data.phase="insufficient_data".
|
||||
|
||||
NEVER toggles retrieval modes (constitutional guard).
|
||||
"""
|
||||
from iai_mcp import retrieve
|
||||
|
||||
graph_bundle = retrieve.build_runtime_graph(store)
|
||||
# build_runtime_graph returns (graph, assignment, rich_club).
|
||||
if isinstance(graph_bundle, tuple):
|
||||
graph = graph_bundle[0]
|
||||
else:
|
||||
graph = graph_bundle
|
||||
|
||||
snap = compute_topology_snapshot(graph)
|
||||
regime = snap.get("regime", "insufficient_data")
|
||||
|
||||
base_data = {
|
||||
"sigma": snap.get("sigma"),
|
||||
"N": snap.get("N", 0),
|
||||
"C": snap.get("C", 0.0),
|
||||
"L": snap.get("L", 0.0),
|
||||
"community_count": snap.get("community_count", 0),
|
||||
"rich_club_ratio": snap.get("rich_club_ratio", 0.0),
|
||||
"regime": regime,
|
||||
}
|
||||
|
||||
if regime == "mid_life_drift":
|
||||
write_event(
|
||||
store,
|
||||
kind=SIGMA_DRIFT_KIND,
|
||||
data={**base_data, "phase": "mid_life_drift"},
|
||||
severity="warning",
|
||||
)
|
||||
elif regime == "developmental":
|
||||
write_event(
|
||||
store,
|
||||
kind=SIGMA_OBSERVATION_KIND,
|
||||
data={**base_data, "phase": "developmental"},
|
||||
severity="info",
|
||||
)
|
||||
try:
|
||||
_bump_hebbian_rate_developmental(store, int(snap.get("N", 0)))
|
||||
except Exception:
|
||||
# Diagnostic only: never block the sigma observation on the
|
||||
# follow-up Hebbian boost.
|
||||
pass
|
||||
elif regime == "healthy":
|
||||
write_event(
|
||||
store,
|
||||
kind=SIGMA_OBSERVATION_KIND,
|
||||
data={**base_data, "phase": "healthy"},
|
||||
severity="info",
|
||||
)
|
||||
else: # insufficient_data
|
||||
write_event(
|
||||
store,
|
||||
kind=SIGMA_OBSERVATION_KIND,
|
||||
data={**base_data, "phase": "insufficient_data"},
|
||||
severity="info",
|
||||
)
|
||||
|
||||
return snap
|
||||
610
src/iai_mcp/sleep.py
Normal file
610
src/iai_mcp/sleep.py
Normal file
|
|
@ -0,0 +1,610 @@
|
|||
"""CLS sleep-cycle replay (MEM-07, D-16, D-19, D-29).
|
||||
|
||||
Two phases (dual-tier per D-16):
|
||||
|
||||
- `run_light_consolidation` -- runs at every session_exit. Pure-local. NO LLM.
|
||||
FSRS tick on recently-recalled records. Sub-second. Always on.
|
||||
|
||||
- `run_heavy_consolidation` -- runs inside quiet window OR via MANUAL trigger
|
||||
(memory_consolidate MCP tool). D-GUARD ladder gates any Tier-1 LLM path via
|
||||
`should_call_llm`; Tier-0 fallback is ALWAYS present (TF-IDF + cooccurrence
|
||||
summarisation). Creates `consolidated_from` edges linking semantic summary
|
||||
records to their source episodes. Runs FSRS edge decay sweep. Logs
|
||||
`cls_consolidation_run` event with mode=heavy, tier=tier0|tier1.
|
||||
|
||||
D-16 scheduler (`should_run_heavy`):
|
||||
- ACTIVITY (default): idle>=30min AND local time in quiet_window.
|
||||
- TIME: strict cron at hour==3 local.
|
||||
- MANUAL: never fires automatically.
|
||||
- 48h max defer: if idle >= max_defer_hours, force-run regardless of window.
|
||||
|
||||
D-19 decay sweep (`_decay_edges`):
|
||||
- Only hebbian edges are decayed. contradicts / invariant_anchor /
|
||||
consolidated_from / schema_instance_of / temporal_next / curiosity_bridge /
|
||||
profile_modulates all survive forever (by design).
|
||||
- Edges > 90d stale: weight *= 0.9 ** (days - 90); prune if < ε (default 0.01).
|
||||
|
||||
D-29 unification: heavy cycle drives FSRS decay + CLS summarisation +
|
||||
schema-candidate surfacing in a single pass -- no duplicated IO.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from itertools import combinations
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from iai_mcp.aaak import enforce_language_tagged, generate_aaak_index
|
||||
from iai_mcp.events import write_event
|
||||
from iai_mcp.guard import BudgetLedger, RateLimitLedger, should_call_llm
|
||||
from iai_mcp.store import EDGES_TABLE, MemoryStore, _uuid_literal
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- constants
|
||||
|
||||
|
||||
class SleepMode(str, Enum):
|
||||
"""D-16 trigger mode for heavy consolidation."""
|
||||
|
||||
ACTIVITY = "activity" # Idle-triggered (default). 30min idle + quiet window.
|
||||
TIME = "time" # Strict cron at hour==3 local.
|
||||
MANUAL = "manual" # Only via memory_consolidate tool.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SleepConfig:
|
||||
"""User-configurable sleep-cycle schedule knobs (D-16)."""
|
||||
|
||||
mode: SleepMode = SleepMode.ACTIVITY
|
||||
quiet_window: tuple[int, int] = (22, 6) # local-hour start..end (wrap-around)
|
||||
require_idle_minutes: int = 30
|
||||
max_defer_hours: int = 48
|
||||
on_user_resume: str = "defer_remaining"
|
||||
light_on_exit: bool = True
|
||||
llm_enabled: bool = False # Tier 0 default -- D-GUARD ladder step 1
|
||||
llm_tier: int = 1 # 1=Haiku-Batch, 2=Sonnet/Opus
|
||||
|
||||
|
||||
DECAY_EPSILON: float = 0.01 # prune threshold
|
||||
DECAY_GRACE_DAYS: int = 90 # no decay for edges <=90d old
|
||||
DECAY_BASE: float = 0.9 # weight *= 0.9^(days-90)
|
||||
FSRS_STABILITY_BOOST: float = 0.2 # simple per-recall linear boost
|
||||
CLUSTER_MIN_SIZE: int = 3 # CLS cluster threshold
|
||||
# H-03: Hebbian LTP increment applied to existing edges between
|
||||
# co-cluster members during heavy consolidation. Mirrors the LTD side (DECAY_*)
|
||||
# so the graph strengthens frequently-co-retrieved associations during sleep,
|
||||
# not only during explicit user-session pipeline_recall. Conservative delta --
|
||||
# 10 consolidations bring a fresh edge from 0.05 to ~0.5 stable.
|
||||
HEAVY_LTP_DELTA: float = 0.05
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- scheduler
|
||||
|
||||
|
||||
def should_run_heavy(
|
||||
now_utc: datetime,
|
||||
last_activity_utc: datetime,
|
||||
config: SleepConfig,
|
||||
tz: ZoneInfo,
|
||||
) -> tuple[bool, str]:
|
||||
"""D-16 trigger evaluator.
|
||||
|
||||
Returns (ok, reason). reason is "" on success, a short diagnostic otherwise.
|
||||
|
||||
The 48h deadline (config.max_defer_hours) overrides MANUAL, TIME, and
|
||||
ACTIVITY path-gates -- if the user has ignored the brain for 48h, we MUST
|
||||
consolidate before the next session starts. This is a cybernetic S4
|
||||
viability requirement (Beer VSM + Ashby ultrastability).
|
||||
"""
|
||||
idle_minutes = (now_utc - last_activity_utc).total_seconds() / 60.0
|
||||
|
||||
# 48h force-run. Precedes MANUAL so a stuck manual-only deployment still
|
||||
# gets periodic consolidation.
|
||||
if idle_minutes >= config.max_defer_hours * 60:
|
||||
return True, f"max_defer_hours ({config.max_defer_hours}h) exceeded"
|
||||
|
||||
if config.mode == SleepMode.MANUAL:
|
||||
return False, "manual-only mode"
|
||||
|
||||
if config.mode == SleepMode.TIME:
|
||||
local = now_utc.astimezone(tz)
|
||||
ok = local.hour == 3
|
||||
return ok, f"TIME mode, local hour={local.hour}"
|
||||
|
||||
# ACTIVITY mode from here on.
|
||||
if idle_minutes < config.require_idle_minutes:
|
||||
return False, f"idle < {config.require_idle_minutes}min"
|
||||
|
||||
local = now_utc.astimezone(tz)
|
||||
start_h, end_h = config.quiet_window
|
||||
# Wrap-around window support: (22, 6) means 22-23 OR 0-5.
|
||||
if start_h > end_h:
|
||||
in_window = (local.hour >= start_h) or (local.hour < end_h)
|
||||
else:
|
||||
in_window = start_h <= local.hour < end_h
|
||||
if not in_window:
|
||||
return False, (
|
||||
f"outside quiet window {config.quiet_window}, "
|
||||
f"local hour={local.hour}"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- FSRS bits
|
||||
|
||||
|
||||
def _apply_fsrs(record: MemoryRecord, now: datetime) -> MemoryRecord:
|
||||
"""Simple FSRS-inspired stability boost for recently-recalled records.
|
||||
|
||||
scope: linear +0.2 per recall, capped at 1.0. Full FSRS (Woz et al
|
||||
2022) with per-difficulty retrievability modelling is Phase 3.
|
||||
"""
|
||||
if record.never_decay:
|
||||
return record
|
||||
record.stability = min(1.0, record.stability + FSRS_STABILITY_BOOST)
|
||||
record.last_reviewed = now
|
||||
return record
|
||||
|
||||
|
||||
def _decay_edges(
|
||||
store: MemoryStore, epsilon: float = DECAY_EPSILON,
|
||||
) -> dict:
|
||||
"""D-19 nightly sweep: decay stale hebbian + hebbian_structure edges, prune below e.
|
||||
|
||||
CONN-05 D-TEM-04 extension: structure-edge LTP from
|
||||
hebbian_structure.strengthen_structure_edge decays under the SAME formula
|
||||
and grace period as content-edge hebbian (constitutional contract: FSRS
|
||||
decay on structure edges is IDENTICAL to record-edge decay).
|
||||
|
||||
Other edge types (contradicts, invariant_anchor, consolidated_from,
|
||||
schema_instance_of, temporal_next, curiosity_bridge, profile_modulates)
|
||||
survive forever.
|
||||
"""
|
||||
tbl = store.db.open_table(EDGES_TABLE)
|
||||
df = tbl.to_pandas()
|
||||
if df.empty:
|
||||
return {"decayed": 0, "pruned": 0}
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
decayed = 0
|
||||
pruned = 0
|
||||
|
||||
# include hebbian_structure in the sweep with identical formula.
|
||||
decayable_kinds = ("hebbian", "hebbian_structure")
|
||||
hebbian = df[df["edge_type"].isin(decayable_kinds)]
|
||||
for _, row in hebbian.iterrows():
|
||||
# CR-01: per-row try/except ValueError so one poisoned row
|
||||
# cannot kill the entire sweep. _uuid_literal raises ValueError on any
|
||||
# non-RFC-4122 UUID string, preventing SQL predicate injection via a
|
||||
# corrupt or adversarial `src`/`dst` value.
|
||||
try:
|
||||
last = row["updated_at"]
|
||||
if last is None:
|
||||
continue
|
||||
# Coerce naive -> UTC; pandas may drop tz on some backends.
|
||||
try:
|
||||
py = last.to_pydatetime() if hasattr(last, "to_pydatetime") else last
|
||||
except Exception:
|
||||
py = last
|
||||
if getattr(py, "tzinfo", None) is None:
|
||||
py = py.replace(tzinfo=timezone.utc)
|
||||
|
||||
days = (now - py).total_seconds() / 86400.0
|
||||
if days <= DECAY_GRACE_DAYS:
|
||||
continue
|
||||
|
||||
new_weight = float(row["weight"]) * (DECAY_BASE ** (days - DECAY_GRACE_DAYS))
|
||||
|
||||
# CR-01 fix: reject non-canonical UUID values BEFORE interpolation.
|
||||
src_lit = _uuid_literal(row["src"])
|
||||
dst_lit = _uuid_literal(row["dst"])
|
||||
edge_kind = str(row["edge_type"])
|
||||
if edge_kind not in decayable_kinds:
|
||||
# Belt-and-braces: should not happen given the .isin() above.
|
||||
continue
|
||||
if new_weight < epsilon:
|
||||
tbl.delete(
|
||||
f"src = '{src_lit}' AND dst = '{dst_lit}' "
|
||||
f"AND edge_type = '{edge_kind}'"
|
||||
)
|
||||
pruned += 1
|
||||
else:
|
||||
tbl.update(
|
||||
where=(
|
||||
f"src = '{src_lit}' AND dst = '{dst_lit}' "
|
||||
f"AND edge_type = '{edge_kind}'"
|
||||
),
|
||||
values={
|
||||
"weight": float(new_weight),
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
decayed += 1
|
||||
except ValueError:
|
||||
# Poisoned UUID shape -- skip this row, continue the sweep.
|
||||
continue
|
||||
|
||||
return {"decayed": decayed, "pruned": pruned}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- light phase
|
||||
|
||||
|
||||
def run_light_consolidation(
|
||||
store: MemoryStore, session_id: str,
|
||||
) -> dict:
|
||||
"""D-16 light phase -- always on, pure local, no LLM.
|
||||
|
||||
Runs at every session_exit. Nudges FSRS stability on records that were
|
||||
recalled in this session (identified by fresh provenance entry within the
|
||||
last hour). Writes one `cls_consolidation_run` event with mode=light.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
records = store.all_records()
|
||||
fsrs_ticked = 0
|
||||
|
||||
for r in records:
|
||||
if r.never_decay:
|
||||
continue
|
||||
if not r.provenance:
|
||||
continue
|
||||
last_prov = r.provenance[-1]
|
||||
try:
|
||||
ts_str = last_prov.get("ts", "")
|
||||
prov_ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||||
if prov_ts.tzinfo is None:
|
||||
prov_ts = prov_ts.replace(tzinfo=timezone.utc)
|
||||
# Only tick records recalled within the last hour.
|
||||
if (now - prov_ts).total_seconds() < 3600:
|
||||
_apply_fsrs(r, now)
|
||||
# H-01 fix: persist the FSRS mutation so stability
|
||||
# and last_reviewed survive process restart. update_record
|
||||
# rewrites only the FSRS-relevant columns -- embedding,
|
||||
# provenance, tags etc. are left intact.
|
||||
store.update_record(r)
|
||||
fsrs_ticked += 1
|
||||
except Exception:
|
||||
# Provenance ts malformed -- ignore that record, don't fail the sweep.
|
||||
continue
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="cls_consolidation_run",
|
||||
data={
|
||||
"mode": "light",
|
||||
"fsrs_ticked": fsrs_ticked,
|
||||
"record_count": len(records),
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
)
|
||||
return {
|
||||
"mode": "light",
|
||||
"fsrs_ticked": fsrs_ticked,
|
||||
"cooccurrence_updates": 0, # populates real cooccurrence counts.
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- heavy phase
|
||||
|
||||
|
||||
def _build_hebbian_clusters(store: MemoryStore) -> list[list[UUID]]:
|
||||
"""Find connected components in the hebbian edge graph with size >= CLUSTER_MIN_SIZE."""
|
||||
edges_df = store.db.open_table(EDGES_TABLE).to_pandas()
|
||||
if edges_df.empty:
|
||||
return []
|
||||
hebbian = edges_df[edges_df["edge_type"] == "hebbian"]
|
||||
if hebbian.empty:
|
||||
return []
|
||||
|
||||
adj: dict[UUID, set[UUID]] = {}
|
||||
for _, row in hebbian.iterrows():
|
||||
src = UUID(row["src"])
|
||||
dst = UUID(row["dst"])
|
||||
adj.setdefault(src, set()).add(dst)
|
||||
adj.setdefault(dst, set()).add(src)
|
||||
|
||||
visited: set[UUID] = set()
|
||||
clusters: list[list[UUID]] = []
|
||||
for node in list(adj.keys()):
|
||||
if node in visited:
|
||||
continue
|
||||
stack = [node]
|
||||
component: list[UUID] = []
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
if cur in visited:
|
||||
continue
|
||||
visited.add(cur)
|
||||
component.append(cur)
|
||||
for neigh in adj.get(cur, set()):
|
||||
if neigh not in visited:
|
||||
stack.append(neigh)
|
||||
if len(component) >= CLUSTER_MIN_SIZE:
|
||||
clusters.append(component)
|
||||
return clusters
|
||||
|
||||
|
||||
def _tier0_schema_surfacing(store: MemoryStore) -> list[dict]:
|
||||
"""Tier-0 fallback schema candidate surfacing: tags appearing in >=3 records.
|
||||
|
||||
Plan 02-03's LEARN-03 schema induction consumes these candidates.
|
||||
|
||||
W3: rewritten on ``store.iter_record_columns(["tags_json"])``.
|
||||
No more full-store load + full-record decrypt -- only the ``tags_json`` column
|
||||
is read from disk; encrypted columns (literal_surface, provenance_json,
|
||||
profile_modulation_gain_json) are NEVER touched on this path. Saves ~16210
|
||||
AES-GCM operations + ~14.5 MB literal_surface materialisation + ~2.4 MB
|
||||
provenance_json materialisation + ~11.9 MB embedding materialisation per
|
||||
invocation on a production-scale store.
|
||||
"""
|
||||
tag_counts: dict[str, int] = {}
|
||||
record_count = 0
|
||||
for row in store.iter_record_columns(["tags_json"], batch_size=1024):
|
||||
record_count += 1
|
||||
tags_raw = row.get("tags_json") or "[]"
|
||||
try:
|
||||
tags = json.loads(tags_raw) if tags_raw else []
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
tags = []
|
||||
for t in tags:
|
||||
# Skip language-qualifying raw:* and domain:* tags -- those are
|
||||
# classification metadata, not schema-candidate signals.
|
||||
if t.startswith("raw:") or t.startswith("domain:"):
|
||||
continue
|
||||
tag_counts[t] = tag_counts.get(t, 0) + 1
|
||||
if record_count < CLUSTER_MIN_SIZE:
|
||||
return []
|
||||
candidates: list[dict] = []
|
||||
for tag, count in tag_counts.items():
|
||||
if count >= 3:
|
||||
candidates.append(
|
||||
{
|
||||
"pattern": f"tag:{tag}",
|
||||
"confidence": min(1.0, count / 10.0),
|
||||
"evidence_count": count,
|
||||
}
|
||||
)
|
||||
return candidates
|
||||
|
||||
|
||||
def _create_semantic_summary(
|
||||
store: MemoryStore,
|
||||
cluster: list[MemoryRecord],
|
||||
summary_text: str,
|
||||
language: str,
|
||||
) -> UUID:
|
||||
"""Insert one semantic summary record + a consolidated_from edge to each source.
|
||||
|
||||
summary inherits dominant language of the source cluster.
|
||||
detail_level=3 -> never_decay=True (auto-enforced by __post_init__).
|
||||
"""
|
||||
# Lazy import -- embedder load is heavy; only needed when we actually summarise.
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
|
||||
emb = embedder_for_store(store).embed(summary_text)
|
||||
now = datetime.now(timezone.utc)
|
||||
summary_id = uuid4()
|
||||
summary = MemoryRecord(
|
||||
id=summary_id,
|
||||
tier="semantic",
|
||||
literal_surface=summary_text,
|
||||
aaak_index="",
|
||||
embedding=emb,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=3, # semantic summaries protected from decay
|
||||
pinned=False,
|
||||
stability=0.5,
|
||||
difficulty=0.3,
|
||||
last_reviewed=now,
|
||||
never_decay=True,
|
||||
never_merge=False,
|
||||
provenance=[
|
||||
{
|
||||
"ts": now.isoformat(),
|
||||
"cue": "cls_consolidation",
|
||||
"session_id": "system",
|
||||
}
|
||||
],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["semantic", "cls_summary"],
|
||||
language=language,
|
||||
)
|
||||
enforce_language_tagged(summary, detect=False)
|
||||
summary.aaak_index = generate_aaak_index(summary)
|
||||
store.insert(summary)
|
||||
|
||||
# R3: batch all consolidated_from edges into a single
|
||||
# boost_edges call (one merge_insert + one tbl.add at most). Previously
|
||||
# this loop emitted N Lance versions on edges.lance for an N-source
|
||||
# cluster.
|
||||
pairs = [(summary_id, source.id) for source in cluster]
|
||||
if pairs:
|
||||
store.boost_edges(
|
||||
pairs,
|
||||
edge_type="consolidated_from",
|
||||
delta=1.0,
|
||||
)
|
||||
return summary_id
|
||||
|
||||
|
||||
def run_heavy_consolidation(
|
||||
store: MemoryStore,
|
||||
session_id: str,
|
||||
config: SleepConfig,
|
||||
budget: BudgetLedger,
|
||||
rate: RateLimitLedger,
|
||||
has_api_key: bool = False,
|
||||
) -> dict:
|
||||
"""D-16 heavy phase -- cluster-find, summarise, decay-sweep, schema-surface.
|
||||
|
||||
D-GUARD: the Tier-1 gate is consulted at the top of the function. If
|
||||
`should_call_llm` returns False for any reason (llm_enabled=false, no API
|
||||
key, budget exceeded, ratelimit cooldown), the entire cycle falls back to
|
||||
Tier 0 -- local heuristic summarisation, zero network I/O. This is the
|
||||
constitutional guarantee (D-GUARD): every LLM-dependent path
|
||||
must degrade gracefully.
|
||||
|
||||
Returns a dict with:
|
||||
mode: "heavy"
|
||||
tier: "tier0" | "tier1"
|
||||
summaries_created: int
|
||||
decay_result: {"decayed": int, "pruned": int}
|
||||
schema_candidates: list[dict]
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: FSRS edge decay sweep (runs regardless of tier).
|
||||
decay_result = _decay_edges(store)
|
||||
|
||||
# Step 2: Decide Tier 0 vs Tier 1. This is consulted BEFORE any API call;
|
||||
# even if Tier 1 is allowed, Plan 02-02's scope is Tier 0 summarisation
|
||||
# only. adds the actual Haiku Batch API call. The gate is here
|
||||
# so the event log reflects what WOULD have happened had Tier 1 been
|
||||
# implemented.
|
||||
llm_ok, _llm_reason = should_call_llm(
|
||||
budget=budget,
|
||||
rate=rate,
|
||||
llm_enabled=config.llm_enabled,
|
||||
has_api_key=has_api_key,
|
||||
)
|
||||
tier = "tier1" if llm_ok else "tier0"
|
||||
# flips the Tier-1 switch by wiring the Batch API. The
|
||||
# gate is re-checked inside batch.submit_batch_consolidation so event
|
||||
# ordering matches prior plans. Tier-0 fallback remains unchanged.
|
||||
effective_tier = "tier0"
|
||||
batch_submitted = False
|
||||
if llm_ok:
|
||||
try:
|
||||
from iai_mcp.batch import submit_batch_consolidation
|
||||
|
||||
# Summarise the workload before submission. scope:
|
||||
# the real cluster/schema task payload is populated post-hoc by
|
||||
# Phase 3; for now we submit placeholder tasks so the D-GUARD
|
||||
# side-effects (budget spend + events) fire on the correct path.
|
||||
tasks: list[dict] = [
|
||||
{
|
||||
"task_id": f"sleep_cycle:{session_id}",
|
||||
"prompt": "CLS consolidation batch",
|
||||
"prompt_tok": 500,
|
||||
"output_tok": 200,
|
||||
}
|
||||
]
|
||||
ok_batch, _reason_batch, _results = submit_batch_consolidation(
|
||||
store, tasks, budget, rate,
|
||||
llm_enabled=config.llm_enabled,
|
||||
)
|
||||
if ok_batch:
|
||||
effective_tier = "tier1"
|
||||
batch_submitted = True
|
||||
except Exception as _exc:
|
||||
# Never block the Tier-0 fallback on batch errors.
|
||||
effective_tier = "tier0"
|
||||
|
||||
# Step 3: cluster-find + summarise.
|
||||
clusters = _build_hebbian_clusters(store)
|
||||
# Phase 07.7-04 W4 (D-13/D-14/D-20 + amendment): single-materialisation
|
||||
# invariant. After Plan 07.7-03 W3 rewrites _tier0_schema_surfacing on
|
||||
# iter_record_columns and Plan 07.7-04 D-26-A/B migrate schema.py
|
||||
# induce_schemas_tier0 + persist_schema to iter_record_columns, this is
|
||||
# the ONLY all_records() call left inside run_heavy_consolidation. The
|
||||
# cluster-lookup primitive choice (switch this site to iter_records or
|
||||
# per-id store.get) is DEFERRED to with the rest of W6
|
||||
# (D-20 deferred). Regression test:
|
||||
# tests/test_sleep_consolidation_streaming.py
|
||||
# ::test_run_heavy_consolidation_calls_all_records_at_most_once
|
||||
records_by_id = {r.id: r for r in store.all_records()}
|
||||
summaries_created = 0
|
||||
for cluster_ids in clusters:
|
||||
cluster_recs = [records_by_id[i] for i in cluster_ids if i in records_by_id]
|
||||
if len(cluster_recs) < CLUSTER_MIN_SIZE:
|
||||
continue
|
||||
# Dominant language vote among cluster members.
|
||||
langs = [r.language for r in cluster_recs if r.language]
|
||||
dom_lang = max(set(langs), key=langs.count) if langs else "en"
|
||||
# Tier-0 summary format: concatenated prefixes of cluster literals,
|
||||
# capped at 80 chars each + 5 members -- keeps the summary short and
|
||||
# keeps promises clean (summary is NEW content, sources intact).
|
||||
summary_text = (
|
||||
f"Cluster summary ({len(cluster_recs)} records, lang={dom_lang}): "
|
||||
+ "; ".join(r.literal_surface[:80] for r in cluster_recs[:5])
|
||||
)
|
||||
_create_semantic_summary(store, cluster_recs, summary_text, dom_lang)
|
||||
summaries_created += 1
|
||||
|
||||
# H-03: Hebbian LTP -- strengthen existing hebbian edges
|
||||
# between co-cluster members. Mirrors the LTD (_decay_edges) side so
|
||||
# the graph is not one-sided. Matches Woz 2022 SRS reinforcement on
|
||||
# co-retrieval. O(k^2) per cluster where k = cluster size; bounded by
|
||||
# the connected-components partition of hebbian adjacency.
|
||||
pairs_to_boost = list(combinations(cluster_ids, 2))
|
||||
if pairs_to_boost:
|
||||
store.boost_edges(
|
||||
pairs_to_boost,
|
||||
delta=HEAVY_LTP_DELTA,
|
||||
edge_type="hebbian",
|
||||
)
|
||||
|
||||
# Step 4: Tier-0 schema candidate surfacing.
|
||||
schemas = _tier0_schema_surfacing(store)
|
||||
|
||||
# Step 4b (Plan 02-03 LEARN-03 primary): schema induction batch run.
|
||||
# Tier-1 attempts the Haiku path via D-GUARD ladder; falls back to tier0.
|
||||
# auto-status candidates are persisted (creating schema_instance_of edges).
|
||||
schemas_induced = 0
|
||||
try:
|
||||
from iai_mcp.schema import (
|
||||
induce_schemas_tier1,
|
||||
persist_schema,
|
||||
)
|
||||
|
||||
candidates = induce_schemas_tier1(
|
||||
store, budget=budget, rate=rate,
|
||||
llm_enabled=config.llm_enabled,
|
||||
)
|
||||
for cand in candidates:
|
||||
if cand.status == "auto":
|
||||
persist_schema(store, cand)
|
||||
schemas_induced += 1
|
||||
# pending_user_approval candidates are only logged (via
|
||||
# induce_schemas_tier1's llm_health emission path).
|
||||
except Exception as exc:
|
||||
write_event(
|
||||
store,
|
||||
kind="schema_induction_run",
|
||||
data={"error": str(exc), "status": "failed"},
|
||||
severity="warning",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
write_event(
|
||||
store,
|
||||
kind="cls_consolidation_run",
|
||||
data={
|
||||
"mode": "heavy",
|
||||
"tier": effective_tier,
|
||||
"tier_eligible": tier,
|
||||
"summaries_created": summaries_created,
|
||||
"decay_result": decay_result,
|
||||
"schema_candidates": len(schemas),
|
||||
"schemas_induced": schemas_induced,
|
||||
"batch_submitted": batch_submitted,
|
||||
},
|
||||
severity="info",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"mode": "heavy",
|
||||
"tier": effective_tier,
|
||||
"summaries_created": summaries_created,
|
||||
"decay_result": decay_result,
|
||||
"schema_candidates": schemas,
|
||||
"schemas_induced": schemas_induced,
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue