Merge remote-tracking branch 'upstream/dev' into fix/backend-tests

This commit is contained in:
Anish Sarkar 2026-05-16 19:40:01 +05:30
commit 8de7d86d56
603 changed files with 45074 additions and 4695 deletions

View file

@ -1,48 +1,48 @@
# Backend E2E Test Harness
# Backend E2E Harness
Strict fakes + alternative entrypoints used **only** by Playwright E2E.
Excluded from the production Docker image via `.dockerignore`.
This directory contains the test-only backend entrypoints and fakes used by
Playwright. They are not part of the production image: `.dockerignore` excludes
`tests/`, and the E2E Docker stage copies this directory through a separate
build context.
## Files
| Path | Role |
| -------------------------------- | ------------------------------------------------------------------------------- |
| `run_backend.py` | FastAPI entrypoint that hijacks `sys.modules` before importing `app.app:app` |
| `run_celery.py` | Celery worker entrypoint with the same hijack + patch logic |
| `middleware/scenario.py` | `X-E2E-Scenario` header → ContextVar (read by fakes) |
| `fakes/composio_module.py` | Strict drop-in for the `composio` package; raises on unknown surface |
| `fakes/llm.py` | `fake_get_user_long_context_llm` returning a `FakeListChatModel` |
| `fakes/embeddings.py` | Deterministic 0.1-vector `embed_text` / `embed_texts` |
| `fakes/fixtures/drive_files.json`| Canned Drive listings + file contents (incl. canary tokens) |
| Path | Purpose |
| --- | --- |
| `run_backend.py` | Starts FastAPI after installing the test fakes into `sys.modules`. |
| `run_celery.py` | Starts the Celery worker with the same fake setup. |
| `middleware/scenario.py` | Reads `X-E2E-Scenario` into a request-scoped context var. |
| `fakes/composio_module.py` | Fake `composio` package used by connector flows. |
| `fakes/llm.py` | Fake chat model factory. |
| `fakes/embeddings.py` | Deterministic embedding helpers. |
| `fakes/fixtures/drive_files.json` | Drive fixture data and canary file contents. |
## Why a sys.modules hijack?
## Why the import hook exists
Production code does `from composio import Composio` at module load
time. By the time the FastAPI app object exists, that binding has
already been resolved. The hijack runs **before** any `app.*` import,
so the binding resolves to our strict fake. No production source
changes; fakes are physically excluded from production images.
Some production modules import SDK clients at module load time, for example
`from composio import Composio`. By the time `app.app` has been imported, those
bindings are already fixed.
Belt + suspenders + no internet: the strict `__getattr__` in every
fake raises `NotImplementedError` if a future production code path
introduces a new SDK call. CI also sets `HTTPS_PROXY=http://127.0.0.1:1`
plus sentinel API keys so any leaked outbound HTTP fails immediately.
The E2E entrypoints install fake modules in `sys.modules` before importing any
`app.*` module. That lets the normal production code run while SDK calls resolve
to local fakes.
## Adding a new fake
The fakes should fail loudly. If production starts using a new SDK method that
the fake does not implement, add that method to the fake instead of letting the
test call the real service.
1. Create `fakes/<sdk>_module.py` modelled on `composio_module.py`.
2. In `run_backend.py` and `run_celery.py`, register
`sys.modules["<sdk>"] = _fake_<sdk>` before the `from app.app import app`
line.
3. If the new fake needs scenario branching, read from
## Adding a fake
1. Add `fakes/<sdk>_module.py`.
2. Register it in both `run_backend.py` and `run_celery.py` before importing
`app.app` or `app.celery_app`.
3. If the fake needs per-test behavior, read the current scenario from
`tests.e2e.middleware.scenario.current_scenario()`.
## Reused by backend integration tests
## Shared with backend integration tests
The strict fakes are not only for Playwright. Backend route integration
tests can import the same fake before importing `app.app`, so Composio
route tests exercise production route code without touching the real
SDK:
Backend integration tests can use the same fakes when they need production route
code without the real SDK:
```python
from tests.e2e.fakes import composio_module as _fake_composio
@ -50,20 +50,93 @@ sys.modules["composio"] = _fake_composio
from app.app import app
```
See `surfsense_backend/tests/integration/composio/conftest.py` for the
current pattern.
See `surfsense_backend/tests/integration/composio/conftest.py` for the current
pattern.
## Running locally
The recommended local flow runs only Postgres and Redis in Docker, and the
backend + Celery worker on the host. No `.env` file is required: both
entrypoints `setdefault` every variable they need (DB URL, Redis URL,
sentinel API keys, etc.) to values that match `docker-compose.deps-only.yml`.
### One-time setup
From `surfsense_web/`:
```bash
cd surfsense_backend
pnpm install
pnpm exec playwright install --with-deps chromium
```
### Each run
**1. Bring up Postgres + Redis** from the repo root (the other deps-only
services (SearXNG, Zero, pgAdmin) are not needed for E2E):
```bash
docker compose -f docker/docker-compose.deps-only.yml up -d db redis
```
**2. Start the backend** in `surfsense_backend/`, terminal A:
```bash
uv sync
uv run alembic upgrade head
uv run python tests/e2e/run_backend.py
# in a second shell:
```
**3. Start the Celery worker** in `surfsense_backend/`, terminal B:
```bash
uv run python tests/e2e/run_celery.py
```
Then in `surfsense_web`:
**4. Register the Playwright user**:
```bash
pnpm test:e2e
curl -X POST http://localhost:8000/auth/register \
-H "Content-Type: application/json" \
-d '{"email":"e2e-test@surfsense.net","password":"E2eTestPassword123!"}'
```
**5. Run Playwright** from `surfsense_web/`, terminal C:
```bash
pnpm test:e2e # dev server (fast iteration)
pnpm test:e2e:headed # show the browser
pnpm test:e2e:ui # Playwright UI mode
pnpm test:e2e:prod # build + start (matches CI exactly)
```
`playwright.config.ts` and the run scripts share defaults, so this works on a
fresh checkout. Set `PLAYWRIGHT_TEST_EMAIL`, `PLAYWRIGHT_TEST_PASSWORD`,
`NEXT_PUBLIC_FASTAPI_BACKEND_URL`, or any backend env (e.g. `DATABASE_URL`)
only when pointing tests at a different stack.
### Cleanup
```bash
docker compose -f docker/docker-compose.deps-only.yml down
```
Add `-v` to also wipe the Postgres volume.
### Hermetic alternative (matches CI)
To reproduce the CI environment exactly — backend and Celery in containers,
network egress denied at L3 — replace steps 13 with:
```bash
docker compose -f docker/docker-compose.e2e.yml up -d --build --wait
```
Then run steps 4 (curl register) and 5 (`pnpm test:e2e:prod`) as above. Tear
down with:
```bash
docker compose -f docker/docker-compose.e2e.yml down -v --remove-orphans
```
This builds the ~9 GB `surfsense-e2e-backend:local` image, so the deps-only
flow above is faster for day-to-day development.

View file

@ -0,0 +1,66 @@
"""Test-only token mint endpoint for the E2E backend entrypoint.
Mounted by ``tests/e2e/run_backend.py`` so Playwright can authenticate
the seeded e2e user without hitting ``/auth/jwt/login`` (rate-limited
to 5/min/IP in production). NEVER ships to production: this whole
``tests/`` tree is excluded from the production Docker image by
``surfsense_backend/.dockerignore``.
Authn: shared secret in ``X-E2E-Mint-Secret``. Same value is set on the
backend container env (``docker/docker-compose.e2e.yml``) and exported
to the Playwright runner (``.github/workflows/e2e-tests.yml``).
"""
from __future__ import annotations
import logging
import os
from fastapi import APIRouter, FastAPI, Header, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from app.db import User, async_session_maker
from app.users import get_jwt_strategy
_logger = logging.getLogger("surfsense.e2e.auth_mint")
class MintRequest(BaseModel):
email: str = "e2e-test@surfsense.net"
class MintResponse(BaseModel):
access_token: str
token_type: str = "bearer"
def _expected_secret() -> str:
return os.environ.get("E2E_MINT_SECRET", "local-e2e-mint-secret-not-for-production")
router = APIRouter(prefix="/__e2e__", tags=["__e2e__"])
@router.post("/auth/token", response_model=MintResponse)
async def mint_test_token(
body: MintRequest,
x_e2e_mint_secret: str = Header(..., alias="X-E2E-Mint-Secret"),
) -> MintResponse:
if x_e2e_mint_secret != _expected_secret():
raise HTTPException(status_code=403, detail="invalid e2e mint secret")
async with async_session_maker() as session:
result = await session.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=404, detail=f"e2e user {body.email!r} not seeded"
)
token = await get_jwt_strategy().write_token(user)
return MintResponse(access_token=token)
def install(app: FastAPI) -> None:
"""Mount the test-only mint router onto the given FastAPI app."""
app.include_router(router)
_logger.warning("[e2e] mounted POST /__e2e__/auth/token (test-only token mint)")

View file

@ -0,0 +1,141 @@
"""Stub DoclingService.process_document for E2E.
The real ``DoclingService.process_document`` calls
``DocumentConverter.convert(file_path)`` which lazily downloads the
``docling-project/docling-layout-heron`` model from Hugging Face Hub.
The hermetic E2E container sets ``HF_HUB_OFFLINE=1`` (see
``docker/docker-compose.e2e.yml``), so that download fails with
``LocalEntryNotFoundError`` and the indexing Celery task retries until
the Playwright test hits its ~4-minute step timeout. In CI that is the
difference between the suite finishing and the 30-minute job timeout
killing the run before any report can upload.
Stubbing ``process_document`` bypasses ``DocumentConverter.convert()``
entirely. ``DoclingService.__init__`` is intentionally left untouched
because constructing ``DocumentConverter(...)`` is cheap and offline
it is only ``.convert()`` that triggers the offline-model download.
Every canary PDF under ``tests/e2e/fakes/fixtures/binary/`` is produced
by ``generate_canary_pdfs.py`` and embeds its canary token as plain
``(text) Tj`` PDF text operators. Extracting those operators gives us
the canary string back, which is what the Playwright assertions look
for in the resulting Document row.
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Matches the `(escaped text) Tj` text-show operator emitted by
# generate_canary_pdfs.py. Inside the parens, the escape rules are:
# \\ -> backslash
# \( -> literal (
# \) -> literal )
# The character class [^\\()] consumes any non-escape byte; \\. consumes
# an escape sequence. Sufficient for our synthetic fixtures.
_TJ_PATTERN = re.compile(rb"\(((?:[^\\()]|\\.)*)\)\s*Tj")
def _extract_text_from_synthetic_pdf(file_path: str) -> str:
"""Pull every ``(text) Tj`` payload out of a fixture PDF in order.
Returns an empty string if the file cannot be read. We do not try to
handle arbitrary PDFs because the fake is only ever invoked against
fixtures we generate ourselves.
"""
try:
data = Path(file_path).read_bytes()
except OSError as exc:
logger.warning("[fake-docling] could not read %s: %s", file_path, exc)
return ""
lines: list[str] = []
for match in _TJ_PATTERN.finditer(data):
raw = match.group(1)
# Order-sensitive unescape via sentinel: protect `\\` first so
# the subsequent `\(` / `\)` passes do not corrupt it.
text = (
raw.replace(rb"\\", b"\x00")
.replace(rb"\(", b"(")
.replace(rb"\)", b")")
.replace(b"\x00", b"\\")
)
try:
lines.append(text.decode("utf-8"))
except UnicodeDecodeError:
lines.append(text.decode("latin-1"))
return "\n".join(lines)
async def fake_process_document(
self,
file_path: str,
filename: str | None = None,
) -> dict[str, Any]:
"""Drop-in replacement for ``DoclingService.process_document``.
Returns the same dict shape as the production method so callers
(``app/etl_pipeline/parsers/docling.py``) can keep reading
``result["content"]`` without changes.
"""
extracted = _extract_text_from_synthetic_pdf(file_path)
display_name = filename or Path(file_path).name
if extracted:
content = f"# {display_name}\n\n{extracted}\n"
else:
# Empty fallback so the indexing pipeline does not error out on
# an unexpected payload. A failing canary assertion is a much
# clearer failure mode than a hard parser exception.
content = (
f"# {display_name}\n\n(empty docling fake — no text-show operators found)\n"
)
logger.info(
"[fake-docling] returning %d chars for %s",
len(content),
display_name,
)
return {
"content": content,
"full_text": content,
"service_used": "docling-fake",
"status": "success",
"processing_notes": "e2e fake DoclingService — no real PDF parsing",
}
def install(patches: list[Any]) -> None:
"""Patch ``DoclingService.process_document`` at the class level.
Patching the class method (rather than each call site) is correct
here because every consumer goes through
``create_docling_service()`` ``DoclingService()`` instance method
dispatch, so the descriptor protocol picks up our replacement. There
is exactly one such consumer today
(``app/etl_pipeline/parsers/docling.py``), but patching the class is
future-proof.
Fails loud rather than warning, because a silent passthrough means
real Docling + ``HF_HUB_OFFLINE=1`` = 4 minutes of CI hang per test.
"""
from unittest.mock import patch as _patch
target = "app.services.docling_service.DoclingService.process_document"
try:
p = _patch(target, fake_process_document)
p.start()
patches.append(p)
logger.info("[fake-docling] patched %s", target)
except (ModuleNotFoundError, AttributeError) as exc:
raise RuntimeError(
f"Could not patch Docling binding {target!r}: {exc!s}. "
f"Update surfsense_backend/tests/e2e/fakes/docling_service.py "
f"to point at the new binding site."
) from exc

View file

@ -0,0 +1,71 @@
# Synthetic Global LLM configuration for E2E ONLY.
#
# Why this file exists:
# surfsense_backend/app/config/global_llm_config.yaml is gitignored
# (operators ship real API keys there). In CI that file does not exist,
# so app.config.load_global_llm_configs() returns [], every chat-stream
# test fails fast with "No usable global LLM configs are available for
# Auto mode" raised by auto_model_pin_service._global_candidates().
#
# What this file does:
# tests/e2e/run_backend.py and tests/e2e/run_celery.py copy this file
# to app/config/global_llm_config.yaml at startup, BEFORE app.config
# is imported. The copy lives only inside the E2E Docker container.
#
# Why a fake api_key is safe:
# tests.e2e.fakes.chat_llm patches
# app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config
# app.tasks.chat.stream_new_chat.create_chat_litellm_from_config
# so the resolved auto-pin id is never sent to a real LLM provider.
# The values below only need to pass
# auto_model_pin_service._is_usable_global_config()
# which requires id / model_name / provider / api_key all truthy.
#
# Why TWO entries (premium + free):
# auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits
# candidates by billing_tier based on _is_premium_eligible(user):
# premium_eligible == True -> keeps only tier=="premium" configs
# premium_eligible == False -> keeps only tier!="premium" configs
# A single-tier fixture would fail one of the two branches with
# "Auto mode could not find an eligible LLM config for this user and
# quota state". Shipping one of each guarantees every quota state
# resolves to a viable pin in E2E.
router_settings:
routing_strategy: "simple-shuffle"
num_retries: 0
allowed_fails: 1
cooldown_time: 1
global_llm_configs:
- id: -9001
name: "E2E Fake Auto Model (premium)"
billing_tier: "premium"
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
model_name: "fake-e2e-model-premium"
api_key: "fake-e2e-api-key-not-for-production"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000
tpm: 100000
litellm_params:
model: "openai/fake-e2e-model-premium"
- id: -9002
name: "E2E Fake Auto Model (free)"
billing_tier: "free"
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
model_name: "fake-e2e-model-free"
api_key: "fake-e2e-api-key-not-for-production"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000
tpm: 100000
litellm_params:
model: "openai/fake-e2e-model-free"

View file

@ -23,15 +23,12 @@ Usage:
from __future__ import annotations
import asyncio
import logging
import os
import sys
# ---------------------------------------------------------------------------
# 1) Hijack sys.modules BEFORE any production import.
# Production: composio_service.py:11 does `from composio import Composio`.
# With this hijack in place, that import resolves to our strict fake.
# ---------------------------------------------------------------------------
import uvicorn
# Make the surfsense_backend root importable as a top-level package so
# `import tests.e2e.fakes...` works regardless of how the entrypoint is
@ -42,97 +39,175 @@ _BACKEND_ROOT = os.path.abspath(os.path.join(_THIS_DIR, "..", ".."))
if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
# ---------------------------------------------------------------------------
# 2) Standard logging + dotenv so the rest of the app behaves like main.py.
# ---------------------------------------------------------------------------
from dotenv import load_dotenv # noqa: E402
load_dotenv()
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("surfsense.e2e.backend")
logger.warning(
"*** SURFSENSE E2E BACKEND ENTRYPOINT — fake Composio + LLM + embeddings ***"
)
# ---------------------------------------------------------------------------
# 3) Now import the production app. Every module in app.* loads here,
# creating their bindings (some of which we will patch in step 4).
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 4) Patch LLM + embedding bindings at every consumer site.
# Composio is already covered by the sys.modules hijack in step 1.
# ---------------------------------------------------------------------------
from unittest.mock import patch # noqa: E402
from app.app import app # noqa: E402
from tests.e2e.fakes import ( # noqa: E402
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
from tests.e2e.fakes.chat_llm import ( # noqa: E402
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
# Patches started during bootstrap are kept alive for the lifetime of the
# process. We never call .stop() on them.
_active_patches: list = []
def _hijack_external_sdks() -> None:
"""Replace composio + notion_client in sys.modules.
Production does ``from composio import Composio`` and
``import notion_client`` at import time. With this hijack in place,
those imports resolve to our strict fakes.
MUST run before _import_production_app().
"""
import tests.e2e.fakes.composio_module as _fake_composio
import tests.e2e.fakes.notion_module as _fake_notion
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
def _load_dotenv_and_set_env_defaults() -> None:
"""Load .env and set every env var the production config reads on import.
MUST run before _import_production_app(), since app.config consumes
these values at import time.
"""
from dotenv import load_dotenv
load_dotenv()
os.environ.setdefault(
"DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense",
)
os.environ.setdefault("CELERY_BROKER_URL", "redis://localhost:6379/0")
os.environ.setdefault("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
os.environ.setdefault("REDIS_APP_URL", "redis://localhost:6379/0")
os.environ.setdefault("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
os.environ.setdefault("SECRET_KEY", "local-e2e-secret-not-for-production")
os.environ.setdefault("AUTH_TYPE", "LOCAL")
os.environ.setdefault("REGISTRATION_ENABLED", "TRUE")
os.environ.setdefault("ETL_SERVICE", "DOCLING")
os.environ.setdefault("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
os.environ.setdefault("NEXT_FRONTEND_URL", "http://localhost:3000")
# Sentinel keys — fakes never read them; turns leaked real calls into 401s.
os.environ.setdefault("COMPOSIO_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("COMPOSIO_ENABLED", "TRUE")
os.environ.setdefault("OPENAI_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("ANTHROPIC_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("LITELLM_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
# Native Google OAuth — fake Flow in tests.e2e.fakes.native_google
# raises "Fake Google Flow requires redirect_uri." if these are empty,
# so connector/add routes return 500 in CI where no .env supplies them.
os.environ.setdefault(
"GOOGLE_DRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/drive/connector/callback",
)
os.environ.setdefault(
"GOOGLE_GMAIL_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/gmail/connector/callback",
)
os.environ.setdefault(
"GOOGLE_CALENDAR_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/calendar/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
def _install_synthetic_global_llm_config() -> None:
"""Materialise a fake ``app/config/global_llm_config.yaml`` for E2E.
The real file is gitignored (production operators ship their own with
real API keys), so a fresh CI checkout has no YAML at the path
``app.config.load_global_llm_configs()`` reads. With an empty
``GLOBAL_LLM_CONFIGS`` list, ``auto_model_pin_service`` raises
``"No usable global LLM configs are available for Auto mode"`` on
every chat-stream request.
We copy the synthetic fixture from ``tests/e2e/fixtures/`` into the
production-expected location BEFORE ``_import_production_app()`` so
``app.config`` picks it up on import. Production code is untouched
this is purely a test-time scaffold.
Only installs when the destination is missing. A developer running
the E2E entrypoint locally keeps their real ``global_llm_config.yaml``
intact (the patched ``create_chat_litellm_from_*`` factories make the
actual model values irrelevant either way).
MUST run before _import_production_app().
"""
import shutil
src = os.path.join(_THIS_DIR, "fixtures", "global_llm_config.yaml")
dst = os.path.join(_BACKEND_ROOT, "app", "config", "global_llm_config.yaml")
if not os.path.exists(src):
raise RuntimeError(
f"E2E synthetic global LLM config fixture missing at {src!r}. "
f"This file is checked into tests/e2e/fixtures/ — if it has gone "
f"missing, restore it from VCS before running the E2E entrypoint."
)
if os.path.exists(dst):
logger.info(
"[e2e-global-llm-config] %s already exists; leaving it alone "
"(local dev config preserved)",
dst,
)
return
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copyfile(src, dst)
logger.info("[e2e-global-llm-config] installed %s -> %s", src, dst)
def _import_production_app():
"""Import and return the production FastAPI app.
Every module under ``app.*`` loads here, creating their bindings.
The LLM/embedding factories captured at this point will be replaced
by patches in _patch_llm_bindings() below.
"""
from app.app import app as production_app
return production_app
def _patch_llm_bindings() -> None:
"""Replace LLM factories at every known binding site."""
from unittest.mock import patch
from tests.e2e.fakes.chat_llm import (
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm
targets = [
"app.services.llm_service.get_user_long_context_llm",
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
@ -190,38 +265,90 @@ def _patch_llm_bindings() -> None:
logger.warning("[fake-chat-llm] could not patch %s: %s.", target, exc)
_patch_llm_bindings()
_fake_embeddings.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
def _install_runtime_fakes() -> None:
"""Run each fake's install() against the active patch stack."""
from tests.e2e.fakes import (
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
docling_service as _fake_docling_service,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
_fake_embeddings.install(_active_patches)
_fake_docling_service.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
# ---------------------------------------------------------------------------
# 5) Mount test-only middleware. Production never reaches this code.
# ---------------------------------------------------------------------------
def _install_test_only_app_extensions(app) -> None:
"""Mount test-only middleware + the /__e2e__ token mint router.
from tests.e2e.middleware.scenario import ScenarioMiddleware # noqa: E402
POST /__e2e__/auth/token bypasses /auth/jwt/login's 5/min/IP rate
limit so Playwright workers can authenticate without thrashing the
production auth surface. See tests/e2e/auth_mint.py.
"""
from tests.e2e.auth_mint import install as install_e2e_mint
from tests.e2e.middleware.scenario import ScenarioMiddleware
app.add_middleware(ScenarioMiddleware)
app.add_middleware(ScenarioMiddleware)
install_e2e_mint(app)
# ---------------------------------------------------------------------------
# 6) Start uvicorn, mirroring main.py's behaviour.
# ---------------------------------------------------------------------------
def _bootstrap():
"""Run the full E2E bootstrap and return the production FastAPI app.
import asyncio # noqa: E402
Ordering is load-bearing:
1) Hijack composio + notion_client in sys.modules.
2) Load .env + set env defaults (app.config reads env on import).
3) Configure logging.
4) Materialise the synthetic global_llm_config.yaml so Auto-mode
pin resolution finds at least one usable candidate.
5) Import production app (which transitively imports the now-faked
external SDKs and reads the env defaults + YAML).
6) Patch LLM / embedding bindings at every consumer site.
7) Mount test-only middleware + /__e2e__ routes onto the app.
"""
_hijack_external_sdks()
_load_dotenv_and_set_env_defaults()
import uvicorn # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger.warning(
"*** SURFSENSE E2E BACKEND ENTRYPOINT — fake Composio + LLM + embeddings ***"
)
_install_synthetic_global_llm_config()
production_app = _import_production_app()
_patch_llm_bindings()
_install_runtime_fakes()
_install_test_only_app_extensions(production_app)
return production_app
app = _bootstrap()
def _main() -> None:

View file

@ -25,96 +25,166 @@ if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
# ---------------------------------------------------------------------------
# 1) Hijack sys.modules BEFORE production celery imports anything.
# ---------------------------------------------------------------------------
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
# ---------------------------------------------------------------------------
# 2) Logging + dotenv.
# ---------------------------------------------------------------------------
from dotenv import load_dotenv # noqa: E402
load_dotenv()
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("surfsense.e2e.celery")
logger.warning("*** SURFSENSE E2E CELERY WORKER — fake Composio + LLM + embeddings ***")
# ---------------------------------------------------------------------------
# 3) Import the production celery_app. All task modules load here.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 4) Patch LLM + embedding bindings inside the worker process.
# ---------------------------------------------------------------------------
from unittest.mock import patch # noqa: E402
from app.celery_app import celery_app # noqa: E402
from tests.e2e.fakes import ( # noqa: E402
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
from tests.e2e.fakes.chat_llm import ( # noqa: E402
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
# Patches started during bootstrap are kept alive for the lifetime of the
# process. We never call .stop() on them.
_active_patches: list = []
def _hijack_external_sdks() -> None:
"""Replace composio + notion_client in sys.modules.
Production does ``from composio import Composio`` and
``import notion_client`` at import time. With this hijack in place,
those imports resolve to our strict fakes.
MUST run before _import_celery_app().
"""
import tests.e2e.fakes.composio_module as _fake_composio
import tests.e2e.fakes.notion_module as _fake_notion
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
def _load_dotenv_and_set_env_defaults() -> None:
"""Load .env and set every env var the production config reads on import.
MUST run before _import_celery_app(), since app.config consumes
these values at import time.
"""
from dotenv import load_dotenv
load_dotenv()
os.environ.setdefault(
"DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense",
)
os.environ.setdefault("CELERY_BROKER_URL", "redis://localhost:6379/0")
os.environ.setdefault("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
os.environ.setdefault("REDIS_APP_URL", "redis://localhost:6379/0")
os.environ.setdefault("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
os.environ.setdefault("SECRET_KEY", "local-e2e-secret-not-for-production")
os.environ.setdefault("AUTH_TYPE", "LOCAL")
os.environ.setdefault("REGISTRATION_ENABLED", "TRUE")
os.environ.setdefault("ETL_SERVICE", "DOCLING")
os.environ.setdefault("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
os.environ.setdefault("NEXT_FRONTEND_URL", "http://localhost:3000")
# Sentinel keys — fakes never read them; turns leaked real calls into 401s.
os.environ.setdefault("COMPOSIO_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("COMPOSIO_ENABLED", "TRUE")
os.environ.setdefault("OPENAI_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("ANTHROPIC_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("LITELLM_API_KEY", "local-deny-real-call-sentinel")
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
# Native Google OAuth — fake Flow in tests.e2e.fakes.native_google raises
# "Fake Google Flow requires redirect_uri." when these are empty.
os.environ.setdefault(
"GOOGLE_DRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/drive/connector/callback",
)
os.environ.setdefault(
"GOOGLE_GMAIL_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/gmail/connector/callback",
)
os.environ.setdefault(
"GOOGLE_CALENDAR_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/google/calendar/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
def _install_synthetic_global_llm_config() -> None:
"""Materialise a fake ``app/config/global_llm_config.yaml`` for E2E.
The real file is gitignored (production operators ship their own with
real API keys), so a fresh CI checkout has no YAML at the path
``app.config.load_global_llm_configs()`` reads. With an empty
``GLOBAL_LLM_CONFIGS`` list, the worker's view of the config diverges
from the API container.
We copy the synthetic fixture from ``tests/e2e/fixtures/`` into the
production-expected location BEFORE _import_celery_app() so
``app.config`` picks it up on import. Install-only-if-missing so a
developer's local config (with real API keys) is preserved.
MUST run before _import_celery_app().
"""
import shutil
src = os.path.join(_THIS_DIR, "fixtures", "global_llm_config.yaml")
dst = os.path.join(_BACKEND_ROOT, "app", "config", "global_llm_config.yaml")
if not os.path.exists(src):
raise RuntimeError(
f"E2E synthetic global LLM config fixture missing at {src!r}. "
f"Restore tests/e2e/fixtures/global_llm_config.yaml from VCS."
)
if os.path.exists(dst):
logger.info(
"[e2e-global-llm-config] %s already exists; leaving it alone "
"(local dev config preserved)",
dst,
)
return
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copyfile(src, dst)
logger.info("[e2e-global-llm-config] installed %s -> %s", src, dst)
def _import_celery_app():
"""Import and return the production Celery app.
Every module under ``app.*`` (including all task modules) loads here,
creating their bindings. The LLM/embedding factories captured at this
point will be replaced by patches in _patch_llm_bindings() below.
"""
from app.celery_app import celery_app
return celery_app
def _patch_llm_bindings() -> None:
"""Replace LLM factories at every known binding site in worker tasks."""
from unittest.mock import patch
from tests.e2e.fakes.chat_llm import (
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm
targets = [
"app.services.llm_service.get_user_long_context_llm",
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
@ -172,38 +242,93 @@ def _patch_llm_bindings() -> None:
)
_patch_llm_bindings()
_fake_embeddings.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
def _install_runtime_fakes() -> None:
"""Run each fake's install() against the active patch stack."""
from tests.e2e.fakes import (
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
docling_service as _fake_docling_service,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
_fake_embeddings.install(_active_patches)
_fake_docling_service.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
# ---------------------------------------------------------------------------
# 5) Start the worker.
# ---------------------------------------------------------------------------
def _bootstrap():
"""Run the full E2E bootstrap and return the production Celery app.
Ordering is load-bearing:
1) Hijack composio + notion_client in sys.modules.
2) Load .env + set env defaults (app.config reads env on import).
3) Configure logging.
4) Materialise the synthetic global_llm_config.yaml so the worker's
view of GLOBAL_LLM_CONFIGS matches the API container.
5) Import production celery_app (which transitively imports the
now-faked external SDKs and reads the env defaults + YAML).
6) Patch LLM / embedding bindings at every consumer site.
7) Install runtime fakes for connectors and chat backends.
"""
_hijack_external_sdks()
_load_dotenv_and_set_env_defaults()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger.warning(
"*** SURFSENSE E2E CELERY WORKER — fake Composio + LLM + embeddings ***"
)
_install_synthetic_global_llm_config()
celery_app = _import_celery_app()
_patch_llm_bindings()
_install_runtime_fakes()
return celery_app
celery_app = _bootstrap()
def _main() -> None:
# Default queues mirror production (default queue + connectors queue
# so Drive indexing tasks are picked up).
queue_name = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
queues = f"{queue_name},{queue_name}.connectors"
# macOS forks-after-MPS-init crash prefork workers; threads avoid it.
default_pool = "threads" if sys.platform == "darwin" else "prefork"
pool = os.getenv("CELERY_POOL", default_pool)
concurrency = os.getenv("CELERY_CONCURRENCY", "2")
celery_app.worker_main(
argv=[
"worker",
"--loglevel=info",
f"--queues={queues}",
"--concurrency=2",
f"--pool={pool}",
f"--concurrency={concurrency}",
"--without-gossip",
"--without-mingle",
]

View file

@ -3,15 +3,24 @@
from __future__ import annotations
import ast
import asyncio
from types import SimpleNamespace
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
@ -24,8 +33,6 @@ class _SubagentState(TypedDict, total=False):
def _build_single_interrupt_subagent():
def approve_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
@ -50,17 +57,27 @@ def _build_single_interrupt_subagent():
return graph.compile(checkpointer=InMemorySaver())
def _make_runtime(config: dict) -> ToolRuntime:
def _make_runtime(config: dict, *, tool_call_id: str = "parent-tcid-1") -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=config,
stream_writer=None,
tool_call_id="parent-tcid-1",
tool_call_id=tool_call_id,
store=None,
)
def _prime_subagent_at_runtime_thread(subagent, runtime: ToolRuntime) -> dict:
"""Build the per-call ``RunnableConfig`` the production ``task`` tool will use.
Mirrors what the ``task`` tool does on first invocation so test fixtures
can prime the subagent's pending interrupt at the same checkpoint slot
(per-call ``thread_id``) the bridge looks at on resume.
"""
return subagent_invoke_config(runtime)
@pytest.mark.asyncio
async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"""Side-channel decision must reach the subagent's pending interrupt verbatim."""
@ -79,16 +96,17 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
"configurable": {"thread_id": "shared-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, (
"fixture broken: subagent should be paused on its interrupt"
)
parent_config["configurable"]["surfsense_resume_value"] = {
"decisions": ["APPROVED"]
runtime.tool_call_id: {"decisions": ["APPROVED"]}
}
runtime = _make_runtime(parent_config)
result = await task_tool.coroutine(
description="please approve",
@ -101,7 +119,7 @@ async def test_resume_bridge_dispatches_decision_into_pending_subagent():
assert update["decision_text"] == repr({"decisions": ["APPROVED"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
final = await subagent.aget_state(parent_config)
final = await subagent.aget_state(sub_config)
assert not final.tasks or all(not t.interrupts for t in final.tasks)
@ -123,11 +141,11 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
"configurable": {"thread_id": "guard-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await subagent.aget_state(parent_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
snap = await subagent.aget_state(sub_config)
assert snap.tasks and snap.tasks[0].interrupts, "fixture broken"
with pytest.raises(RuntimeError, match="resume bridge is broken"):
await task_tool.coroutine(
@ -139,8 +157,6 @@ async def test_pending_interrupt_without_resume_value_raises_runtime_error():
def _build_bundle_subagent():
def bundle_node(state):
from langchain_core.messages import AIMessage
decision = interrupt(
{
"action_requests": [
@ -181,7 +197,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
"configurable": {"thread_id": "bundle-thread"},
"recursion_limit": 100,
}
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
runtime = _make_runtime(parent_config)
sub_config = _prime_subagent_at_runtime_thread(subagent, runtime)
await subagent.ainvoke({"messages": [HumanMessage(content="seed")]}, sub_config)
decisions_payload = {
"decisions": [
@ -190,8 +208,9 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
{"type": "reject", "args": {"message": "no thanks"}},
]
}
parent_config["configurable"]["surfsense_resume_value"] = decisions_payload
runtime = _make_runtime(parent_config)
parent_config["configurable"]["surfsense_resume_value"] = {
runtime.tool_call_id: decisions_payload
}
result = await task_tool.coroutine(
description="run bundle",
@ -206,3 +225,182 @@ async def test_bundle_three_mixed_decisions_arrive_in_order():
assert received["decisions"][1]["type"] == "edit"
assert received["decisions"][1]["args"] == {"args": {"name": "edited-b"}}
assert received["decisions"][2]["type"] == "reject"
@pytest.mark.asyncio
async def test_parallel_atask_routes_each_decision_to_its_own_subagent():
"""Two ``atask`` calls with distinct ``tool_call_id``s must each get their own decision.
With per-call ``thread_id`` isolation and per-call resume keying, A's
decision must reach A's pending interrupt and B's must reach B's. They
must NOT cross-contaminate even though they share ``configurable``.
"""
subagent_a = _build_single_interrupt_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver_a",
"description": "approves A",
"runnable": subagent_a,
},
{
"name": "approver_b",
"description": "approves B",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "parallel-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
parent_config["configurable"]["surfsense_resume_value"] = {
"tcid-A": {"decisions": ["DECISION-A"]},
"tcid-B": {"decisions": ["DECISION-B"]},
}
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="please approve A",
subagent_type="approver_a",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve B",
subagent_type="approver_b",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
assert result_a.update["decision_text"] == repr({"decisions": ["DECISION-A"]})
assert result_b.update["decision_text"] == repr({"decisions": ["DECISION-B"]})
assert "surfsense_resume_value" not in parent_config["configurable"]
@pytest.mark.asyncio
async def test_full_resume_routing_glue_for_two_paused_subagents():
"""End-to-end: extractor + slicer + bridge correctly route a flat decisions list.
This simulates exactly what ``stream_resume_chat`` will do on resume:
given a paused parent state with two pending interrupts (one per
subagent) and a flat ``decisions`` list, build the per-tool-call dict
via ``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call``,
then resume the bridge concurrently and verify each subagent received
only its own slice.
"""
subagent_a = _build_bundle_subagent()
subagent_b = _build_single_interrupt_subagent()
task_tool = build_task_tool_with_parent_config(
[
{
"name": "bundler",
"description": "three-action bundle",
"runnable": subagent_a,
},
{
"name": "approver",
"description": "single approval",
"runnable": subagent_b,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "glue-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-bundler")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-approver")
sub_config_a = _prime_subagent_at_runtime_thread(subagent_a, runtime_a)
sub_config_b = _prime_subagent_at_runtime_thread(subagent_b, runtime_b)
await subagent_a.ainvoke(
{"messages": [HumanMessage(content="seed-A")]}, sub_config_a
)
await subagent_b.ainvoke(
{"messages": [HumanMessage(content="seed-B")]}, sub_config_b
)
# Synthetic parent state mirroring what the parent's pregel would have
# bundled: one Interrupt per subagent, value carrying tool_call_id +
# action_requests (exactly the shape ``propagation.wrap_with_tool_call_id``
# produces).
parent_interrupts = (
SimpleNamespace(
id="i-bundler",
value={
"action_requests": [
{"name": "create_a", "args": {}, "description": ""},
{"name": "create_b", "args": {}, "description": ""},
{"name": "create_c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
"tool_call_id": "tcid-bundler",
},
),
SimpleNamespace(
id="i-approver",
value={
"action_requests": [{"name": "approve", "args": {}, "description": ""}],
"review_configs": [{}],
"tool_call_id": "tcid-approver",
},
),
)
parent_state = SimpleNamespace(interrupts=parent_interrupts)
flat_decisions = [
{"type": "approve"},
{"type": "edit", "args": {"args": {"name": "edited-b"}}},
{"type": "reject", "args": {"message": "no thanks"}},
{"type": "approve"},
]
pending = collect_pending_tool_calls(parent_state)
assert pending == [("tcid-bundler", 3), ("tcid-approver", 1)]
routed = slice_decisions_by_tool_call(flat_decisions, pending)
parent_config["configurable"]["surfsense_resume_value"] = routed
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="run bundle",
subagent_type="bundler",
runtime=runtime_a,
),
task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=runtime_b,
),
)
assert isinstance(result_a, Command)
assert isinstance(result_b, Command)
received_a = ast.literal_eval(result_a.update["decision_text"])
assert received_a == {"decisions": flat_decisions[0:3]}
assert result_b.update["decision_text"] == repr({"decisions": flat_decisions[3:4]})
assert "surfsense_resume_value" not in parent_config["configurable"]

View file

@ -0,0 +1,259 @@
"""Real-graph contract: heterogeneous decisions route correctly across parallel subagents.
The simple "approve everything" parallel test (see
``test_parallel_resume_command_keying``) proves the routing wires up at all,
but it doesn't exercise the actual production user flow: rejecting one card
while approving another, or editing one action's args before approving the
rest. Those are the decisions ``HumanInTheLoopMiddleware`` differentiates on,
and they're exactly where a slicer/router bug silently mis-applies a reject
to the wrong subagent.
This module pins:
1. **Order preservation** across the slice boundary flat decisions enter
in the order the SSE stream rendered cards; each subagent must receive
only its slice in the original order.
2. **Per-decision metadata pass-through** ``message`` and ``edited_action``
payloads must reach the subagent intact (not just the ``type`` discriminator).
3. **Off-by-one-sensitive bundle sizes** both paused subagents have action
counts ``> 1`` (``2`` and ``3``). With those sizes a buggy
``cursor += 1`` slicer (instead of ``cursor += action_count``) produces a
different B-slice from the correct one, so this test catches the most
common refactor mistake. A ``(1, 2)`` configuration would silently pass
such a bug because ``+= 1`` and ``+= count`` are arithmetically identical
when ``count == 1``.
"""
from __future__ import annotations
import contextlib
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
def _build_capturing_subagent(checkpointer: InMemorySaver, *, action_count: int):
"""Subagent that pauses with an N-action bundle and on resume records what it received.
The recorded ``AIMessage`` content is the JSON-serialized resume payload, so
the assertions can inspect exactly which decisions reached this subagent
(vs. its sibling) including the ``message`` and ``edited_action``
metadata, not just the ``type``.
"""
def hitl_node(_state):
decision_payload = interrupt(
{
"action_requests": [
{
"name": f"act_{i}",
"args": {"i": i},
"description": f"action {i}",
}
for i in range(action_count)
],
"review_configs": [
{
"action_name": f"act_{i}",
"allowed_decisions": ["approve", "reject", "edit"],
}
for i in range(action_count)
],
}
)
return {
"messages": [
AIMessage(content=json.dumps(decision_payload, sort_keys=True))
]
}
g = StateGraph(_SubState)
g.add_node("hitl", hitl_node)
g.add_edge(START, "hitl")
g.add_edge("hitl", END)
return g.compile(checkpointer=checkpointer)
def _parent_dispatching_two_subagents(
task_tool, *, dispatches: list[dict[str, str]], checkpointer
):
"""Parent that fans out to ``len(dispatches)`` parallel ``task`` tool calls.
Each entry in ``dispatches`` is ``{"tcid": ..., "subtype": ..., "desc": ...}``
so different parallel branches can target different subagent types the
actual production scenario (Linear + Jira, etc.).
"""
def fanout(_state) -> list[Send]:
return [Send("call_task", d) for d in dispatches]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_heterogeneous_decisions_route_to_correct_subagents_with_metadata_intact():
"""Mixed approve/reject/edit decisions across two parallel subagents.
Setup chosen so the slicer's cursor arithmetic is sensitive to off-by-one
refactors:
- Sub-A pauses with a 2-action bundle (``act_0``, ``act_1``).
- Sub-B pauses with a 3-action bundle (``act_0``, ``act_1``, ``act_2``).
- Parent ends up with 2 pending interrupts (one per subagent).
With both counts ``> 1``, a buggy ``cursor += 1`` (instead of
``cursor += action_count``) produces a different B-slice from the correct
one, so the assertions catch it. A ``(1, 2)`` configuration would not
because ``+= 1`` and ``+= count`` are arithmetically identical when
``count == 1``.
The frontend submits a flat
``[A_approve, A_reject, B_edit, B_approve, B_reject]`` list with distinct
``message`` and ``edited_action`` payloads; our slicer must split into
``{tcid_A: [A_approve, A_reject], tcid_B: [B_edit, B_approve, B_reject]}``
and the bridge must forward each subagent's slice intact — including all
metadata, in original order.
"""
checkpointer = InMemorySaver()
sub_a = _build_capturing_subagent(checkpointer, action_count=2)
sub_b = _build_capturing_subagent(checkpointer, action_count=3)
task_tool = build_task_tool_with_parent_config(
[
{"name": "agent-a", "description": "first", "runnable": sub_a},
{"name": "agent-b", "description": "second", "runnable": sub_b},
]
)
parent = _parent_dispatching_two_subagents(
task_tool,
dispatches=[
{"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"},
{"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"},
],
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "het-decisions-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused_state = await parent.aget_state(config)
assert len(paused_state.interrupts) == 2, (
f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}"
)
pending = collect_pending_tool_calls(paused_state)
pending_by_tcid = dict(pending)
assert pending_by_tcid == {"tcid-A": 2, "tcid-B": 3}, (
f"REGRESSION: action-count accounting wrong; got {pending_by_tcid!r}"
)
a_approve = {"type": "approve"}
a_reject = {"type": "reject", "message": "A[1] looks redundant"}
b_edit = {
"type": "edit",
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
}
b_approve = {"type": "approve"}
b_reject = {"type": "reject", "message": "B[2] needs more context"}
flat_decisions = [a_approve, a_reject, b_edit, b_approve, b_reject]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
assert by_tool_call_id == {
"tcid-A": {"decisions": [a_approve, a_reject]},
"tcid-B": {"decisions": [b_edit, b_approve, b_reject]},
}, f"REGRESSION: slicer mis-routed decisions: {by_tool_call_id!r}"
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
await parent.ainvoke(Command(resume=lg_resume_map), config)
final_state = await parent.aget_state(config)
assert not final_state.interrupts, (
f"REGRESSION: leftover pending interrupts after resume: {final_state.interrupts!r}"
)
payloads: list[dict] = []
for msg in final_state.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
with contextlib.suppress(json.JSONDecodeError):
payloads.append(json.loads(content))
expected_a = {"decisions": [a_approve, a_reject]}
expected_b = {"decisions": [b_edit, b_approve, b_reject]}
assert expected_a in payloads, (
f"REGRESSION: sub-A did not receive its 2-decision slice in original order; "
f"payloads seen: {payloads!r}"
)
assert expected_b in payloads, (
f"REGRESSION: sub-B did not receive its 3-decision slice in original order; "
f"payloads seen: {payloads!r}"
)
@pytest.mark.asyncio
async def test_decision_count_mismatch_fails_loud_before_dispatch():
"""The slicer must refuse a flat list whose total != sum(action_counts).
Otherwise a frontend/backend contract drift would silently send a
truncated/padded slice to one of the subagents the worst possible
failure mode (mis-applied reject on a long-lived ticket).
"""
pending = [("tcid-A", 1), ("tcid-B", 2)]
decisions = [{"type": "approve"}, {"type": "approve"}]
with pytest.raises(ValueError, match="Decision count mismatch"):
slice_decisions_by_tool_call(decisions, pending)

View file

@ -0,0 +1,253 @@
"""Real-graph contract: one parallel branch completes while a sibling pauses with HITL.
The two existing parallel-routing tests
(``test_parallel_resume_command_keying`` and
``test_parallel_heterogeneous_decisions``) both pause **all** branches
simultaneously. That's the easy case — every dispatched ``task`` call has a
matching pending interrupt, and the routing helpers see a uniform shape.
Production rarely matches that uniform shape. The orchestrator typically
delegates "create a Linear ticket and summarize the user's recent activity":
one branch needs HITL, the other returns its result and exits. At the pause
moment::
state.values["messages"] += [ToolMessage(from-A)] # A merged in
state.interrupts = [Interrupt(value-from-B)] # B alone is pending
So ``len(state.interrupts) < num_dispatched_tasks``. The slicer and
``build_lg_resume_map`` must:
1. **Key off ``state.interrupts``, never off the originally dispatched tcids.**
A flat decisions list of length 1 must route only to B; if anything tries
to look up A in the resume map, langgraph rejects an unknown
``Interrupt.id``.
2. **Leave A's contributions intact across resume.** A's ToolMessage was
committed at the pause; resuming the paused branch must not re-run A nor
drop its message.
3. **Drain the single pending interrupt.** Final ``state.interrupts`` is
empty regardless of whether sibling branches were paused.
The langgraph semantics this test relies on were verified empirically in the
exploratory probe before this test was authored.
"""
from __future__ import annotations
import contextlib
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: Annotated[list, add_messages]
class _DispatchState(TypedDict, total=False):
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
_QUICK_MARKER = "quick-subagent-finished-without-pausing"
def _build_quick_subagent(checkpointer: InMemorySaver):
"""Subagent that completes synchronously without firing any interrupt."""
def quick_node(_state):
return {"messages": [AIMessage(content=_QUICK_MARKER)]}
g = StateGraph(_SubState)
g.add_node("quick", quick_node)
g.add_edge(START, "quick")
g.add_edge("quick", END)
return g.compile(checkpointer=checkpointer)
def _build_pausing_subagent(checkpointer: InMemorySaver):
"""Subagent that pauses with a single-action HITL bundle and records its resume payload."""
def hitl_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": "act_0", "args": {"i": 0}, "description": ""}
],
"review_configs": [
{
"action_name": "act_0",
"allowed_decisions": ["approve", "reject", "edit"],
}
],
}
)
return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]}
g = StateGraph(_SubState)
g.add_node("hitl", hitl_node)
g.add_edge(START, "hitl")
g.add_edge("hitl", END)
return g.compile(checkpointer=checkpointer)
def _parent_with_two_branches(task_tool, *, dispatches, checkpointer):
def fanout(_state) -> list[Send]:
return [Send("call_task", d) for d in dispatches]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
def _quick_marker_count(state) -> int:
"""How many messages anywhere in parent state contain the quick subagent's marker."""
n = 0
for msg in state.values.get("messages", []) or []:
content = getattr(msg, "content", "")
if isinstance(content, str) and _QUICK_MARKER in content:
n += 1
return n
@pytest.mark.asyncio
async def test_partial_pause_routes_only_to_paused_branch_without_rerunning_completed_one():
"""One branch completes synchronously; the other pauses with HITL — resume routes only to B.
Setup:
- Sub-A (``quick``): no interrupt, finishes immediately, writes a marker
message to parent state.
- Sub-B (``pausing``): interrupts with a 1-action HITL bundle.
At pause, parent state has A's marker already merged in and exactly one
pending interrupt (B's). Resume sends a 1-element flat decisions list;
the routing helpers must not look up A in the resume map (would explode
with an unknown ``Interrupt.id``) and must not re-invoke A on resume
(would duplicate the marker).
"""
checkpointer = InMemorySaver()
quick_sub = _build_quick_subagent(checkpointer)
pausing_sub = _build_pausing_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[
{"name": "quick-agent", "description": "instant", "runnable": quick_sub},
{
"name": "pausing-agent",
"description": "needs review",
"runnable": pausing_sub,
},
]
)
parent = _parent_with_two_branches(
task_tool,
dispatches=[
{"tcid": "tcid-A", "subtype": "quick-agent", "desc": "do A fast"},
{
"tcid": "tcid-B",
"subtype": "pausing-agent",
"desc": "needs approval",
},
],
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "partial-pause-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused = await parent.aget_state(config)
assert len(paused.interrupts) == 1, (
f"REGRESSION: expected exactly 1 pending interrupt (sub-B alone), "
f"got {len(paused.interrupts)}"
)
pending = collect_pending_tool_calls(paused)
assert pending == [("tcid-B", 1)], (
f"REGRESSION: pending list contains stale tcids; got {pending!r}"
)
pre_resume_marker_count = _quick_marker_count(paused)
assert pre_resume_marker_count == 1, (
f"REGRESSION: sub-A's contribution missing or duplicated at pause "
f"(found {pre_resume_marker_count}, expected 1)"
)
flat_decisions = [{"type": "approve"}]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
assert by_tool_call_id == {"tcid-B": {"decisions": [{"type": "approve"}]}}, (
f"REGRESSION: slicer routed to a non-pending tcid: {by_tool_call_id!r}"
)
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
assert set(lg_resume_map.keys()) == {paused.interrupts[0].id}, (
f"REGRESSION: resume map keyed by an unknown Interrupt.id "
f"(would crash langgraph): {lg_resume_map!r}"
)
await parent.ainvoke(Command(resume=lg_resume_map), config)
final = await parent.aget_state(config)
assert not final.interrupts, (
f"REGRESSION: pending interrupts after resume: {final.interrupts!r}"
)
post_resume_marker_count = _quick_marker_count(final)
assert post_resume_marker_count == 1, (
f"REGRESSION: sub-A re-ran on resume (marker count went "
f"{pre_resume_marker_count}{post_resume_marker_count}); "
f"resume must touch only the paused branch."
)
payloads: list[dict] = []
for msg in final.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
with contextlib.suppress(json.JSONDecodeError):
payloads.append(json.loads(content))
assert {"decisions": [{"type": "approve"}]} in payloads, (
f"REGRESSION: sub-B did not receive its single approve on resume; "
f"payloads seen: {payloads!r}"
)

View file

@ -0,0 +1,215 @@
"""Real-graph contract: all-reject decisions route correctly across parallel subagents.
Heterogeneous routing is covered by ``test_parallel_heterogeneous_decisions``.
This module pins the narrower edge case where **every** card on **every**
paused subagent is rejected.
Why a separate pin:
1. **No approval-bias in the slicer.** A future "if no approvals, short-circuit
resume" optimization would be tempting (skips a langgraph round-trip) and
would also silently break this scenario. Pin it.
2. **``message`` metadata pass-through across a run of rejects.** The reject
``message`` is the user-visible reason ("looks suspicious", "duplicate",
etc.). Losing it would silently swallow user intent the worst UX
failure mode for HITL. Heterogeneous covers one reject; here we verify a
sequence of rejects survives the slicer + bridge with distinct messages
intact and in order.
3. **All branches complete with no leftover pending.** Even when nothing was
approved, the parent must drain every paused subagent so the SSE stream
can close cleanly. A bug that left one ``Interrupt.id`` un-keyed would
strand the conversation in "pending" forever.
"""
from __future__ import annotations
import contextlib
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
def _build_recording_subagent(checkpointer: InMemorySaver, *, action_count: int):
"""Subagent that pauses with ``action_count`` actions and records its resume payload.
The recorded ``AIMessage`` content is the JSON-serialized payload, so the
test can match each subagent's slice by content.
"""
def hitl_node(_state):
decision_payload = interrupt(
{
"action_requests": [
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
for i in range(action_count)
],
"review_configs": [
{
"action_name": f"act_{i}",
"allowed_decisions": ["approve", "reject", "edit"],
}
for i in range(action_count)
],
}
)
return {
"messages": [
AIMessage(content=json.dumps(decision_payload, sort_keys=True))
]
}
g = StateGraph(_SubState)
g.add_node("hitl", hitl_node)
g.add_edge(START, "hitl")
g.add_edge("hitl", END)
return g.compile(checkpointer=checkpointer)
def _parent_two_branches(task_tool, *, dispatches, checkpointer):
def fanout(_state) -> list[Send]:
return [Send("call_task", d) for d in dispatches]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_all_reject_decisions_route_to_each_subagent_with_messages_intact():
"""All cards rejected across two parallel subagents — order and messages preserved.
Setup mirrors a real "user reviews two parallel ticket creations and
rejects everything with distinct reasons":
- Sub-A pauses with 2 actions.
- Sub-B pauses with 1 action.
- Flat decisions: 3 rejects, each with a unique ``message``.
Asserts each subagent receives only its slice, in original order,
with every ``message`` intact and no ``edited_action`` fields fabricated.
"""
checkpointer = InMemorySaver()
sub_a = _build_recording_subagent(checkpointer, action_count=2)
sub_b = _build_recording_subagent(checkpointer, action_count=1)
task_tool = build_task_tool_with_parent_config(
[
{"name": "agent-a", "description": "first", "runnable": sub_a},
{"name": "agent-b", "description": "second", "runnable": sub_b},
]
)
parent = _parent_two_branches(
task_tool,
dispatches=[
{"tcid": "tcid-A", "subtype": "agent-a", "desc": "do A"},
{"tcid": "tcid-B", "subtype": "agent-b", "desc": "do B"},
],
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "all-reject-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused_state = await parent.aget_state(config)
assert len(paused_state.interrupts) == 2, (
f"fixture broken: expected 2 paused subagents, got {len(paused_state.interrupts)}"
)
a_reject_0 = {"type": "reject", "message": "A[0] looks suspicious"}
a_reject_1 = {"type": "reject", "message": "A[1] duplicates A[0]"}
b_reject_0 = {"type": "reject", "message": "B[0] needs more context"}
flat_decisions = [a_reject_0, a_reject_1, b_reject_0]
pending = collect_pending_tool_calls(paused_state)
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
assert by_tool_call_id == {
"tcid-A": {"decisions": [a_reject_0, a_reject_1]},
"tcid-B": {"decisions": [b_reject_0]},
}, f"REGRESSION: slicer mis-routed all-reject decisions: {by_tool_call_id!r}"
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
await parent.ainvoke(Command(resume=lg_resume_map), config)
final_state = await parent.aget_state(config)
assert not final_state.interrupts, (
f"REGRESSION: leftover pending interrupts after all-reject resume: "
f"{final_state.interrupts!r}"
)
payloads: list[dict] = []
for msg in final_state.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
with contextlib.suppress(json.JSONDecodeError):
payloads.append(json.loads(content))
expected_a = {"decisions": [a_reject_0, a_reject_1]}
expected_b = {"decisions": [b_reject_0]}
assert expected_a in payloads, (
f"REGRESSION: sub-A did not receive its 2-reject slice in order; "
f"payloads seen: {payloads!r}"
)
assert expected_b in payloads, (
f"REGRESSION: sub-B did not receive its single reject; "
f"payloads seen: {payloads!r}"
)
for p in payloads:
for d in p.get("decisions", []):
assert "edited_action" not in d, (
f"REGRESSION: spurious ``edited_action`` on a reject — "
f"slicer/bridge mutated payload: {d!r}"
)

View file

@ -0,0 +1,300 @@
"""Real-graph contract: parallel resume must key ``Command(resume=...)`` by ``Interrupt.id``.
When the parent state has multiple pending interrupts, langgraph rejects a
scalar ``Command(resume=v)`` with::
RuntimeError: When there are multiple pending interrupts, you must specify
the interrupt id when resuming.
The fix is to map each ``Interrupt.id`` from ``state.interrupts`` to the
per-subagent slice orthogonal to our ``tool_call_id``-keyed
``surfsense_resume_value`` side-channel (different consumer: langgraph's
pregel vs. our subagent bridge).
This test reproduces the production failure with a real two-task parallel
``Send`` parent graph, exercises the full resume cycle, and asserts both
subagents complete cleanly with their per-subagent slice intact.
Bundle sizes are chosen heterogeneous (``2`` and ``3``) so the assertions
also catch slicer arithmetic regressions (e.g., ``cursor += 1`` instead of
``cursor += action_count``). A symmetric ``(1, 1)`` configuration would
silently pass such a bug because the slices would coincide.
"""
from __future__ import annotations
import contextlib
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
# ``add_messages`` reducer matches production agent state shape and is
# required when two parallel ``Send`` branches both write to ``messages``
# in the same superstep (post-resume both subagents return their own
# ``{"messages": [...]}``). Without a reducer langgraph raises
# ``InvalidUpdateError: At key 'messages': Can receive only one value``.
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
def _build_pausing_subagent(checkpointer: InMemorySaver, *, action_count: int):
"""Subagent that pauses with an ``action_count``-action HITL bundle.
On resume it captures the decision payload as a JSON-serialized
``AIMessage`` content so the test can inspect exactly which slice
reached this subagent the strongest assertion against slicer
routing regressions.
"""
def approve_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": f"act_{i}", "args": {"i": i}, "description": ""}
for i in range(action_count)
],
"review_configs": [{} for _ in range(action_count)],
}
)
return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]}
g = StateGraph(_SubState)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
return g.compile(checkpointer=checkpointer)
def _parent_graph_dispatching_two_tasks_via_send(
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
):
def fanout_edge(_state) -> list[Send]:
return [
Send(
"call_task",
{"tcid": tool_call_id_a, "desc": "approve A", "subtype": "agent-a"},
),
Send(
"call_task",
{"tcid": tool_call_id_b, "desc": "approve B", "subtype": "agent-b"},
),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
def _build_two_subagents_task_tool(checkpointer: InMemorySaver):
"""Register two subagents under distinct names with heterogeneous bundle sizes.
Sub-A: 2-action bundle. Sub-B: 3-action bundle. Both ``> 1`` so the slice
arithmetic is sensitive to off-by-one mistakes.
"""
sub_a = _build_pausing_subagent(checkpointer, action_count=2)
sub_b = _build_pausing_subagent(checkpointer, action_count=3)
return build_task_tool_with_parent_config(
[
{"name": "agent-a", "description": "first", "runnable": sub_a},
{"name": "agent-b", "description": "second", "runnable": sub_b},
]
)
@pytest.mark.asyncio
async def test_parallel_resume_with_command_resume_scalar_raises_lg_runtime_error():
"""Confirm the production failure mode: scalar resume on multi-pending state explodes.
This is a contract pin: if langgraph relaxes the requirement in a future
release, this test starts passing and we know we can simplify
``stream_resume_chat``. Until then, the keyed form is mandatory.
"""
checkpointer = InMemorySaver()
task_tool = _build_two_subagents_task_tool(checkpointer)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "parallel-resume-scalar"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
with pytest.raises(RuntimeError, match="multiple pending interrupts"):
await parent.ainvoke(Command(resume={"decisions": ["A"]}), config)
@pytest.mark.asyncio
async def test_parallel_resume_with_per_interrupt_id_keying_completes_both_subagents():
"""Production-shape resume: builds the langgraph-keyed map and resumes both subagents.
Mirrors what ``stream_resume_chat`` does: collects pending interrupts,
slices the flat decisions list by ``tool_call_id``, builds the
``Interrupt.id``-keyed map for ``Command(resume=...)``, and resumes.
Post-conditions checked:
1. The langgraph-keyed map has exactly one entry per pending interrupt
id (``str`` keys, count matches).
2. Both subagents complete with no leftover pending interrupts.
3. **Each subagent receives its exact slice in the original order**
this catches slicer arithmetic regressions (e.g., ``cursor += 1``)
that wouldn't surface by checking only "no leftover pending".
"""
checkpointer = InMemorySaver()
task_tool = _build_two_subagents_task_tool(checkpointer)
tcid_a = "parent-tcid-A"
tcid_b = "parent-tcid-B"
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a=tcid_a,
tool_call_id_b=tcid_b,
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "parallel-resume-keyed"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused_state = await parent.aget_state(config)
assert len(paused_state.interrupts) == 2, (
"fixture broken: expected 2 paused subagents"
)
pending = collect_pending_tool_calls(paused_state)
assert dict(pending) == {tcid_a: 2, tcid_b: 3}, (
f"fixture broken: heterogeneous bundle sizes not detected; got {pending!r}"
)
a_d0 = {"type": "approve"}
a_d1 = {"type": "reject", "message": "A[1] is redundant"}
b_d0 = {
"type": "edit",
"edited_action": {"name": "act_0", "args": {"i": 0, "edited": True}},
}
b_d1 = {"type": "approve"}
b_d2 = {"type": "reject", "message": "B[2] needs more context"}
flat_decisions = [a_d0, a_d1, b_d0, b_d1, b_d2]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
lg_resume_map = build_lg_resume_map(paused_state, by_tool_call_id)
assert len(lg_resume_map) == 2, (
f"expected one entry per pending interrupt id, got {lg_resume_map!r}"
)
assert all(isinstance(k, str) for k in lg_resume_map), (
f"keys must be Interrupt.id strings, got {[type(k).__name__ for k in lg_resume_map]}"
)
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
await parent.ainvoke(Command(resume=lg_resume_map), config)
final_state = await parent.aget_state(config)
assert not final_state.interrupts, (
f"expected no leftover pending interrupts after resume, got "
f"{final_state.interrupts!r}"
)
payloads: list[dict] = []
for msg in final_state.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
with contextlib.suppress(json.JSONDecodeError):
payloads.append(json.loads(content))
expected_a = {"decisions": [a_d0, a_d1]}
expected_b = {"decisions": [b_d0, b_d1, b_d2]}
assert expected_a in payloads, (
f"REGRESSION: sub-A did not receive its 2-decision slice in order; "
f"payloads seen: {payloads!r}"
)
assert expected_b in payloads, (
f"REGRESSION: sub-B did not receive its 3-decision slice in order; "
f"payloads seen: {payloads!r}"
)
def test_build_lg_resume_map_returns_empty_when_no_interrupts_carry_stamps():
"""Unstamped interrupts can't be routed; we don't fabricate keys for them.
If a regression lets an unstamped interrupt reach the parent state, the
empty map propagates to the call site and surfaces as a clear count
mismatch instead of a silent mis-route.
"""
from types import SimpleNamespace
fake_interrupt = SimpleNamespace(id="i-foreign", value={"action_requests": [{}]})
state = SimpleNamespace(interrupts=(fake_interrupt,))
assert build_lg_resume_map(state, {"some-tcid": {"decisions": ["x"]}}) == {}
def test_build_lg_resume_map_skips_interrupts_without_corresponding_slice():
"""Skip rather than silently mis-route when the slice and interrupts disagree.
Only emit a resume entry when both an interrupt id and a tool_call_id
slice are present; a mismatch indicates upstream contract drift and
should not be papered over.
"""
from types import SimpleNamespace
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"action_requests": [{}], "tool_call_id": "tcid-A"},
),
SimpleNamespace(
id="i-B",
value={"action_requests": [{}], "tool_call_id": "tcid-B"},
),
)
)
out = build_lg_resume_map(state, {"tcid-A": {"decisions": ["only-A"]}})
assert out == {"i-A": {"decisions": ["only-A"]}}

View file

@ -0,0 +1,275 @@
"""Real-graph parallel HITL across both approval kinds — the keystone regression.
Pre-fix bug: the parallel-HITL routing layer (``collect_pending_tool_calls``
+ ``slice_decisions_by_tool_call`` + ``build_lg_resume_map``) only
recognized middleware-gated approvals (LC HITL shape from
``HumanInTheLoopMiddleware``). Self-gated approvals from
``request_approval`` and middleware-gated permission asks from
``PermissionMiddleware`` both used the SurfSense-specific
``{type, action, context}`` shape, so when the orchestrator dispatched
two parallel ``task`` calls one self-gated, one middleware-gated only
one interrupt was visible to the routing layer and resume crashed with
``Decision count mismatch``.
This test fans out two real subagents via ``Send``: one calls
``request_approval`` (self-gated), the other calls
``request_permission_decision`` (middleware-gated). Both pause; the routing
layer must see TWO LC HITL interrupts, slice the decisions by
``tool_call_id``, key by ``Interrupt.id``, and resume both branches with
their per-slice payload.
"""
from __future__ import annotations
import contextlib
import json
from typing import Annotated
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command, Send
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
build_lg_resume_map,
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
request_permission_decision,
)
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
from app.agents.new_chat.permissions import Rule
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
# ``add_messages`` is mandatory: parallel ``Send`` branches both append
# to ``messages`` in the same superstep; without a reducer langgraph
# raises ``InvalidUpdateError``.
messages: Annotated[list, add_messages]
tcid: str
desc: str
subtype: str
def _build_self_gated_subagent(checkpointer: InMemorySaver):
"""Subagent that pauses via ``request_approval`` (self-gated path)."""
def gate_node(_state):
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com"},
)
return {
"messages": [
AIMessage(
content=json.dumps(
{
"kind": "self_gated",
"decision_type": result.decision_type,
"params": result.params,
"rejected": result.rejected,
},
sort_keys=True,
)
)
]
}
g = StateGraph(_SubState)
g.add_node("gate", gate_node)
g.add_edge(START, "gate")
g.add_edge("gate", END)
return g.compile(checkpointer=checkpointer)
def _build_middleware_gated_subagent(checkpointer: InMemorySaver):
"""Subagent that pauses via ``request_permission_decision`` (middleware-gated path)."""
def perm_node(_state):
decision = request_permission_decision(
tool_name="rm",
args={"path": "/tmp/file"},
patterns=["rm/*"],
rules=[Rule(permission="rm", pattern="*", action="ask")],
emit_interrupt=True,
)
return {
"messages": [
AIMessage(
content=json.dumps(
{"kind": "middleware_gated", "decision": decision},
sort_keys=True,
)
)
]
}
g = StateGraph(_SubState)
g.add_node("perm", perm_node)
g.add_edge(START, "perm")
g.add_edge("perm", END)
return g.compile(checkpointer=checkpointer)
def _build_mixed_task_tool(checkpointer: InMemorySaver):
"""Two subagents, one per approval kind, registered under distinct names."""
return build_task_tool_with_parent_config(
[
{
"name": "self-gated-agent",
"description": "uses request_approval",
"runnable": _build_self_gated_subagent(checkpointer),
},
{
"name": "middleware-gated-agent",
"description": "uses request_permission_decision",
"runnable": _build_middleware_gated_subagent(checkpointer),
},
]
)
def _parent_dispatching_one_of_each(
task_tool, *, tcid_self: str, tcid_mw: str, checkpointer
):
def fanout_edge(_state) -> list[Send]:
return [
Send(
"call_task",
{
"tcid": tcid_self,
"desc": "approve email",
"subtype": "self-gated-agent",
},
),
Send(
"call_task",
{
"tcid": tcid_mw,
"desc": "approve rm",
"subtype": "middleware-gated-agent",
},
),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type=state["subtype"], runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_parallel_self_gated_and_middleware_gated_route_and_resume_cleanly():
"""Both interrupt kinds must reach the routing layer in LC HITL shape and resume independently."""
checkpointer = InMemorySaver()
task_tool = _build_mixed_task_tool(checkpointer)
tcid_self = "tcid-self-gated"
tcid_mw = "tcid-middleware-gated"
parent = _parent_dispatching_one_of_each(
task_tool,
tcid_self=tcid_self,
tcid_mw=tcid_mw,
checkpointer=checkpointer,
)
config: dict = {
"configurable": {"thread_id": "mixed-parallel"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
paused = await parent.aget_state(config)
assert len(paused.interrupts) == 2, (
"fixture broken: expected one paused interrupt per approval kind"
)
# Both interrupts must speak the same wire shape — the whole point of
# the unification. If either one regresses to the legacy SurfSense shape
# ``collect_pending_tool_calls`` would silently skip it and the count
# below would be 1.
pending = collect_pending_tool_calls(paused)
assert dict(pending) == {tcid_self: 1, tcid_mw: 1}, (
f"REGRESSION: not all interrupt kinds reached the routing layer; "
f"got {pending!r}"
)
# Verify the actual wire payloads carry the LC HITL standard fields
# (extra defensive assertion against partial regressions where one
# path stamps tool_call_id but reverts the body shape).
interrupt_types = {i.value.get("interrupt_type") for i in paused.interrupts}
assert interrupt_types == {"gmail_email_send", "permission_ask"}
# Resume order: same order the SSE stream would emit (interrupts list).
decision_self = {"type": "approve"}
decision_mw = {"type": "approve_always"}
flat_decisions = [
# Match `pending` order.
decision_self if pending[0][0] == tcid_self else decision_mw,
decision_mw if pending[0][0] == tcid_self else decision_self,
]
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
assert len(lg_resume_map) == 2
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
await parent.ainvoke(Command(resume=lg_resume_map), config)
final = await parent.aget_state(config)
assert not final.interrupts, (
f"expected both branches resumed, but state still has interrupts: "
f"{final.interrupts!r}"
)
# Each subagent must have received its own slice — verify by inspecting
# the JSON-serialized result messages.
payloads: list[dict] = []
for msg in final.values.get("messages", []) or []:
content = getattr(msg, "content", None)
if isinstance(content, str):
with contextlib.suppress(json.JSONDecodeError):
payloads.append(json.loads(content))
self_payloads = [p for p in payloads if p.get("kind") == "self_gated"]
mw_payloads = [p for p in payloads if p.get("kind") == "middleware_gated"]
assert len(self_payloads) == 1, (
f"self-gated subagent did not complete; payloads: {payloads!r}"
)
assert len(mw_payloads) == 1, (
f"middleware-gated subagent did not complete; payloads: {payloads!r}"
)
# Self-gated approve → HITLResult(decision_type="approve", rejected=False).
assert self_payloads[0]["decision_type"] == "approve"
assert self_payloads[0]["rejected"] is False
# Middleware-gated approve_always → canonical permission shape unchanged.
assert mw_payloads[0]["decision"] == {"decision_type": "approve_always"}

View file

@ -0,0 +1,237 @@
"""Behavioural guarantees for parallel ``task`` tool calls (non-HITL cases).
The HITL bridge tests in ``test_hitl_bridge.py`` cover the parallel-interrupt
flow. This file covers the *normal* parallel paths (no interrupts) and the
failure-isolation guarantee together they pin the behaviour we promise the
user about ``asyncio.gather`` over two ``atask`` coroutines.
"""
from __future__ import annotations
import asyncio
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _SubState(TypedDict, total=False):
messages: list
def _build_success_subagent(reply: str):
"""A subagent that completes immediately with ``reply``, never interrupts."""
def node(_state):
return {"messages": [AIMessage(content=reply)]}
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _build_failing_subagent(exc: Exception):
"""A subagent whose only node raises ``exc`` — simulates a tool-level failure."""
def node(_state):
raise exc
g = StateGraph(_SubState)
g.add_node("only", node)
g.add_edge(START, "only")
g.add_edge("only", END)
return g.compile(checkpointer=InMemorySaver())
def _make_runtime(parent_config: dict, *, tool_call_id: str) -> ToolRuntime:
return ToolRuntime(
state={"messages": [HumanMessage(content="seed")]},
context=None,
config=parent_config,
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
def _tool_message_text(cmd: Command, *, expected_tcid: str) -> str:
"""Return the ToolMessage content the task tool produced for ``expected_tcid``."""
assert isinstance(cmd, Command), f"expected Command, got {type(cmd).__name__}"
messages = cmd.update["messages"]
assert len(messages) == 1, f"expected 1 ToolMessage, got {len(messages)}"
msg = messages[0]
assert isinstance(msg, ToolMessage)
assert msg.tool_call_id == expected_tcid
return msg.content
@pytest.mark.asyncio
async def test_two_parallel_atasks_to_different_subagents_both_succeed():
"""Normal happy-path: two distinct subagents complete in parallel without interrupting."""
subagent_a = _build_success_subagent("A is done")
subagent_b = _build_success_subagent("B is done")
task_tool = build_task_tool_with_parent_config(
[
{"name": "alpha", "description": "alpha agent", "runnable": subagent_a},
{"name": "beta", "description": "beta agent", "runnable": subagent_b},
]
)
parent_config: dict = {
"configurable": {"thread_id": "ok-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="do A",
subagent_type="alpha",
runtime=runtime_a,
),
task_tool.coroutine(
description="do B",
subagent_type="beta",
runtime=runtime_b,
),
)
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "A is done"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "B is done"
@pytest.mark.asyncio
async def test_two_parallel_atasks_same_subagent_type_different_tool_call_ids():
"""Per-call ``thread_id`` isolation: same compiled subagent invoked twice in parallel.
Both calls share the same ``InMemorySaver`` instance but are namespaced by
distinct ``tool_call_id``s, so checkpoints land in disjoint thread slots.
"""
shared_subagent = _build_success_subagent("ok")
task_tool = build_task_tool_with_parent_config(
[
{
"name": "approver",
"description": "shared approver",
"runnable": shared_subagent,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "shared-subagent-thread"},
"recursion_limit": 100,
}
runtime_a = _make_runtime(parent_config, tool_call_id="tcid-A")
runtime_b = _make_runtime(parent_config, tool_call_id="tcid-B")
result_a, result_b = await asyncio.gather(
task_tool.coroutine(
description="first request",
subagent_type="approver",
runtime=runtime_a,
),
task_tool.coroutine(
description="second request",
subagent_type="approver",
runtime=runtime_b,
),
)
# Both calls succeed and produce ToolMessages keyed by their own tool_call_id.
assert _tool_message_text(result_a, expected_tcid="tcid-A") == "ok"
assert _tool_message_text(result_b, expected_tcid="tcid-B") == "ok"
# Verify checkpoint isolation: each call's state lives at its own thread_id.
state_a = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-A"}}
)
state_b = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread::task:tcid-B"}}
)
assert state_a.values["messages"][-1].content == "ok"
assert state_b.values["messages"][-1].content == "ok"
# The parent's own thread_id slot is untouched by either subagent.
state_parent = await shared_subagent.aget_state(
{"configurable": {"thread_id": "shared-subagent-thread"}}
)
assert state_parent.values == {} or state_parent.values.get("messages") in (
None,
[],
)
@pytest.mark.asyncio
async def test_one_atask_failure_does_not_corrupt_sibling_atask():
"""Failure isolation: a sibling's exception must not poison the surviving atask's state.
Note: in production, langgraph's pregel runner cancels siblings when any
parallel task raises a non-``GraphBubbleUp`` exception (see
``_should_stop_others`` in ``langgraph/pregel/_runner.py``). At our layer
that policy is invisible what we *can* guarantee is that the two atask
coroutines have disjoint state, so the surviving one returns a valid
Command even when its sibling explodes.
"""
failing_subagent = _build_failing_subagent(ValueError("boom"))
surviving_subagent = _build_success_subagent("still here")
task_tool = build_task_tool_with_parent_config(
[
{
"name": "broken",
"description": "always fails",
"runnable": failing_subagent,
},
{
"name": "healthy",
"description": "always succeeds",
"runnable": surviving_subagent,
},
]
)
parent_config: dict = {
"configurable": {"thread_id": "iso-thread"},
"recursion_limit": 100,
}
runtime_fail = _make_runtime(parent_config, tool_call_id="tcid-fail")
runtime_ok = _make_runtime(parent_config, tool_call_id="tcid-ok")
results = await asyncio.gather(
task_tool.coroutine(
description="will explode",
subagent_type="broken",
runtime=runtime_fail,
),
task_tool.coroutine(
description="will work",
subagent_type="healthy",
runtime=runtime_ok,
),
return_exceptions=True,
)
fail_result, ok_result = results
assert isinstance(fail_result, Exception), (
f"expected the broken subagent to raise, got {fail_result!r}"
)
# ValueError gets wrapped in langgraph's internal exception types — the
# important guarantee is "this path errored", not the specific class.
assert "boom" in str(fail_result) or isinstance(fail_result, ValueError)
assert _tool_message_text(ok_result, expected_tcid="tcid-ok") == "still here"
# Configurable side-channel must not have been corrupted by the failure.
assert "surfsense_resume_value" not in parent_config["configurable"]

View file

@ -0,0 +1,154 @@
"""Slicing helper that routes a flat decisions list to per-tool-call payloads.
The frontend submits ``decisions: list[ResumeDecision]`` in the same order the
SSE stream emitted approval cards. When multiple parallel subagents are paused,
the backend slices that flat list into per-``tool_call_id`` payloads so each
``atask`` reads only its own decisions through ``consume_surfsense_resume``.
The extractor reads ``state.interrupts[i].value["tool_call_id"]`` which is
populated by ``propagation.wrap_with_tool_call_id`` inside ``task_tool``'s
``except GraphInterrupt`` chokepoint whenever a subagent interrupt bubbles up
through ``[a]task`` to build the ordered ``pending`` list the slicer needs.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
collect_pending_tool_calls,
slice_decisions_by_tool_call,
)
class TestSliceDecisionsByToolCall:
def test_splits_flat_decisions_across_two_pending_tool_calls(self):
decisions = [
{"type": "approve"},
{"type": "edit", "edited_action": {"name": "edited-b1"}},
{"type": "reject"},
{"type": "approve"},
{"type": "approve"},
]
pending = [
("tcid-A", 3),
("tcid-B", 2),
]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {
"tcid-A": {"decisions": decisions[0:3]},
"tcid-B": {"decisions": decisions[3:5]},
}
def test_raises_when_decision_count_less_than_total_actions(self):
decisions = [{"type": "approve"}, {"type": "approve"}]
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*2 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_raises_when_decision_count_greater_than_total_actions(self):
decisions = [{"type": "approve"}] * 6
pending = [("tcid-A", 3), ("tcid-B", 2)]
with pytest.raises(ValueError, match=r"5 actions.*6 decisions"):
slice_decisions_by_tool_call(decisions, pending)
def test_handles_single_pending_tool_call(self):
decisions = [{"type": "approve"}, {"type": "reject"}]
pending = [("tcid-only", 2)]
routed = slice_decisions_by_tool_call(decisions, pending)
assert routed == {"tcid-only": {"decisions": decisions}}
def test_returns_empty_dict_for_no_pending(self):
routed = slice_decisions_by_tool_call([], [])
assert routed == {}
def _interrupt_with(tool_call_id: str, action_count: int):
return SimpleNamespace(
id=f"i-{tool_call_id}",
value={
"action_requests": [{"name": "n", "args": {}}] * action_count,
"review_configs": [{}] * action_count,
"tool_call_id": tool_call_id,
},
)
class TestCollectPendingToolCalls:
def test_single_pending_returns_one_pair(self):
state = SimpleNamespace(interrupts=(_interrupt_with("tcid-only", 3),))
assert collect_pending_tool_calls(state) == [("tcid-only", 3)]
def test_multiple_pending_preserves_state_order(self):
"""Order must match what the SSE stream emitted (= state.interrupts order)."""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
_interrupt_with("tcid-B", 3),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 3)]
def test_empty_when_no_interrupts(self):
state = SimpleNamespace(interrupts=())
assert collect_pending_tool_calls(state) == []
def test_skips_interrupts_without_tool_call_id(self):
"""Defensive: interrupts not produced by our propagation layer are ignored.
``stream_resume_chat`` only owns the ``task``-routing slice; non-task
interrupts (e.g. parent-side HITL middleware on a different tool) are
not the slicer's responsibility.
"""
state = SimpleNamespace(
interrupts=(
_interrupt_with("tcid-A", 2),
SimpleNamespace(id="i-foreign", value={"action_requests": [{}]}),
_interrupt_with("tcid-B", 1),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 2), ("tcid-B", 1)]
def test_handles_scalar_value_interrupt(self):
"""Subagents using ``interrupt("approve?")`` style propagate as ``{"value": ..., "tool_call_id": ...}``.
These have no ``action_requests`` count them as a single action so
the frontend submits exactly one decision per such interrupt.
"""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"value": "approve?", "tool_call_id": "tcid-A"},
),
)
)
assert collect_pending_tool_calls(state) == [("tcid-A", 1)]
def test_raises_when_interrupt_value_missing_action_count_keys(self):
"""An interrupt with ``tool_call_id`` but no usable count signals a contract bug."""
state = SimpleNamespace(
interrupts=(
SimpleNamespace(
id="i-A",
value={"tool_call_id": "tcid-A", "weird_shape": True},
),
)
)
with pytest.raises(ValueError, match="action_requests"):
collect_pending_tool_calls(state)

View file

@ -1,4 +1,4 @@
"""Resume side-channel must be read exactly once per turn."""
"""Resume side-channel is keyed per ``tool_call_id`` so parallel siblings can resume independently."""
from __future__ import annotations
@ -10,33 +10,61 @@ from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_mid
)
def _runtime_with_config(config: dict) -> ToolRuntime:
def _runtime_with_config(
config: dict, *, tool_call_id: str = "tcid-test"
) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config,
stream_writer=None,
tool_call_id="tcid-test",
tool_call_id=tool_call_id,
store=None,
)
class TestConsumeSurfsenseResume:
def test_pops_value_on_first_call(self):
def test_pops_only_entry_matching_runtime_tool_call_id(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": {"decisions": ["approve"]}}}
{"configurable": configurable}, tool_call_id="tcid-A"
)
assert consume_surfsense_resume(runtime) == {"decisions": ["approve"]}
def test_second_call_returns_none(self):
configurable: dict = {"surfsense_resume_value": {"decisions": ["approve"]}}
runtime = _runtime_with_config({"configurable": configurable})
def test_popping_one_entry_leaves_siblings_untouched(self):
configurable = {
"surfsense_resume_value": {
"tcid-A": {"decisions": ["approve"]},
"tcid-B": {"decisions": ["reject"]},
}
}
runtime_a = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
consume_surfsense_resume(runtime_a)
assert configurable["surfsense_resume_value"] == {
"tcid-B": {"decisions": ["reject"]}
}
def test_returns_none_when_no_entry_for_this_tool_call(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert consume_surfsense_resume(runtime) is None
assert "surfsense_resume_value" not in configurable
def test_returns_none_when_no_payload_queued(self):
runtime = _runtime_with_config({"configurable": {}})
@ -48,22 +76,57 @@ class TestConsumeSurfsenseResume:
assert consume_surfsense_resume(runtime) is None
def test_drops_empty_dict_after_last_entry_consumed(self):
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
consume_surfsense_resume(runtime)
assert "surfsense_resume_value" not in configurable
class TestHasSurfsenseResume:
def test_true_when_payload_queued(self):
def test_true_when_entry_for_this_tool_call_present(self):
runtime = _runtime_with_config(
{"configurable": {"surfsense_resume_value": "approve"}}
{
"configurable": {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is True
def test_false_when_entry_for_other_tool_call_only(self):
runtime = _runtime_with_config(
{
"configurable": {
"surfsense_resume_value": {"tcid-other": {"decisions": []}}
}
},
tool_call_id="tcid-A",
)
assert has_surfsense_resume(runtime) is False
def test_does_not_consume_payload(self):
configurable = {"surfsense_resume_value": "approve"}
runtime = _runtime_with_config({"configurable": configurable})
configurable = {
"surfsense_resume_value": {"tcid-A": {"decisions": ["approve"]}}
}
runtime = _runtime_with_config(
{"configurable": configurable}, tool_call_id="tcid-A"
)
has_surfsense_resume(runtime)
assert configurable == {"surfsense_resume_value": "approve"}
assert configurable["surfsense_resume_value"] == {
"tcid-A": {"decisions": ["approve"]}
}
def test_false_when_payload_absent(self):
runtime = _runtime_with_config({"configurable": {}})

View file

@ -0,0 +1,284 @@
"""Production-shape regression tests for ``tool_call_id`` stamping on subagent interrupts.
The production bug we're pinning here: when the orchestrator dispatches one or
more ``task`` tool calls and the targeted subagents hit a HITL ``interrupt(...)``,
the parent's persisted ``state.interrupts`` must carry the parent's
``tool_call_id`` on each interrupt value. Without that stamp,
``stream_resume_chat`` cannot route a flat ``decisions`` list back to the right
paused subagent and resume fails with ``Decision count mismatch``.
The tests in this module:
- Build a **real** ``StateGraph`` subagent that calls real ``interrupt(...)``
(no MagicMock, no patch of langgraph internals those are exactly the kind
of fakes that hid this bug).
- Invoke the ``task`` tool from **inside a parent pregel** (via a tiny parent
``StateGraph`` node) so the subagent invocation happens in the
production-shape "subgraph called from a parent tool node" context.
- Assert on ``parent.state.interrupts[*].value["tool_call_id"]`` the
observable that ``stream_resume_chat`` reads.
"""
from __future__ import annotations
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
class _S(TypedDict, total=False):
messages: list
def _build_single_interrupt_subagent(checkpointer: InMemorySaver):
"""Subagent that fires one HITL-bundle-shaped interrupt and waits for a decision."""
def approve_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": "do_thing", "args": {"x": 1}, "description": ""}
],
"review_configs": [{}],
}
)
return {"messages": [AIMessage(content=f"got:{decision}")]}
g = StateGraph(_S)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
return g.compile(checkpointer=checkpointer)
def _build_bundle_subagent(checkpointer: InMemorySaver):
"""Subagent that fires one interrupt carrying a 3-action bundle."""
def bundle_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": "a", "args": {}, "description": ""},
{"name": "b", "args": {}, "description": ""},
{"name": "c", "args": {}, "description": ""},
],
"review_configs": [{}, {}, {}],
}
)
return {"messages": [AIMessage(content=f"bundle:{decision}")]}
g = StateGraph(_S)
g.add_node("bundle", bundle_node)
g.add_edge(START, "bundle")
g.add_edge("bundle", END)
return g.compile(checkpointer=checkpointer)
def _parent_graph_calling_task(task_tool, *, tool_call_id: str, checkpointer):
"""A tiny parent graph whose only node invokes ``task_tool`` from inside the pregel runtime.
This is the minimal reproduction of production's "subagent invoked from
inside a parent tool node" context — the *only* context where langgraph
treats the subagent as a subgraph and routes its interrupts back to the
parent's checkpoint.
"""
async def call_task(state, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
return await task_tool.coroutine(
description="please approve",
subagent_type="approver",
runtime=rt,
)
g = StateGraph(_S)
g.add_node("call_task", call_task)
g.add_edge(START, "call_task")
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
class _DispatchState(TypedDict, total=False):
messages: list
tcid: str
desc: str
def _parent_graph_dispatching_two_tasks_via_send(
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
):
"""A parent graph that dispatches two ``task`` calls as parallel pregel
tasks via :class:`~langgraph.types.Send`.
This mirrors the production dispatch mechanism: when the orchestrator's
LLM emits two ``task`` tool calls in one turn, langchain's tool node
fans them out as parallel pregel tasks (the same primitive as ``Send``)
so each tool call gets its own pregel task that can raise
``GraphInterrupt`` independently and pregel collects *all* of them
into the parent's snapshot at the end of the superstep.
"""
def fanout_edge(_state) -> list[Send]:
return [
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type="approver", runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
def _parent_interrupt_values(snapshot) -> list[dict]:
"""Extract ``state.interrupts[*].value`` for assertions."""
return [i.value for i in (snapshot.interrupts or ())]
@pytest.mark.asyncio
async def test_single_subagent_interrupt_stamps_parent_tool_call_id():
"""A single paused subagent must surface to the parent with ``tool_call_id`` stamped.
Production bug regression: was producing
``value={"action_requests": [...], "review_configs": [...]}`` (no
``tool_call_id``), causing ``stream_resume_chat`` to skip the interrupt
and raise ``Decision count mismatch``.
"""
checkpointer = InMemorySaver()
subagent = _build_single_interrupt_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_calling_task(
task_tool, tool_call_id="parent-tcid-A", checkpointer=checkpointer
)
parent_config = {
"configurable": {"thread_id": "parent-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await parent.aget_state(parent_config)
values = _parent_interrupt_values(snap)
assert len(values) == 1, (
f"expected exactly 1 parent interrupt, got {len(values)}: {values!r}"
)
value = values[0]
assert isinstance(value, dict)
assert value.get("tool_call_id") == "parent-tcid-A", (
f"REGRESSION: parent interrupt missing/wrong tool_call_id stamp. "
f"Expected 'parent-tcid-A', got {value.get('tool_call_id')!r}. "
f"Keys present: {sorted(value.keys())}"
)
# The original HITL payload must still be intact alongside the stamp.
assert value.get("action_requests") == [
{"name": "do_thing", "args": {"x": 1}, "description": ""}
]
@pytest.mark.asyncio
async def test_two_parallel_subagents_each_stamp_their_own_tool_call_id():
"""Two ``task`` calls dispatched in parallel must each carry their own ``tool_call_id``.
This is the actual production scenario (Linear + Jira ticket creation):
two parallel ``task`` tool calls, both subagents hit HITL, parent must
end up with two interrupts whose ``tool_call_id``s match the two
distinct parent-level ``tool_call_id``s the LLM emitted.
"""
checkpointer = InMemorySaver()
subagent = _build_single_interrupt_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
parent_config = {
"configurable": {"thread_id": "parent-thread-parallel"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await parent.aget_state(parent_config)
values = _parent_interrupt_values(snap)
assert len(values) == 2, (
f"expected 2 parent interrupts (one per parallel task call), "
f"got {len(values)}: {values!r}"
)
stamps = {v.get("tool_call_id") for v in values}
assert stamps == {"parent-tcid-A", "parent-tcid-B"}, (
f"REGRESSION: parallel parent interrupts missing/wrong tool_call_id stamps. "
f"Expected {{'parent-tcid-A', 'parent-tcid-B'}}, got {stamps!r}. "
f"Values: {values!r}"
)
@pytest.mark.asyncio
async def test_bundle_subagent_interrupt_stamps_tool_call_id_preserving_actions():
"""A subagent emitting a multi-action bundle must surface stamped, with all actions intact.
The bundle shape (``action_requests=[3 items]``) drives the
``slice_decisions_by_tool_call`` accounting in ``stream_resume_chat``
if either the stamp or the action count is lost, resume routing
miscounts and crashes.
"""
checkpointer = InMemorySaver()
subagent = _build_bundle_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_calling_task(
task_tool, tool_call_id="parent-tcid-bundle", checkpointer=checkpointer
)
parent_config = {
"configurable": {"thread_id": "parent-thread-bundle"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
snap = await parent.aget_state(parent_config)
values = _parent_interrupt_values(snap)
assert len(values) == 1
value = values[0]
assert value.get("tool_call_id") == "parent-tcid-bundle"
assert isinstance(value.get("action_requests"), list)
assert len(value["action_requests"]) == 3, (
f"REGRESSION: bundle action_requests count changed during stamping; "
f"got {len(value['action_requests'])} actions: {value['action_requests']!r}"
)

View file

@ -0,0 +1,94 @@
"""Per-call ``thread_id`` derivation for nested subagent invocations.
Parallel ``task`` (and ``ask_knowledge_base``) calls must land in disjoint
checkpoint slots so their nested pregel runs do not stomp on each other or on
the parent's checkpoint state. The slot key is derived from the runtime's
``tool_call_id`` so the same call across the resume cycle keeps reading from
the same snapshot.
Note: we namespace via ``thread_id`` rather than ``checkpoint_ns`` because
langgraph's ``aget_state`` interprets a non-empty ``checkpoint_ns`` as a
subgraph path and raises ``ValueError("Subgraph X not found")``. ``thread_id``
is the primary checkpoint key and is free-form, so it's the right primitive.
"""
from __future__ import annotations
from langchain.tools import ToolRuntime
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.config import (
subagent_invoke_config,
)
def _runtime(*, tool_call_id: str, config: dict | None = None) -> ToolRuntime:
return ToolRuntime(
state=None,
context=None,
config=config or {},
stream_writer=None,
tool_call_id=tool_call_id,
store=None,
)
class TestSubagentInvokeThreadId:
def test_sets_per_call_thread_id_under_parent(self):
runtime = _runtime(
tool_call_id="tcid-A",
config={"configurable": {"thread_id": "t1"}},
)
sub_config = subagent_invoke_config(runtime)
assert sub_config["configurable"]["thread_id"] == "t1::task:tcid-A"
def test_per_call_thread_id_nests_under_already_namespaced_parent(self):
"""A subagent that itself spawns a subagent must keep nesting cleanly."""
runtime = _runtime(
tool_call_id="tcid-inner",
config={
"configurable": {
"thread_id": "t1::task:tcid-outer",
}
},
)
sub_config = subagent_invoke_config(runtime)
assert (
sub_config["configurable"]["thread_id"]
== "t1::task:tcid-outer::task:tcid-inner"
)
def test_different_tool_call_ids_produce_different_thread_ids(self):
config = {"configurable": {"thread_id": "t1"}}
rt_a = _runtime(tool_call_id="tcid-A", config=config)
rt_b = _runtime(tool_call_id="tcid-B", config=config)
tid_a = subagent_invoke_config(rt_a)["configurable"]["thread_id"]
tid_b = subagent_invoke_config(rt_b)["configurable"]["thread_id"]
assert tid_a != tid_b
def test_same_tool_call_id_produces_same_thread_id_across_repeated_calls(self):
"""Resume bridge needs to find the snapshot it primed earlier."""
config = {"configurable": {"thread_id": "t1"}}
rt_first = _runtime(tool_call_id="tcid-A", config=config)
rt_second = _runtime(tool_call_id="tcid-A", config=config)
tid_first = subagent_invoke_config(rt_first)["configurable"]["thread_id"]
tid_second = subagent_invoke_config(rt_second)["configurable"]["thread_id"]
assert tid_first == tid_second
def test_does_not_mutate_caller_config(self):
"""Repeated calls must not accumulate suffixes onto the parent's config."""
original_thread_id = "t1"
config = {"configurable": {"thread_id": original_thread_id}}
runtime = _runtime(tool_call_id="tcid-A", config=config)
subagent_invoke_config(runtime)
subagent_invoke_config(runtime)
assert config["configurable"]["thread_id"] == original_thread_id

View file

@ -0,0 +1,125 @@
"""Regression: ``request_permission_decision`` must emit the unified LC HITL wire shape.
Same bug class as :mod:`test_lc_hitl_wire` for self-gated approvals: the
permission middleware previously fired the SurfSense-specific
``{type, action, context}`` shape, which the parallel-HITL routing layer
does not recognize. Standardizing on LC HITL keeps every approval kind on
one routing path.
"""
from __future__ import annotations
import pytest
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.request import (
request_permission_decision,
)
from app.agents.new_chat.permissions import Rule
class _State(TypedDict, total=False):
messages: list
final_decision: dict
def _build_graph_calling_request_permission_decision(checkpointer: InMemorySaver):
"""Real graph whose only node delegates to the permission ask primitive."""
def perm_node(_state):
decision = request_permission_decision(
tool_name="rm",
args={"path": "/tmp/file"},
patterns=["rm/*"],
rules=[Rule(permission="rm", pattern="*", action="ask")],
emit_interrupt=True,
)
return {"final_decision": decision}
g = StateGraph(_State)
g.add_node("perm", perm_node)
g.add_edge(START, "perm")
g.add_edge("perm", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_permission_ask_payload_uses_lc_hitl_shape():
"""The permission middleware now speaks the langchain HITL standard shape."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-wire"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1
value = snap.interrupts[0].value
assert value.get("action_requests") == [
{"name": "rm", "args": {"path": "/tmp/file"}}
], f"REGRESSION: permission ask reverted to legacy shape; got {value!r}"
review = value.get("review_configs")
assert isinstance(review, list) and len(review) == 1
palette = review[0]["allowed_decisions"]
# Native tool (no ``tool=`` argument): the palette must include the
# once/reject/edit triad. ``approve_always`` is gated on MCP-ness and
# therefore *omitted* here — palette content per tool kind is
# exercised in ``test_permission_ask_mcp_context``.
assert "approve" in palette and "reject" in palette and "edit" in palette
assert value.get("interrupt_type") == "permission_ask"
# SurfSense context rides through verbatim for FE explainability.
assert value["context"]["patterns"] == ["rm/*"]
assert value["context"]["always"] == ["rm/*"]
@pytest.mark.asyncio
async def test_resume_with_approve_envelope_returns_once_decision():
"""``approve`` from the LC envelope projects to permission-domain ``once``."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-once"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {"decision_type": "once"}
@pytest.mark.asyncio
async def test_resume_with_approve_always_envelope_projects_unchanged():
"""``approve_always`` reply must project unchanged so the middleware can promote the rule."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-approve-always"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve_always"}]}), config
)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {"decision_type": "approve_always"}
@pytest.mark.asyncio
async def test_resume_with_reject_and_feedback_carries_feedback_through():
"""Reject feedback must survive normalization for ``CorrectedError`` to fire downstream."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_permission_decision(checkpointer)
config = {"configurable": {"thread_id": "perm-reject"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(
resume={"decisions": [{"type": "reject", "feedback": "use the trash bin"}]}
),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision") == {
"decision_type": "reject",
"feedback": "use the trash bin",
}

View file

@ -0,0 +1,232 @@
"""Permission-ask payload surfaces tool metadata for the FE card."""
from __future__ import annotations
from typing import Annotated, Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions import (
build_permission_mw,
)
from app.agents.multi_agent_chat.middleware.shared.permissions.ask.payload import (
build_permission_ask_payload,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
class _NoArgs(BaseModel):
pass
async def _noop(**_kwargs) -> str:
return ""
def _ask_rule(tool_name: str) -> Rule:
return Rule(permission=tool_name, pattern="*", action="ask")
def _make_mcp_tool(*, name: str, connector_id: int, connector_name: str):
return StructuredTool(
name=name,
description=f"Run {name} via MCP.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={
"mcp_connector_id": connector_id,
"mcp_connector_name": connector_name,
"mcp_transport": "http",
"hitl": True,
},
)
def test_payload_surfaces_mcp_fields_from_tool():
tool = _make_mcp_tool(
name="linear_create_issue", connector_id=42, connector_name="Linear (acme)"
)
payload = build_permission_ask_payload(
tool_name=tool.name,
args={"title": "bug"},
patterns=[tool.name],
rules=[_ask_rule(tool.name)],
tool=tool,
)
ctx = payload["context"]
assert ctx["mcp_connector_id"] == 42
assert ctx["mcp_server"] == "Linear (acme)"
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
def test_payload_omits_tool_fields_when_tool_is_none():
payload = build_permission_ask_payload(
tool_name="rm",
args={"path": "/tmp/x"},
patterns=["rm"],
rules=[_ask_rule("rm")],
tool=None,
)
ctx = payload["context"]
assert "mcp_connector_id" not in ctx
assert "mcp_server" not in ctx
assert "tool_description" not in ctx
def test_palette_includes_approve_always_for_mcp_tool():
"""Saving to the connector's trusted-tools list is only possible for MCP tools."""
tool = _make_mcp_tool(
name="linear_create_issue", connector_id=42, connector_name="Linear"
)
palette = build_permission_ask_payload(
tool_name=tool.name,
args={},
patterns=[tool.name],
rules=[_ask_rule(tool.name)],
tool=tool,
)["review_configs"][0]["allowed_decisions"]
assert "approve_always" in palette
def test_palette_excludes_approve_always_for_native_tool():
"""Native tools have no place to persist trust, so don't offer the button."""
native = StructuredTool(
name="rm",
description="Remove a file.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={"hitl": True},
)
palette = build_permission_ask_payload(
tool_name=native.name,
args={"path": "/tmp/x"},
patterns=[native.name],
rules=[_ask_rule(native.name)],
tool=native,
)["review_configs"][0]["allowed_decisions"]
assert "approve_always" not in palette
assert palette == ["approve", "reject", "edit"]
def test_palette_excludes_approve_always_when_tool_is_none():
"""Without a tool object the middleware can't tell — fall back to the safe triad."""
palette = build_permission_ask_payload(
tool_name="rm",
args={"path": "/tmp/x"},
patterns=["rm"],
rules=[_ask_rule("rm")],
tool=None,
)["review_configs"][0]["allowed_decisions"]
assert palette == ["approve", "reject", "edit"]
def test_payload_omits_falsy_mcp_metadata_fields():
tool = StructuredTool(
name="anon_tool",
description="",
coroutine=_noop,
args_schema=_NoArgs,
metadata={"mcp_connector_id": None, "mcp_connector_name": ""},
)
ctx = build_permission_ask_payload(
tool_name=tool.name,
args={},
patterns=[tool.name],
rules=[_ask_rule(tool.name)],
tool=tool,
)["context"]
assert "mcp_connector_id" not in ctx
assert "mcp_server" not in ctx
assert "tool_description" not in ctx
class _State(TypedDict, total=False):
messages: Annotated[list, add_messages]
def _emit_tool_call(tool_name: str, args: dict[str, Any], call_id: str):
def _node(_state: _State) -> dict[str, Any]:
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": args,
"id": call_id,
"type": "tool_call",
}
],
)
]
}
return _node
def _compile_graph_with(pm, tool_name: str, args: dict[str, Any], call_id: str):
def after(state: _State) -> dict[str, Any] | None:
return pm.after_model(state, None) # type: ignore[arg-type]
g = StateGraph(_State)
g.add_node("emit", _emit_tool_call(tool_name, args, call_id))
g.add_node("permission", after)
g.add_edge(START, "emit")
g.add_edge("emit", "permission")
g.add_edge("permission", END)
return g.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_middleware_decorates_interrupt_with_mcp_tool_metadata():
tool = _make_mcp_tool(
name="linear_create_issue", connector_id=7, connector_name="Linear"
)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[
Ruleset(origin="linear", rules=[_ask_rule(tool.name)]),
],
tools=[tool],
)
assert pm is not None
graph = _compile_graph_with(pm, tool.name, {"title": "bug"}, "call-1")
config = {"configurable": {"thread_id": "linear-ask"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1
ctx = snap.interrupts[0].value["context"]
assert ctx["mcp_connector_id"] == 7
assert ctx["mcp_server"] == "Linear"
assert ctx["tool_description"] == "Run linear_create_issue via MCP."
@pytest.mark.asyncio
async def test_middleware_without_tool_index_still_asks_without_tool_fields():
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule("rm")])],
)
assert pm is not None
graph = _compile_graph_with(pm, "rm", {"path": "/tmp/foo"}, "call-rm")
config = {"configurable": {"thread_id": "kb-rm"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1
ctx = snap.interrupts[0].value["context"]
assert "mcp_connector_id" not in ctx
assert "mcp_server" not in ctx
assert "tool_description" not in ctx

View file

@ -0,0 +1,166 @@
"""Regression: subagent-owned rulesets layer cleanly into ``PermissionMiddleware``.
The KB unification swap (legacy ``interrupt_on`` map KB-owned ``Ruleset``
threaded through ``build_permission_mw(subagent_rulesets=...)``) must
produce *exactly one* interrupt per destructive FS call, in LC HITL
shape, even when ``enable_permission`` is False destructive ops always
ask.
We exercise the production factory and a real ``PermissionMiddleware`` on a
real ``StateGraph`` so the test catches regressions in factory gating,
ruleset layering, and interrupt emission together.
"""
from __future__ import annotations
from typing import Annotated, Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions import (
build_permission_mw,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
def _kb_style_ruleset() -> Ruleset:
"""Mirror :data:`knowledge_base.agent.KB_RULESET` without importing it.
Importing the agent module pulls in deepagents and prompts; this test
is about the factory + middleware contract, not KB wiring.
"""
return Ruleset(
origin="knowledge_base",
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
)
class _State(TypedDict, total=False):
messages: Annotated[list, add_messages]
def _build_graph_with_permission_middleware(
*,
flags: AgentFeatureFlags,
subagent_rulesets: list[Ruleset] | None,
checkpointer: InMemorySaver,
):
"""Compile a one-node graph that emits a tool call for ``rm`` and
routes through the production ``PermissionMiddleware``.
The node returns an ``AIMessage`` with a tool call. The middleware's
``after_model`` hook intercepts and (if a rule says ``ask``) raises
a ``GraphInterrupt`` carrying the LC HITL payload.
"""
pm = build_permission_mw(flags=flags, subagent_rulesets=subagent_rulesets)
def node(_state: _State) -> dict[str, Any]:
msg = AIMessage(
content="",
tool_calls=[
{
"name": "rm",
"args": {"path": "/tmp/foo"},
"id": "call-rm-1",
"type": "tool_call",
}
],
)
return {"messages": [msg]}
def after_node(state: _State) -> dict[str, Any] | None:
if pm is None:
return None
return pm.after_model(state, None) # type: ignore[arg-type]
g = StateGraph(_State)
g.add_node("emit", node)
g.add_node("permission", after_node)
g.add_edge(START, "emit")
g.add_edge("emit", "permission")
g.add_edge("permission", END)
return g.compile(checkpointer=checkpointer), pm
@pytest.mark.asyncio
async def test_kb_ruleset_raises_one_lc_hitl_ask_for_rm_even_when_permission_flag_off():
"""KB ruleset: ``rm`` must ask once even with ``enable_permission=False``.
This is the keystone of the unification: the legacy ``interrupt_on``
map fired regardless of ``enable_permission``, so the migrated rules
must too. Otherwise users could opt out of "ask before rm".
"""
flags = AgentFeatureFlags(enable_permission=False)
checkpointer = InMemorySaver()
graph, pm = _build_graph_with_permission_middleware(
flags=flags,
subagent_rulesets=[_kb_style_ruleset()],
checkpointer=checkpointer,
)
assert pm is not None, "subagent rulesets must force the middleware on"
config = {"configurable": {"thread_id": "kb-cloud-rm"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1, (
f"REGRESSION: KB ruleset should raise exactly one interrupt; got "
f"{[i.value for i in snap.interrupts]!r}"
)
payload = snap.interrupts[0].value
requests = payload.get("action_requests")
assert requests == [{"name": "rm", "args": {"path": "/tmp/foo"}}], (
f"interrupt must carry the rm call in LC HITL shape; got {payload!r}"
)
assert payload.get("interrupt_type") == "permission_ask"
@pytest.mark.asyncio
async def test_kb_ruleset_resume_with_approve_lets_rm_through():
"""Resume with ``approve`` → call kept; the model continues normally."""
flags = AgentFeatureFlags(enable_permission=False)
checkpointer = InMemorySaver()
graph, _ = _build_graph_with_permission_middleware(
flags=flags,
subagent_rulesets=[_kb_style_ruleset()],
checkpointer=checkpointer,
)
config = {"configurable": {"thread_id": "kb-cloud-rm-approve"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
final = await graph.aget_state(config)
assert final.next == (), "graph must complete after approve"
last_ai = next(
(m for m in reversed(final.values["messages"]) if isinstance(m, AIMessage)),
None,
)
assert last_ai is not None
assert [tc["name"] for tc in last_ai.tool_calls] == ["rm"], (
"approved rm call must remain on the AIMessage so the tool can run"
)
@pytest.mark.asyncio
async def test_no_subagent_rulesets_with_permission_off_skips_middleware_entirely():
"""No subagent rulesets + permission off → factory returns ``None`` (no engine).
The legacy gating is preserved when no caller asks for rules: nothing
runs, nothing pauses.
"""
flags = AgentFeatureFlags(enable_permission=False)
pm = build_permission_mw(flags=flags, subagent_rulesets=None)
assert pm is None

View file

@ -0,0 +1,186 @@
"""``approve_always`` decisions for MCP tools are saved via the trusted-tool saver."""
from __future__ import annotations
from typing import Annotated, Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Command
from pydantic import BaseModel
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.shared.permissions import (
build_permission_mw,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset
class _NoArgs(BaseModel):
pass
async def _noop(**_kwargs) -> str:
return ""
def _ask_rule(tool_name: str) -> Rule:
return Rule(permission=tool_name, pattern="*", action="ask")
def _make_mcp_tool(*, name: str, connector_id: int):
return StructuredTool(
name=name,
description=f"Run {name} via MCP.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={
"mcp_connector_id": connector_id,
"mcp_connector_name": "Linear",
"mcp_transport": "http",
"hitl": True,
},
)
def _make_native_tool(*, name: str):
return StructuredTool(
name=name,
description=f"Native {name}.",
coroutine=_noop,
args_schema=_NoArgs,
metadata={"hitl": True},
)
class _State(TypedDict, total=False):
messages: Annotated[list, add_messages]
def _build_graph(pm, tool_name: str):
def emit(_state: _State) -> dict[str, Any]:
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": tool_name,
"args": {},
"id": "call-1",
"type": "tool_call",
}
],
)
]
}
g = StateGraph(_State)
g.add_node("emit", emit)
g.add_node("permission", pm.aafter_model) # type: ignore[arg-type]
g.add_edge(START, "emit")
g.add_edge("emit", "permission")
g.add_edge("permission", END)
return g.compile(checkpointer=InMemorySaver())
@pytest.mark.asyncio
async def test_approve_always_decision_saves_mcp_tool_via_callback():
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "approve-always-mcp"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve_always"}]}), config
)
assert saved == [(7, "linear_create_issue")]
@pytest.mark.asyncio
async def test_once_decision_does_not_save():
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "once-mcp"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config)
assert saved == []
@pytest.mark.asyncio
async def test_approve_always_decision_for_native_tool_skips_save():
"""Native tools have no ``mcp_connector_id`` so there is nowhere to persist trust."""
saved: list[tuple[int, str]] = []
async def trusted_tool_saver(connector_id: int, tool_name: str) -> None:
saved.append((connector_id, tool_name))
tool = _make_native_tool(name="rm")
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="kb", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=trusted_tool_saver,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "approve-always-native"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve_always"}]}), config
)
assert saved == []
@pytest.mark.asyncio
async def test_approve_always_decision_with_no_saver_callback_is_a_noop():
"""Anonymous turns build the middleware without a ``trusted_tool_saver``; must not crash."""
tool = _make_mcp_tool(name="linear_create_issue", connector_id=7)
pm = build_permission_mw(
flags=AgentFeatureFlags(enable_permission=False),
subagent_rulesets=[Ruleset(origin="linear", rules=[_ask_rule(tool.name)])],
tools=[tool],
trusted_tool_saver=None,
)
assert pm is not None
graph = _build_graph(pm, tool.name)
config = {"configurable": {"thread_id": "anon-approve-always"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "approve_always"}]}), config
)

View file

@ -0,0 +1,132 @@
"""Regression: ``request_approval`` must emit the unified LC HITL wire shape.
Before this fix, self-gated approvals fired the SurfSense-specific
``{type, action, context}`` shape which the parallel-HITL routing layer
(``collect_pending_tool_calls``) does not recognize. In a parallel HITL
scenario where one subagent used self-gated approvals (e.g. Gmail send)
and another used middleware-gated approvals (e.g. Linear via
``HumanInTheLoopMiddleware``), the routing layer would silently skip the
self-gated interrupt and crash on resume with ``Decision count mismatch``.
This test pins the wire contract by running ``request_approval`` inside a
real ``StateGraph`` and asserting the paused parent observes the LC HITL
shape (``action_requests``, ``review_configs``, ``interrupt_type``).
"""
from __future__ import annotations
import pytest
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Command
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.subagents.shared.hitl.approvals.self_gated import (
request_approval,
)
class _State(TypedDict, total=False):
messages: list
final_decision_type: str
final_params: dict
def _build_graph_calling_request_approval(checkpointer: InMemorySaver):
"""A real graph whose only node delegates to ``request_approval``."""
def gate_node(_state):
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com", "subject": "hi"},
context={"account": "alice@gmail.com"},
)
return {
"final_decision_type": result.decision_type,
"final_params": result.params,
}
g = StateGraph(_State)
g.add_node("gate", gate_node)
g.add_edge(START, "gate")
g.add_edge("gate", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_paused_interrupt_uses_lc_hitl_action_requests_shape():
"""The paused interrupt must speak the langchain HITL standard shape."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-wire"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
snap = await graph.aget_state(config)
assert len(snap.interrupts) == 1, (
f"expected one paused interrupt, got {len(snap.interrupts)}"
)
value = snap.interrupts[0].value
assert isinstance(value, dict)
# Standard LC HITL fields the routing layer reads.
assert value.get("action_requests") == [
{
"name": "send_gmail_email",
"args": {"to": "alice@example.com", "subject": "hi"},
}
], (
"REGRESSION: self-gated approval reverted to legacy SurfSense shape; "
f"got {value!r}"
)
assert value.get("review_configs") == [
{
"action_name": "send_gmail_email",
"allowed_decisions": ["approve", "reject", "edit"],
}
]
assert value.get("interrupt_type") == "gmail_email_send", (
"FE card discriminator must travel as ``interrupt_type``."
)
assert value.get("context") == {"account": "alice@gmail.com"}
@pytest.mark.asyncio
async def test_resume_with_lc_envelope_returns_hitl_result_with_edited_args():
"""Edit reply via the LC envelope must round-trip into ``HITLResult.params``."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-resume"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
edited = {"to": "alice@example.com", "subject": "EDITED"}
await graph.ainvoke(
Command(
resume={
"decisions": [
{"type": "edit", "edited_action": {"args": {"subject": "EDITED"}}}
]
}
),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision_type") == "edit"
assert final.values.get("final_params") == edited
@pytest.mark.asyncio
async def test_reject_envelope_returns_rejected_hitl_result():
"""Reject reply must surface as ``HITLResult.rejected=True`` without invoking the tool."""
checkpointer = InMemorySaver()
graph = _build_graph_calling_request_approval(checkpointer)
config = {"configurable": {"thread_id": "self-gated-reject"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
await graph.ainvoke(
Command(resume={"decisions": [{"type": "reject", "feedback": "no"}]}),
config,
)
final = await graph.aget_state(config)
assert final.values.get("final_decision_type") == "reject"

View file

@ -0,0 +1,170 @@
"""Unit contract for the unified LC HITL wire format.
Both the self-gated approval primitive (``request_approval``) and the
middleware-gated permission ask (``PermissionMiddleware``) must serialize
to the same wire shape so the parallel-HITL routing layer
(``collect_pending_tool_calls`` + ``slice_decisions_by_tool_call`` +
``build_lg_resume_map``) sees one format.
These tests pin the shape:
- Builder always emits ``action_requests`` (1 entry) + ``review_configs``
+ ``interrupt_type``; ``context`` rides through verbatim when present.
- Parser tolerates the standard LC envelope, bare scalar strings, and
unrecognized shapes (failing closed to ``reject``).
- Edited args round-trip through both nested (``edited_action.args``) and
flat (``args``) shapes without inventing values for the empty case.
"""
from __future__ import annotations
from app.agents.multi_agent_chat.subagents.shared.hitl.wire import (
LC_DECISION_APPROVE,
LC_DECISION_EDIT,
LC_DECISION_REJECT,
SURFSENSE_DECISION_APPROVE_ALWAYS,
build_lc_hitl_payload,
parse_lc_envelope,
)
class TestBuildLcHitlPayload:
def test_minimal_payload_has_one_action_request_and_one_review_config(self):
payload = build_lc_hitl_payload(
tool_name="send_email",
args={"to": "x@y.z"},
allowed_decisions=[LC_DECISION_APPROVE, LC_DECISION_REJECT],
interrupt_type="gmail_email_send",
)
assert payload["action_requests"] == [
{"name": "send_email", "args": {"to": "x@y.z"}}
]
assert payload["review_configs"] == [
{
"action_name": "send_email",
"allowed_decisions": [LC_DECISION_APPROVE, LC_DECISION_REJECT],
}
]
assert payload["interrupt_type"] == "gmail_email_send"
assert "context" not in payload, "context must be omitted when not provided"
def test_none_args_normalized_to_empty_dict(self):
"""FE expects a stable shape; ``None`` would crash card rendering."""
payload = build_lc_hitl_payload(
tool_name="ping",
args=None, # type: ignore[arg-type]
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="self_gated",
)
assert payload["action_requests"][0]["args"] == {}
def test_description_attached_only_when_provided(self):
with_desc = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="x",
description="please review",
)
without = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=[LC_DECISION_APPROVE],
interrupt_type="x",
)
assert with_desc["action_requests"][0]["description"] == "please review"
assert "description" not in without["action_requests"][0]
def test_context_passed_through_verbatim(self):
ctx = {"patterns": ["rm/*"], "rules": [], "always": ["rm/*"]}
payload = build_lc_hitl_payload(
tool_name="rm",
args={"path": "/tmp"},
allowed_decisions=[
LC_DECISION_APPROVE,
LC_DECISION_REJECT,
SURFSENSE_DECISION_APPROVE_ALWAYS,
],
interrupt_type="permission_ask",
context=ctx,
)
assert payload["context"] == ctx
def test_allowed_decisions_list_is_copied_not_aliased(self):
"""A caller mutating their original list must not corrupt the payload."""
decisions = [LC_DECISION_APPROVE]
payload = build_lc_hitl_payload(
tool_name="t",
args={},
allowed_decisions=decisions,
interrupt_type="x",
)
decisions.append(LC_DECISION_REJECT)
assert payload["review_configs"][0]["allowed_decisions"] == [
LC_DECISION_APPROVE
]
class TestParseLcEnvelope:
def test_standard_lc_envelope_returns_typed_decision(self):
parsed = parse_lc_envelope({"decisions": [{"type": "approve"}]})
assert parsed.decision_type == "approve"
assert parsed.edited_args is None
assert parsed.message is None
def test_bare_scalar_string_passes_through_lowercased(self):
assert parse_lc_envelope("APPROVE_ALWAYS").decision_type == "approve_always"
assert parse_lc_envelope("once").decision_type == "once"
def test_non_dict_non_string_collapses_to_reject(self):
"""Failing closed: ambiguous input must never proceed."""
assert parse_lc_envelope(42).decision_type == "reject"
assert parse_lc_envelope(None).decision_type == "reject"
assert parse_lc_envelope(["bogus"]).decision_type == "reject"
def test_missing_decision_type_collapses_to_reject(self):
assert parse_lc_envelope({"decisions": [{}]}).decision_type == "reject"
assert parse_lc_envelope({"foo": "bar"}).decision_type == "reject"
def test_edit_extracts_nested_args(self):
parsed = parse_lc_envelope(
{
"decisions": [
{
"type": LC_DECISION_EDIT,
"edited_action": {"args": {"to": "edited@y.z"}},
}
]
}
)
assert parsed.decision_type == "edit"
assert parsed.edited_args == {"to": "edited@y.z"}
def test_edit_falls_back_to_flat_args(self):
parsed = parse_lc_envelope(
{"decisions": [{"type": "edit", "args": {"k": "v"}}]}
)
assert parsed.edited_args == {"k": "v"}
def test_edit_with_empty_args_yields_none_edited(self):
"""Empty edited_args means "no edits" — caller treats as plain approve."""
parsed = parse_lc_envelope(
{"decisions": [{"type": "edit", "edited_action": {"args": {}}}]}
)
assert parsed.edited_args is None
def test_message_picked_from_either_feedback_or_message_field(self):
with_feedback = parse_lc_envelope(
{"decisions": [{"type": "reject", "feedback": "no thanks"}]}
)
with_message = parse_lc_envelope(
{"decisions": [{"type": "reject", "message": "no thanks"}]}
)
assert with_feedback.message == "no thanks"
assert with_message.message == "no thanks"
def test_blank_message_treated_as_absent(self):
parsed = parse_lc_envelope(
{"decisions": [{"type": "reject", "message": " "}]}
)
assert parsed.message is None

View file

@ -1,4 +1,4 @@
"""Subagent resilience contract: ``extra_middleware`` reaches the agent chain."""
"""Subagent resilience contract: ``middleware_stack`` reaches the agent chain."""
from __future__ import annotations
@ -19,9 +19,14 @@ from langchain_core.language_models.fake_chat_models import (
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from app.agents.multi_agent_chat.middleware.shared.permissions.middleware.core import (
PermissionMiddleware,
)
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
pack_subagent,
)
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.agents.new_chat.permissions import Rule, Ruleset, evaluate
class RateLimitError(Exception):
@ -67,20 +72,23 @@ class _AlwaysFailingChatModel(BaseChatModel):
@pytest.mark.asyncio
async def test_subagent_recovers_when_primary_llm_fails():
"""Fallback in ``extra_middleware`` must finish the turn when primary raises."""
"""Fallback in ``middleware_stack`` must finish the turn when primary raises."""
primary = _AlwaysFailingChatModel()
fallback = FakeMessagesListChatModel(
responses=[AIMessage(content="recovered via fallback")]
)
spec = pack_subagent(
result = pack_subagent(
name="resilience_test",
description="test subagent",
system_prompt="be helpful",
tools=[],
ruleset=Ruleset(origin="resilience_test", rules=[]),
dependencies={"flags": AgentFeatureFlags()},
model=primary,
extra_middleware=[ModelFallbackMiddleware(fallback)],
middleware_stack={"fallback": ModelFallbackMiddleware(fallback)},
)
spec = result.spec
agent = create_agent(
model=spec["model"],
@ -94,3 +102,142 @@ async def test_subagent_recovers_when_primary_llm_fails():
final = result["messages"][-1]
assert isinstance(final, AIMessage)
assert final.content == "recovered via fallback"
def _extract_permission_mw(spec) -> PermissionMiddleware:
"""Find the lone PermissionMiddleware in a subagent's middleware list."""
matches = [m for m in spec["middleware"] if isinstance(m, PermissionMiddleware)]
assert len(matches) == 1, "expected exactly one PermissionMiddleware"
return matches[0]
def test_user_allowlist_overrides_coded_ask_via_last_match_wins():
"""User ``allow`` rules promoted via "Always Allow" must beat coded ``ask`` rules."""
coded = Ruleset(
origin="connector",
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
)
user_allowlist = Ruleset(
origin="user_allowlist:connector",
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
)
result = pack_subagent(
name="connector",
description="test connector",
system_prompt="x",
tools=[],
ruleset=coded,
dependencies={
"flags": AgentFeatureFlags(),
"user_allowlist_by_subagent": {"connector": user_allowlist},
},
)
mw = _extract_permission_mw(result.spec)
decided = evaluate("save_issue", "*", *mw._static_rulesets)
assert decided.action == "allow", (
f"user_allowlist must override coded ask; got {decided!r}"
)
def test_coded_ask_stays_when_user_allowlist_unrelated():
"""User ``allow`` rules for OTHER tools must not leak into asked-tools."""
coded = Ruleset(
origin="connector",
rules=[Rule(permission="delete_issue", pattern="*", action="ask")],
)
user_allowlist = Ruleset(
origin="user_allowlist:connector",
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
)
result = pack_subagent(
name="connector",
description="test",
system_prompt="x",
tools=[],
ruleset=coded,
dependencies={
"flags": AgentFeatureFlags(),
"user_allowlist_by_subagent": {"connector": user_allowlist},
},
)
mw = _extract_permission_mw(result.spec)
decided = evaluate("delete_issue", "*", *mw._static_rulesets)
assert decided.action == "ask"
def test_missing_user_allowlist_keeps_coded_behaviour():
"""``dependencies`` without ``user_allowlist_by_subagent`` is the common case."""
coded = Ruleset(
origin="connector",
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
)
result = pack_subagent(
name="connector",
description="test",
system_prompt="x",
tools=[],
ruleset=coded,
dependencies={"flags": AgentFeatureFlags()},
)
mw = _extract_permission_mw(result.spec)
decided = evaluate("save_issue", "*", *mw._static_rulesets)
assert decided.action == "ask"
def test_user_allowlist_for_different_subagent_does_not_leak():
"""User trust for ``linear`` must not affect a ``jira`` subagent compile."""
coded = Ruleset(
origin="jira",
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
)
linear_allowlist = Ruleset(
origin="user_allowlist:linear",
rules=[Rule(permission="save_issue", pattern="*", action="allow")],
)
result = pack_subagent(
name="jira",
description="test",
system_prompt="x",
tools=[],
ruleset=coded,
dependencies={
"flags": AgentFeatureFlags(),
"user_allowlist_by_subagent": {"linear": linear_allowlist},
},
)
mw = _extract_permission_mw(result.spec)
decided = evaluate("save_issue", "*", *mw._static_rulesets)
assert decided.action == "ask"
def test_empty_user_allowlist_is_tolerated():
"""An empty ``Ruleset`` (no rules) must not flip evaluation to allow-everything."""
coded = Ruleset(
origin="connector",
rules=[Rule(permission="save_issue", pattern="*", action="ask")],
)
empty = Ruleset(origin="user_allowlist:connector", rules=[])
result = pack_subagent(
name="connector",
description="test",
system_prompt="x",
tools=[],
ruleset=coded,
dependencies={
"flags": AgentFeatureFlags(),
"user_allowlist_by_subagent": {"connector": empty},
},
)
mw = _extract_permission_mw(result.spec)
decided = evaluate("save_issue", "*", *mw._static_rulesets)
assert decided.action == "ask"

View file

@ -106,9 +106,9 @@ class TestAsk:
# No new rule persisted
assert mw._runtime_ruleset.rules == []
def test_always_persists_runtime_rule(self) -> None:
def test_approve_always_persists_runtime_rule(self) -> None:
mw = PermissionMiddleware(rulesets=[])
mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment]
mw._raise_interrupt = lambda **kw: {"decision_type": "approve_always"} # type: ignore[assignment]
state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]}
out = mw.after_model(state, _FakeRuntime())
assert out is None # call kept

View file

@ -741,6 +741,366 @@ async def test_extract_image_falls_back_to_document_without_vision_llm(
assert result.content_type == "document"
# ---------------------------------------------------------------------------
# Document path with vision LLM: per-image descriptions are appended
# ---------------------------------------------------------------------------
def _fake_extraction_result(*descriptions):
from app.etl_pipeline.picture_describer import (
PictureDescription,
PictureExtractionResult,
)
return PictureExtractionResult(
descriptions=[
PictureDescription(
page_number=d["page"],
ordinal_in_page=d.get("ordinal", 0),
name=d["name"],
sha256=d.get("sha", "deadbeef"),
description=d["desc"],
)
for d in descriptions
]
)
async def test_extract_pdf_with_vision_llm_inlines_image_blocks(tmp_path, mocker):
"""A PDF with an `<!-- image -->` placeholder + caption gets the
block spliced inline (no orphaned ``## Image Content`` section).
This is the headline scenario for the medxpertqa benchmark: the
image content lives in the same chunk as the surrounding case text
so retrieval pulls the question, image, and answer options together.
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {
"content": (
"# MedXpertQA-MM MM-130\n\n"
"## Clinical case\n\nA 44-year-old man...\n\n"
"<!-- image -->\nImage: MM-130-a.jpeg\n\n"
"## Answer choices\n\nA) ...\n"
)
}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
extraction = _fake_extraction_result(
{
"page": 1,
"name": "Im0",
"desc": "Axial CT showing a large cystic mass.",
}
)
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(return_value=extraction),
)
fake_llm = mocker.MagicMock()
result = await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
md = result.markdown_content
# The placeholder + caption are gone, replaced by a horizontal-
# rule-delimited section with the captioned filename.
assert "<!-- image -->" not in md
assert "Image: MM-130-a.jpeg" not in md
assert "**Embedded image:** `MM-130-a.jpeg`" in md
assert "**Visual description:**" in md
assert "Axial CT showing a large cystic mass." in md
# No OCR section -- our fake_extraction_result has no ocr_text,
# and the format omits the section when there's no text to show.
assert "**OCR text:**" not in md
# No raw HTML / XML tags or blockquote wrapping leak.
assert "<image" not in md
assert "> **Embedded image:**" not in md
# No appended section -- everything went inline.
assert "## Image Content" not in md
# Surrounding case text + answer options are preserved.
assert "A 44-year-old man..." in md
assert "## Answer choices" in md
assert "A) ..." in md
async def test_extract_pdf_with_vision_llm_appends_when_no_marker(tmp_path, mocker):
"""When parser markdown has no image markers, descriptions get appended.
This is the fallback path for parsers that drop image placeholders
entirely. The image content still ends up in the markdown -- just
in a clearly-labeled section rather than inline.
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {
"content": "# Parsed PDF text\n\nNo image markers anywhere.\n"
}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
extraction = _fake_extraction_result(
{"page": 1, "name": "Im0", "desc": "An image description."}
)
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(return_value=extraction),
)
fake_llm = mocker.MagicMock()
result = await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
md = result.markdown_content
assert "# Parsed PDF text" in md
assert "## Image Content (vision-LLM extracted)" in md
assert "**Embedded image:** `Im0`" in md
assert "An image description." in md
async def test_extract_pdf_without_vision_llm_skips_picture_descriptions(
tmp_path, mocker
):
"""No vision LLM -> parser markdown returned as-is."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
describe_mock = mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(),
)
result = await EtlPipelineService().extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
assert result.markdown_content == "# Parsed PDF text"
assert "<image" not in result.markdown_content
describe_mock.assert_not_called()
async def test_extract_pdf_with_vision_llm_swallows_describe_failure(tmp_path, mocker):
"""A pypdf or vision LLM blow-up never fails the document upload."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(side_effect=RuntimeError("pypdf exploded")),
)
fake_llm = mocker.MagicMock()
result = await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
assert result.markdown_content == "# Parsed PDF text"
assert result.etl_service == "DOCLING"
async def test_extract_pdf_with_vision_llm_no_images_returns_parser_text(
tmp_path, mocker
):
"""Vision-LLM-enabled PDF with zero extracted images is unchanged."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Just text, no images"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
empty = _fake_extraction_result()
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(return_value=empty),
)
fake_llm = mocker.MagicMock()
result = await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
assert result.markdown_content == "# Just text, no images"
assert "<image" not in result.markdown_content
# ---------------------------------------------------------------------------
# Per-image OCR runner: wiring + behaviour
#
# When extracting a PDF with a vision LLM, the ETL service must ALSO
# pass an ``ocr_runner`` to picture_describer. The runner is a closure
# that re-feeds each extracted image through a vision-LLM-less
# EtlPipelineService -- i.e. the same OCR engine that handles
# standalone image uploads (Docling/Azure DI/LlamaCloud) gets a crack
# at each embedded image, with the text attached to the inline block.
# ---------------------------------------------------------------------------
async def test_extract_pdf_passes_ocr_runner_to_describe_pictures(tmp_path, mocker):
"""The ETL service must wire an ocr_runner kwarg to describe_pictures."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
describe_mock = mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=mocker.AsyncMock(return_value=_fake_extraction_result()),
)
fake_llm = mocker.MagicMock()
await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
describe_mock.assert_awaited_once()
_, kwargs = describe_mock.await_args
assert "ocr_runner" in kwargs
assert callable(kwargs["ocr_runner"])
async def test_extract_pdf_ocr_runner_invokes_document_parser_on_image(
tmp_path, mocker
):
"""The OCR runner closure should re-extract each image via the parser.
We capture the runner that the ETL service passes to
describe_pictures, invoke it with a fake image path, and assert
that Docling was called with that image. This proves the closure
is wired to a vision-LLM-less sub-pipeline (otherwise it would
recurse into the vision LLM and never hit the OCR engine).
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
image_file = tmp_path / "Im0.png"
image_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "Slice 24 / 60 L R"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
captured: dict = {}
async def capture_runner(*args, **kwargs):
captured["runner"] = kwargs["ocr_runner"]
return _fake_extraction_result()
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=capture_runner,
)
fake_llm = mocker.MagicMock()
await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
runner = captured["runner"]
ocr_text = await runner(str(image_file), "Im0.png")
assert ocr_text == "Slice 24 / 60 L R"
# Docling was invoked twice in total: once for the PDF, once for
# the image we re-fed via the runner.
assert fake_docling.process_document.await_count == 2
async def test_extract_pdf_ocr_runner_returns_empty_on_unsupported_image(
tmp_path, mocker
):
"""Unsupported image format → runner returns empty string, doesn't raise.
Common case: a PDF embeds a JPEG2000 or CCITT-TIFF image that
Docling can't load. We don't want an unsupported format on ONE
embedded image to spoil the whole PDF extraction; the runner
should swallow the EtlUnsupportedFileError and return "" so the
image gets a description but no OCR tag.
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake content")
weird_image = tmp_path / "Im0.jp2" # JPEG2000, unlikely to be supported
weird_image.write_bytes(b"\x00\x00\x00\x0cjP" + b"\x00" * 50)
mocker.patch("app.config.config.ETL_SERVICE", "DOCLING")
fake_docling = mocker.AsyncMock()
fake_docling.process_document.return_value = {"content": "# Parsed PDF text"}
mocker.patch(
"app.services.docling_service.create_docling_service",
return_value=fake_docling,
)
captured: dict = {}
async def capture_runner(*args, **kwargs):
captured["runner"] = kwargs["ocr_runner"]
return _fake_extraction_result()
mocker.patch(
"app.etl_pipeline.picture_describer.describe_pictures",
new=capture_runner,
)
fake_llm = mocker.MagicMock()
await EtlPipelineService(vision_llm=fake_llm).extract(
EtlRequest(file_path=str(pdf_file), filename="report.pdf")
)
runner = captured["runner"]
ocr_text = await runner(str(weird_image), "Im0.jp2")
assert ocr_text == ""
# ---------------------------------------------------------------------------
# Processing Mode enum tests
# ---------------------------------------------------------------------------

View file

@ -0,0 +1,972 @@
"""Unit tests for the picture_describer module.
Covers:
- :func:`describe_pictures` -- the PDF image walker + per-image vision
LLM call (structured output split into ``ocr_text`` and
``description``);
- :func:`inject_descriptions_inline` -- in-place replacement of image
placeholders / captions in the parser markdown;
- :func:`merge_descriptions_into_markdown` -- the top-level helper
that inlines what it can and appends what it can't;
- :func:`render_appended_section` -- the appended-fallback renderer.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.etl_pipeline.picture_describer import (
PictureDescription,
PictureExtractionResult,
describe_pictures,
inject_descriptions_inline,
merge_descriptions_into_markdown,
render_appended_section,
)
pytestmark = pytest.mark.unit
def _make_image_obj(name: str, data: bytes):
"""Mimic pypdf's ImageFile object shape for the bits we use."""
img = MagicMock()
img.name = name
img.data = data
return img
# ---------------------------------------------------------------------------
# describe_pictures: short-circuits
# ---------------------------------------------------------------------------
async def test_describe_pictures_no_op_for_non_pdf(tmp_path):
"""Non-PDF files are silently no-op'd; we don't try to extract images."""
docx_file = tmp_path / "report.docx"
docx_file.write_bytes(b"PK fake docx")
fake_llm = AsyncMock()
result = await describe_pictures(str(docx_file), "report.docx", fake_llm)
assert result.descriptions == []
assert result.skipped_too_large == 0
fake_llm.ainvoke.assert_not_called()
async def test_describe_pictures_no_op_when_vision_llm_is_none(tmp_path):
"""If the caller didn't provide a vision LLM, we no-op even for PDFs."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
result = await describe_pictures(str(pdf_file), "report.pdf", None)
assert result.descriptions == []
async def test_describe_pictures_no_op_for_pdf_with_no_images(tmp_path, mocker):
"""A PDF that pypdf can open but contains zero images returns empty."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[]), MagicMock(images=[])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
fake_llm = AsyncMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert result.descriptions == []
fake_llm.ainvoke.assert_not_called()
# ---------------------------------------------------------------------------
# describe_pictures: happy paths
# ---------------------------------------------------------------------------
async def test_describe_pictures_runs_vision_llm_per_image(tmp_path, mocker):
"""Every eligible image gets exactly one description-only vision call."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img_a = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
img_b = _make_image_obj("Im1.png", b"\x89PNG\r\n\x1a\n" + b"\xcd" * 2000)
page1 = MagicMock(images=[img_a])
page2 = MagicMock(images=[img_b])
fake_reader = MagicMock()
fake_reader.pages = [page1, page2]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
parse_mock = mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(side_effect=["Description A", "Description B"]),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 2
by_name = {d.name: d.description for d in result.descriptions}
assert by_name == {"Im0.jpeg": "Description A", "Im1.png": "Description B"}
assert all(d.page_number in (1, 2) for d in result.descriptions)
assert parse_mock.await_count == 2
async def test_describe_pictures_dedups_by_hash(tmp_path, mocker):
"""An image that appears N times in the PDF is described once."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
payload = b"\x89PNG\r\n\x1a\n" + b"\x42" * 2000
img = _make_image_obj("logo.png", payload)
page1 = MagicMock(images=[img])
page2 = MagicMock(images=[_make_image_obj("logo.png", payload)])
page3 = MagicMock(images=[_make_image_obj("logo.png", payload)])
fake_reader = MagicMock()
fake_reader.pages = [page1, page2, page3]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
parse_mock = mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="Logo desc"),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 1
assert result.skipped_duplicate == 2
assert parse_mock.await_count == 1
async def test_describe_pictures_skips_too_small_images(tmp_path, mocker):
"""Sub-1KB images (tracking pixels, dots, etc.) are skipped."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
tiny = _make_image_obj("dot.png", b"\x89PNG\r\n\x1a\n")
big = _make_image_obj("ct.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 3000)
page = MagicMock(images=[tiny, big])
fake_reader = MagicMock()
fake_reader.pages = [page]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
parse_mock = mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="CT scan"),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 1
assert result.descriptions[0].name == "ct.jpeg"
assert result.skipped_too_small == 1
assert parse_mock.await_count == 1
async def test_describe_pictures_skips_too_large_images(tmp_path, mocker):
"""Images larger than the vision LLM's per-image cap are skipped."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
huge = _make_image_obj("huge.jpeg", b"\xff" * (6 * 1024 * 1024))
ok = _make_image_obj("ok.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
page = MagicMock(images=[huge, ok])
fake_reader = MagicMock()
fake_reader.pages = [page]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
parse_mock = mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="OK image"),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 1
assert result.descriptions[0].name == "ok.jpeg"
assert result.skipped_too_large == 1
assert parse_mock.await_count == 1
async def test_describe_pictures_swallows_per_image_failure(tmp_path, mocker):
"""A vision LLM failure on one image must not kill the whole document."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img_a = _make_image_obj("a.jpeg", b"\xff\xd8" + b"\xab" * 2000)
img_b = _make_image_obj("b.jpeg", b"\xff\xd8" + b"\xcd" * 2000)
page = MagicMock(images=[img_a, img_b])
fake_reader = MagicMock()
fake_reader.pages = [page]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(side_effect=[RuntimeError("vision blew up"), "Success"]),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 1
assert result.descriptions[0].description == "Success"
assert result.failed == 1
async def test_describe_pictures_handles_pypdf_open_failure(tmp_path, mocker):
"""A malformed PDF that pypdf can't open returns an empty result."""
pdf_file = tmp_path / "broken.pdf"
pdf_file.write_bytes(b"not a pdf")
mocker.patch("pypdf.PdfReader", side_effect=ValueError("EOF marker not found"))
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "broken.pdf", fake_llm)
assert result.descriptions == []
# ---------------------------------------------------------------------------
# inject_descriptions_inline: replacement patterns
# ---------------------------------------------------------------------------
def _desc(name="Im0", description="A CT scan."):
return PictureDescription(
page_number=1,
ordinal_in_page=0,
name=name,
sha256="aa",
description=description,
)
def test_inject_no_op_when_no_descriptions():
markdown = "# Title\n\nbody text\n"
result = PictureExtractionResult()
out, n = inject_descriptions_inline(markdown, result)
assert out == markdown
assert n == 0
def test_inject_replaces_placeholder_with_caption():
"""`<!-- image -->` + `Image: <name>` together becomes one block.
This is the most common medxpertqa case: our renderer puts a caption
line right below the embedded JPEG, and Docling preserves both.
"""
markdown = (
"# Case\n\n"
"Clinical text...\n\n"
"<!-- image -->\nImage: MM-130-a.jpeg\n\n"
"Answer choices: A) ...\n"
)
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "<!-- image -->" not in out
assert "Image: MM-130-a.jpeg" not in out # caption consumed
# New format: horizontal-rule-delimited section with "Embedded
# image:" anchor and named "Visual description:" section. No
# blockquote wrapping -- nested blocks (lists, code, tables) inside
# a blockquote are silently dropped by Streamdown / remark.
assert "**Embedded image:** `MM-130-a.jpeg`" in out
assert "**Visual description:**" in out
assert "A CT scan." in out
# Block is delimited by horizontal rules so it stands out from
# surrounding paragraphs.
assert "\n---\n" in out
# No OCR section -- this fixture has no ocr_text on its descriptions.
assert "**OCR text:**" not in out
# No raw HTML tags / blockquote prefixes leak.
assert "<image" not in out
assert "</image>" not in out
assert "> **Embedded image:**" not in out # we no longer wrap in `>`
# Surrounding context is preserved.
assert "Clinical text..." in out
assert "Answer choices: A) ..." in out
def test_inject_uses_pypdf_name_when_no_caption():
"""`<!-- image -->` alone uses the pypdf-given name as the attribute."""
markdown = "# Case\n\n<!-- image -->\n\nMore text\n"
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "**Embedded image:** `Im0`" in out
def test_inject_replaces_bare_caption():
"""A bare `Image: <name>` line (no placeholder) still gets replaced."""
markdown = "# Case\n\nText...\nImage: scan.jpeg\nMore text\n"
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "**Embedded image:** `scan.jpeg`" in out
assert "Image: scan.jpeg" not in out
def test_inject_handles_multiple_images_in_order():
"""Two placeholders + two descriptions: each consumed in document order."""
markdown = (
"Page 1\n\n<!-- image -->\nImage: a.jpeg\n\n"
"Between\n\n<!-- image -->\nImage: b.jpeg\n\nEnd\n"
)
result = PictureExtractionResult(
descriptions=[
PictureDescription(
page_number=1,
ordinal_in_page=0,
name="Im0",
sha256="aa",
description="Desc A",
),
PictureDescription(
page_number=2,
ordinal_in_page=0,
name="Im1",
sha256="bb",
description="Desc B",
),
]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 2
assert "**Embedded image:** `a.jpeg`" in out
assert "**Embedded image:** `b.jpeg`" in out
assert out.index("a.jpeg") < out.index("b.jpeg")
assert "Desc A" in out and "Desc B" in out
def test_inject_returns_remaining_count_when_more_descriptions_than_markers():
"""Three descriptions, one marker -> only one inlined, two leftover."""
markdown = "Just one <!-- image --> here.\n"
result = PictureExtractionResult(
descriptions=[
_desc(name="Im0", description="First"),
_desc(name="Im1", description="Second"),
_desc(name="Im2", description="Third"),
]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "**Embedded image:** `Im0`" in out
assert "**Embedded image:** `Im1`" not in out
def test_inject_returns_zero_when_no_markers_present():
"""Markdown with no image markers at all returns the input unchanged."""
markdown = "# Title\n\nJust text. No images mentioned at all.\n"
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
out, n = inject_descriptions_inline(markdown, result)
assert n == 0
assert out == markdown
# ---------------------------------------------------------------------------
# render_appended_section
# ---------------------------------------------------------------------------
def test_render_appended_empty_when_nothing_passed():
assert render_appended_section([]) == ""
def test_render_appended_renders_each_image_as_block():
descriptions = [
_desc(name="MM-130-a.jpeg", description="CT scan"),
_desc(name="MM-130-b.jpeg", description="Bar chart"),
]
rendered = render_appended_section(descriptions)
assert "## Image Content (vision-LLM extracted)" in rendered
assert "**Embedded image:** `MM-130-a.jpeg`" in rendered
assert "CT scan" in rendered
assert "**Embedded image:** `MM-130-b.jpeg`" in rendered
assert "Bar chart" in rendered
# Each image block is delimited by horizontal rules.
assert rendered.count("\n---\n") >= 2
# No raw HTML / XML / blockquote prefixes.
assert "<image" not in rendered
assert "> **Embedded image:**" not in rendered
assert "**OCR text:**" not in rendered
def test_render_appended_includes_skip_notes():
descriptions = [_desc()]
skip_result = PictureExtractionResult(
descriptions=descriptions,
skipped_too_small=2,
skipped_too_large=1,
skipped_duplicate=3,
failed=1,
)
rendered = render_appended_section(descriptions, skip_notes=skip_result)
assert "_Note:" in rendered
assert "2 too small" in rendered
assert "1 too large" in rendered
assert "3 duplicate" in rendered
assert "1 failed" in rendered
# ---------------------------------------------------------------------------
# merge_descriptions_into_markdown: top-level
# ---------------------------------------------------------------------------
def test_merge_inlines_when_marker_present():
markdown = "Text...\n\n<!-- image -->\nImage: scan.jpeg\n\nMore text\n"
result = PictureExtractionResult(descriptions=[_desc(name="Im0")])
out = merge_descriptions_into_markdown(markdown, result)
assert "**Embedded image:** `scan.jpeg`" in out
# Nothing leaked into an appended section -- we should NOT see the
# appended-section heading because everything went inline.
assert "## Image Content" not in out
def test_merge_appends_when_no_marker_present():
"""Zero markers means everything goes into an appended section."""
markdown = "Pure text doc, no image markers.\n"
result = PictureExtractionResult(
descriptions=[_desc(name="Im0", description="An image desc.")]
)
out = merge_descriptions_into_markdown(markdown, result)
assert "Pure text doc" in out
assert "## Image Content (vision-LLM extracted)" in out
assert "**Embedded image:** `Im0`" in out
def test_merge_appends_leftovers_with_distinct_heading():
"""One marker, two descriptions -> one inline, second appended under
a heading that signals it's a leftover.
"""
markdown = "Text\n\n<!-- image -->\nImage: a.jpeg\n\nEnd\n"
result = PictureExtractionResult(
descriptions=[
_desc(name="Im0", description="First"),
_desc(name="Im1", description="Second"),
]
)
out = merge_descriptions_into_markdown(markdown, result)
assert "**Embedded image:** `a.jpeg`" in out # inlined
assert "## Image Content (additional, no inline marker found)" in out
assert "**Embedded image:** `Im1`" in out # appended
# ---------------------------------------------------------------------------
# describe_pictures: ocr_runner integration
#
# These tests cover the per-image OCR side-channel: when the caller
# supplies an ``ocr_runner`` callable, each extracted image is sent
# both to the vision LLM (visual description) and to the OCR runner
# (text-in-image), in parallel. The OCR text -- if any -- is recorded
# on the PictureDescription and rendered in the inline block.
# ---------------------------------------------------------------------------
async def test_describe_pictures_calls_ocr_runner_per_image(tmp_path, mocker):
"""When an ocr_runner is provided, it's invoked once per eligible image."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img_a = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
img_b = _make_image_obj("Im1.png", b"\x89PNG\r\n\x1a\n" + b"\xcd" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img_a, img_b])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(side_effect=["Visual A", "Visual B"]),
)
ocr_runner = AsyncMock(side_effect=["OCR text A", "OCR text B"])
fake_llm = MagicMock()
result = await describe_pictures(
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
)
assert ocr_runner.await_count == 2
by_name = {d.name: d.ocr_text for d in result.descriptions}
assert by_name == {"Im0.jpeg": "OCR text A", "Im1.png": "OCR text B"}
async def test_describe_pictures_runs_vision_and_ocr_in_parallel(tmp_path, mocker):
"""Vision LLM and OCR run concurrently per image, not sequentially.
We verify this by recording call timestamps: if both finish within
a small window relative to the per-call sleep, they ran in parallel.
"""
import asyncio
import time
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
sleep_each = 0.05 # 50ms per call
async def slow_vision(*args, **kwargs):
await asyncio.sleep(sleep_each)
return "Visual"
async def slow_ocr(*args, **kwargs):
await asyncio.sleep(sleep_each)
return "OCR"
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=slow_vision,
)
fake_llm = MagicMock()
started = time.perf_counter()
result = await describe_pictures(
str(pdf_file), "report.pdf", fake_llm, ocr_runner=slow_ocr
)
elapsed = time.perf_counter() - started
assert len(result.descriptions) == 1
assert result.descriptions[0].ocr_text == "OCR"
# Sequential would be ~2*sleep_each. Parallel is ~1*sleep_each + overhead.
# Be generous with the bound so we're not flaky on slow CI.
assert elapsed < 1.5 * sleep_each, (
f"vision+OCR appear to be sequential (took {elapsed:.3f}s)"
)
async def test_describe_pictures_treats_empty_ocr_as_none(tmp_path, mocker):
"""Empty / whitespace-only OCR result is normalised to None.
This means the rendered image block won't carry an empty
"OCR text" section for images that contain no text at all
(e.g. a clean radiograph).
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="A radiograph."),
)
ocr_runner = AsyncMock(return_value=" \n \n")
fake_llm = MagicMock()
result = await describe_pictures(
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
)
assert len(result.descriptions) == 1
assert result.descriptions[0].ocr_text is None
async def test_describe_pictures_swallows_ocr_runner_failure(tmp_path, mocker):
"""An OCR runner exception must not kill the description for that image.
OCR is supplementary; the vision LLM's description is the primary
payload. If OCR blows up we drop the OCR field for that image and
keep the description.
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="A radiograph."),
)
ocr_runner = AsyncMock(side_effect=RuntimeError("OCR backend down"))
fake_llm = MagicMock()
result = await describe_pictures(
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
)
assert len(result.descriptions) == 1
assert result.descriptions[0].description == "A radiograph."
assert result.descriptions[0].ocr_text is None
assert result.failed == 0 # the IMAGE didn't fail; only its OCR did
async def test_describe_pictures_vision_failure_with_ocr_runner_skips_image(
tmp_path, mocker
):
"""If the vision LLM fails, the image is skipped even if OCR succeeded.
The inline block's primary purpose is the visual description; an
OCR-only block would be misleading (it'd look like the vision
pipeline ran when it didn't), so we treat vision failure as image
failure regardless of OCR outcome.
"""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img = _make_image_obj("scan.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(side_effect=RuntimeError("vision blew up")),
)
ocr_runner = AsyncMock(return_value="OCR text")
fake_llm = MagicMock()
result = await describe_pictures(
str(pdf_file), "report.pdf", fake_llm, ocr_runner=ocr_runner
)
assert result.descriptions == []
assert result.failed == 1
async def test_describe_pictures_no_ocr_runner_keeps_ocr_text_none(tmp_path, mocker):
"""Backward compat: omitting ocr_runner produces description-only blocks."""
pdf_file = tmp_path / "report.pdf"
pdf_file.write_bytes(b"%PDF-1.4 fake")
img = _make_image_obj("Im0.jpeg", b"\xff\xd8\xff\xe0" + b"\xab" * 2000)
fake_reader = MagicMock()
fake_reader.pages = [MagicMock(images=[img])]
mocker.patch("pypdf.PdfReader", return_value=fake_reader)
mocker.patch(
"app.etl_pipeline.parsers.vision_llm.parse_image_for_description",
new=AsyncMock(return_value="Visual"),
)
fake_llm = MagicMock()
result = await describe_pictures(str(pdf_file), "report.pdf", fake_llm)
assert len(result.descriptions) == 1
assert result.descriptions[0].ocr_text is None
# ---------------------------------------------------------------------------
# Rendering: "OCR text" section appears iff PictureDescription.ocr_text is set
# ---------------------------------------------------------------------------
def _desc_with_ocr(name="Im0", description="A CT scan.", ocr_text="L R 10mm"):
return PictureDescription(
page_number=1,
ordinal_in_page=0,
name=name,
sha256="aa",
description=description,
ocr_text=ocr_text,
)
def test_inject_renders_ocr_section_when_ocr_text_present():
markdown = "Text\n\n<!-- image -->\nImage: scan.jpeg\n\nMore\n"
result = PictureExtractionResult(
descriptions=[_desc_with_ocr(name="Im0", ocr_text="L R 10mm")]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "**Embedded image:** `scan.jpeg`" in out
assert "**OCR text:**" in out
assert "L R 10mm" in out
# OCR section comes before the visual description (literal text
# first, interpretation second).
assert out.index("**OCR text:**") < out.index("**Visual description:**")
# Critical: no nested-block constructs (fenced code, blockquote)
# that previous formats relied on -- both broke in Streamdown /
# PlateJS by escaping their container and dropping content.
assert "```" not in out
assert "> **" not in out
def test_inject_renders_multiline_ocr_with_hard_breaks():
"""Multi-line OCR uses trailing-two-spaces hard breaks so each
line renders on its own row, without needing a fragile fenced
code block or blockquote wrapper."""
markdown = "Text\n\n<!-- image -->\nImage: scan.jpeg\n\nMore\n"
ocr_multi = "Slice 24 / 60\nL\nR\n10 mm"
result = PictureExtractionResult(
descriptions=[_desc_with_ocr(name="Im0", ocr_text=ocr_multi)]
)
out, _ = inject_descriptions_inline(markdown, result)
# Every OCR line is present.
for line in ("Slice 24 / 60", "L", "R", "10 mm"):
assert line in out
# Non-last OCR lines get the trailing two-space hard break.
assert "Slice 24 / 60 \n" in out
assert "\nL \n" in out
assert "\nR \n" in out
# Last OCR line must NOT carry the two-space hard break (no stray <br>).
assert "10 mm \n" not in out
assert "10 mm\n" in out
def test_render_appended_renders_ocr_section_when_ocr_text_present():
descriptions = [
_desc_with_ocr(
name="MM-130-a.jpeg",
description="Axial CT.",
ocr_text="Slice 24 / 60",
),
]
rendered = render_appended_section(descriptions)
assert "**OCR text:**" in rendered
assert "Slice 24 / 60" in rendered
assert "Axial CT." in rendered
def test_render_omits_ocr_section_when_ocr_text_is_none():
descriptions = [_desc(name="Im0", description="A clean radiograph.")]
rendered = render_appended_section(descriptions)
assert "**Embedded image:** `Im0`" in rendered
assert "**OCR text:**" not in rendered
assert "**Visual description:**" in rendered
# No raw HTML / blockquote prefixes.
assert "<image" not in rendered
assert "> **" not in rendered
# ---------------------------------------------------------------------------
# inject_descriptions_inline: <figure> blocks (layout-aware parsers)
#
# Azure Document Intelligence's ``prebuilt-layout`` and LlamaCloud
# premium both emit ``<figure>...</figure>`` blocks that already contain
# the parser's own OCR of the figure (chart bar values, axis labels,
# inline ``<figcaption>``, embedded ``<table>`` for tabular figures).
# That parser-side content is useful for retrieval on its own, so we
# PRESERVE the figure verbatim and append our vision-LLM block
# immediately after rather than substituting for it.
# ---------------------------------------------------------------------------
def test_inject_appends_block_after_figure_preserving_parser_content():
"""Figure block stays intact; vision-LLM block goes right after it."""
markdown = (
"Some narrative text.\n\n"
"<figure>\n\n"
"Republican\n68\nDemocrat\n30\n"
"\n</figure>\n\n"
"Following paragraph.\n"
)
result = PictureExtractionResult(
descriptions=[_desc(name="Im0", description="Bar chart of party ID.")]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
# Original figure is preserved verbatim -- the parser's OCR'd
# numbers must still be searchable.
assert "<figure>" in out
assert "</figure>" in out
assert "Republican" in out and "68" in out
# Our vision-LLM block follows the figure, not before / inside it.
assert "**Embedded image:** `Im0`" in out
assert "Bar chart of party ID." in out
figure_close = out.index("</figure>")
embedded_at = out.index("**Embedded image:** `Im0`")
assert figure_close < embedded_at, "block must be appended AFTER </figure>"
# Surrounding narrative is preserved.
assert "Some narrative text." in out
assert "Following paragraph." in out
def test_inject_handles_multiple_figures_in_document_order():
"""N figures + N descriptions: each pair lands in the right place."""
markdown = (
"Page 1\n\n<figure>\nChart A bars\n</figure>\n\n"
"Between\n\n<figure>\nChart B bars\n</figure>\n\n"
"End.\n"
)
result = PictureExtractionResult(
descriptions=[
PictureDescription(
page_number=1,
ordinal_in_page=0,
name="Im0",
sha256="aa",
description="Description of chart A.",
),
PictureDescription(
page_number=2,
ordinal_in_page=0,
name="Im1",
sha256="bb",
description="Description of chart B.",
),
]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 2
# Both figures preserved; both descriptions inlined; order matches.
assert out.count("<figure>") == 2
assert out.count("</figure>") == 2
assert "Description of chart A." in out
assert "Description of chart B." in out
assert out.index("Description of chart A.") < out.index("Description of chart B.")
# Each description appears AFTER its corresponding </figure>.
first_close = out.index("</figure>")
assert first_close < out.index("Description of chart A.")
second_close = out.index("</figure>", first_close + 1)
assert second_close < out.index("Description of chart B.")
def test_inject_figures_with_attributes_and_nested_tags():
"""``<figure>`` with attributes and nested tags is matched and preserved."""
markdown = (
'<figure id="fig-3" class="chart">\n'
"<figcaption>Source: Pew Research</figcaption>\n"
"<table><tr><td>Republican</td><td>57</td></tr></table>\n"
"</figure>\n"
)
result = PictureExtractionResult(
descriptions=[_desc(name="Im0", description="Survey table.")]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
# All nested HTML is preserved (chunking will pick it up).
assert 'id="fig-3"' in out
assert "<figcaption>Source: Pew Research</figcaption>" in out
assert "<table>" in out and "Republican" in out and "57" in out
# Our block sits after the closing tag.
assert out.index("</figure>") < out.index("**Embedded image:** `Im0`")
def test_inject_figures_more_descriptions_than_figures_returns_remaining():
"""Three descriptions, one figure -> one inlined, two left for caller."""
markdown = "Text.\n<figure>\nbar values\n</figure>\nMore.\n"
result = PictureExtractionResult(
descriptions=[
_desc(name="Im0", description="First desc."),
_desc(name="Im1", description="Second desc."),
_desc(name="Im2", description="Third desc."),
]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
assert "First desc." in out
# Leftovers are the caller's job; inject_descriptions_inline does
# not append them on its own.
assert "Second desc." not in out
assert "Third desc." not in out
def test_inject_figures_more_figures_than_descriptions_leaves_extras_untouched():
"""Two figures, one description -> first figure enriched, second left raw."""
markdown = (
"<figure>\nfigure 1 content\n</figure>\n<figure>\nfigure 2 content\n</figure>\n"
)
result = PictureExtractionResult(
descriptions=[_desc(name="Im0", description="Only description.")]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 1
# Both figures still present; only the first one was enriched.
assert out.count("<figure>") == 2
assert "Only description." in out
# Second figure has no embedded-image block immediately after it.
second_open = out.index("<figure>", out.index("<figure>") + 1)
second_close = out.index("</figure>", second_open)
after_second = out[second_close:]
assert "**Embedded image:**" not in after_second
def test_merge_inlines_at_figure_boundary():
"""Top-level helper does the right thing with figures (no leftover section)."""
markdown = "Lead.\n<figure>\nbars\n</figure>\nTrailer.\n"
result = PictureExtractionResult(
descriptions=[_desc(name="Im0", description="Bar chart.")]
)
out = merge_descriptions_into_markdown(markdown, result)
# Inline succeeded -> no appended-section heading.
assert "## Image Content" not in out
assert "Bar chart." in out
assert "<figure>" in out and "</figure>" in out
def test_inject_figures_then_falls_through_to_docling_marker():
"""Mixed-marker doc: figure consumed first, then Docling placeholder.
Defensive -- single docs are usually one parser's output, but if a
pipeline ever stitches two parsers' markdowns together the inliner
should still place each description.
"""
markdown = (
"<figure>\nChart bars: 50, 40, 30\n</figure>\n\n"
"Later in the doc:\n\n"
"<!-- image -->\nImage: scan.jpeg\n\n"
"End.\n"
)
result = PictureExtractionResult(
descriptions=[
_desc(name="Im0", description="Chart description."),
_desc(name="Im1", description="Scan description."),
]
)
out, n = inject_descriptions_inline(markdown, result)
assert n == 2
# Figure preserved + augmented.
assert "<figure>" in out and "Chart bars: 50, 40, 30" in out
assert "Chart description." in out
# Docling placeholder + caption replaced.
assert "<!-- image -->" not in out
assert "Image: scan.jpeg" not in out
assert "**Embedded image:** `scan.jpeg`" in out
assert "Scan description." in out

View file

@ -0,0 +1,146 @@
"""Unit tests for the vision_llm parser helpers.
Two helpers exist:
- :func:`parse_with_vision_llm` -- single-shot for standalone image
uploads (.png/.jpg/etc). Returns combined markdown (description +
verbatim OCR mixed) since the image *is* the document.
- :func:`parse_image_for_description` -- per-image-in-PDF call. Returns
visual description only; OCR is the ETL service's job.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# parse_with_vision_llm: legacy single-shot path
# ---------------------------------------------------------------------------
async def test_parse_with_vision_llm_returns_combined_markdown(tmp_path):
"""Standalone image uploads still go through the combined-markdown path."""
from app.etl_pipeline.parsers.vision_llm import parse_with_vision_llm
img = tmp_path / "scan.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
fake_response = MagicMock()
fake_response.content = "# A scan of something."
fake_llm = AsyncMock()
fake_llm.ainvoke.return_value = fake_response
out = await parse_with_vision_llm(str(img), "scan.png", fake_llm)
assert out == "# A scan of something."
fake_llm.ainvoke.assert_awaited_once()
async def test_parse_with_vision_llm_rejects_empty_response(tmp_path):
"""An empty model response raises rather than silently returning blanks."""
from app.etl_pipeline.parsers.vision_llm import parse_with_vision_llm
img = tmp_path / "scan.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
fake_response = MagicMock()
fake_response.content = ""
fake_llm = AsyncMock()
fake_llm.ainvoke.return_value = fake_response
with pytest.raises(ValueError, match="empty content"):
await parse_with_vision_llm(str(img), "scan.png", fake_llm)
# ---------------------------------------------------------------------------
# parse_image_for_description: per-image-in-PDF, description only
# ---------------------------------------------------------------------------
async def test_parse_image_for_description_returns_description(tmp_path):
"""Description-only path returns the model's markdown unchanged."""
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
img = tmp_path / "scan.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
fake_response = MagicMock()
fake_response.content = "Axial CT showing a large cystic mass."
fake_llm = AsyncMock()
fake_llm.ainvoke.return_value = fake_response
out = await parse_image_for_description(str(img), "scan.png", fake_llm)
assert out == "Axial CT showing a large cystic mass."
async def test_parse_image_for_description_uses_description_only_prompt(tmp_path):
"""The prompt explicitly tells the model NOT to transcribe text.
This is the contract that lets us drop OCR from the response: the
ETL pipeline already has the text (from page-level OCR), so asking
the vision LLM for it would be redundant cost.
"""
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
img = tmp_path / "scan.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
fake_response = MagicMock()
fake_response.content = "A description"
fake_llm = AsyncMock()
fake_llm.ainvoke.return_value = fake_response
await parse_image_for_description(str(img), "scan.png", fake_llm)
# The prompt is the first text part of the message we sent.
sent_messages = fake_llm.ainvoke.call_args.args[0]
prompt_text = sent_messages[0].content[0]["text"].lower()
assert "describe what this image visually depicts" in prompt_text
assert "do not transcribe text" in prompt_text
async def test_parse_image_for_description_rejects_empty(tmp_path):
"""Empty response surfaces as ValueError so the caller can skip the image."""
from app.etl_pipeline.parsers.vision_llm import parse_image_for_description
img = tmp_path / "scan.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 200)
fake_response = MagicMock()
fake_response.content = " " # whitespace-only counts as empty
fake_llm = AsyncMock()
fake_llm.ainvoke.return_value = fake_response
with pytest.raises(ValueError, match="empty content"):
await parse_image_for_description(str(img), "scan.png", fake_llm)
# ---------------------------------------------------------------------------
# Image size + extension validation (shared by both paths)
# ---------------------------------------------------------------------------
def test_image_to_data_url_rejects_oversized(tmp_path):
"""Images larger than 5 MB raise before any LLM call is made."""
from app.etl_pipeline.parsers.vision_llm import _image_to_data_url
big = tmp_path / "huge.png"
big.write_bytes(b"\x89PNG" + b"\x00" * (6 * 1024 * 1024))
with pytest.raises(ValueError, match="Image too large"):
_image_to_data_url(str(big))
def test_image_to_data_url_rejects_unsupported_extension(tmp_path):
"""Unknown extensions raise rather than guessing a MIME type."""
from app.etl_pipeline.parsers.vision_llm import _image_to_data_url
weird = tmp_path / "scan.xyz"
weird.write_bytes(b"\x00" * 100)
with pytest.raises(ValueError, match="Unsupported image extension"):
_image_to_data_url(str(weird))

View file

@ -39,8 +39,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
):
"""index() runs the chunker and embed_texts via asyncio.to_thread, not blocking the loop.
The default (non-code) path uses ``chunk_text_hybrid`` so Markdown tables stay
intact (see issue #1334); ``chunk_text`` is reserved for the code-chunker branch.
Routing between ``chunk_text`` (code path) and ``chunk_text_hybrid`` (default
path, see issue #1334) is verified separately in
``test_non_code_documents_use_hybrid_chunker``.
"""
to_thread_calls = []
original_to_thread = asyncio.to_thread
@ -86,11 +87,64 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
await pipeline.index(document, connector_doc, llm=MagicMock())
assert "chunk_text_hybrid" in to_thread_calls
# Either chunker entry point satisfies the "chunking runs off the event
# loop" contract this test guards. Routing between the two is verified
# in test_non_code_documents_use_hybrid_chunker.
assert {"chunk_text", "chunk_text_hybrid"} & set(to_thread_calls)
assert "embed_texts" in to_thread_calls
assert document.status == DocumentStatus.ready()
async def test_non_code_documents_use_hybrid_chunker(
pipeline, make_connector_document, monkeypatch
):
"""Non-code documents route through ``chunk_text_hybrid`` (issue #1334).
The hybrid chunker preserves Markdown table integrity by avoiding splits
mid-row. Only documents flagged with ``should_use_code_chunker=True``
should take the ``chunk_text`` path.
"""
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.summarize_document",
AsyncMock(return_value="Summary."),
)
mock_chunk_hybrid = MagicMock(return_value=["chunk1"])
mock_chunk_hybrid.__name__ = "chunk_text_hybrid"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid",
mock_chunk_hybrid,
)
mock_chunk_code = MagicMock(return_value=["chunk1"])
mock_chunk_code.__name__ = "chunk_text"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk_code,
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]),
)
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.attach_chunks_to_document",
MagicMock(),
)
connector_doc = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id="msg-1",
search_space_id=1,
should_use_code_chunker=False,
)
document = MagicMock(spec=Document)
document.id = 1
document.status = DocumentStatus.pending()
await pipeline.index(document, connector_doc, llm=MagicMock())
mock_chunk_hybrid.assert_called_once()
mock_chunk_code.assert_not_called()
def _mock_session_factory(orm_docs_by_id):
"""Replace get_celery_session_maker with a two-level callable.

View file

@ -0,0 +1,209 @@
"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt.
The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame
per paused subagent, in the same order ``state.interrupts`` reports them
that's also the order the resume slicer consumes decisions. These tests pin
that contract against a **real** paused parent graph built via
:class:`~langgraph.types.Send` fan-out (no synthetic state mocks).
"""
from __future__ import annotations
import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send, interrupt
from typing_extensions import TypedDict
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
build_task_tool_with_parent_config,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
all_interrupt_values,
)
class _SubState(TypedDict, total=False):
messages: list
class _DispatchState(TypedDict, total=False):
messages: list
tcid: str
desc: str
def _build_pausing_subagent(checkpointer: InMemorySaver):
def approve_node(_state):
decision = interrupt(
{
"action_requests": [
{"name": "do_thing", "args": {"x": 1}, "description": ""}
],
"review_configs": [{}],
}
)
return {"messages": [AIMessage(content=f"got:{decision}")]}
g = StateGraph(_SubState)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
return g.compile(checkpointer=checkpointer)
def _parent_graph_dispatching_two_tasks_via_send(
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
):
def fanout_edge(_state) -> list[Send]:
return [
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
]
async def call_task(state: _DispatchState, config: RunnableConfig):
rt = ToolRuntime(
state=state,
config=config,
context=None,
stream_writer=None,
tool_call_id=state["tcid"],
store=None,
)
return await task_tool.coroutine(
description=state["desc"], subagent_type="approver", runtime=rt
)
g = StateGraph(_DispatchState)
g.add_node("call_task", call_task)
g.add_conditional_edges(START, fanout_edge, ["call_task"])
g.add_edge("call_task", END)
return g.compile(checkpointer=checkpointer)
@pytest.mark.asyncio
async def test_returns_every_pending_interrupt_for_two_paused_subagents():
"""Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts."""
checkpointer = InMemorySaver()
subagent = _build_pausing_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
parent_config = {
"configurable": {"thread_id": "all-iv-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
state = await parent.aget_state(parent_config)
values = all_interrupt_values(state)
assert isinstance(values, list)
assert len(values) == 2, (
f"REGRESSION: expected one value per pending subagent, got "
f"{len(values)}: {values!r}"
)
stamps = [v.get("tool_call_id") for v in values]
assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"]
for v in values:
assert isinstance(v.get("action_requests"), list)
assert len(v["action_requests"]) == 1
@pytest.mark.asyncio
async def test_preserves_state_interrupts_traversal_order():
"""Order returned by inspector must match ``state.interrupts`` order.
The resume slicer consumes decisions left-to-right against
``collect_pending_tool_calls(state)`` which walks ``state.interrupts``
in iteration order so the inspector (which drives the *emit* order)
must agree with that traversal or the slice and the wire fall out of sync.
"""
checkpointer = InMemorySaver()
subagent = _build_pausing_subagent(checkpointer)
task_tool = build_task_tool_with_parent_config(
[{"name": "approver", "description": "approves", "runnable": subagent}]
)
parent = _parent_graph_dispatching_two_tasks_via_send(
task_tool,
tool_call_id_a="parent-tcid-A",
tool_call_id_b="parent-tcid-B",
checkpointer=checkpointer,
)
parent_config = {
"configurable": {"thread_id": "order-thread"},
"recursion_limit": 100,
}
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
state = await parent.aget_state(parent_config)
inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)]
state_order = [
i.value["tool_call_id"]
for i in state.interrupts
if isinstance(getattr(i, "value", None), dict) and "tool_call_id" in i.value
]
assert inspector_order == state_order, (
f"inspector order {inspector_order!r} diverged from state.interrupts "
f"order {state_order!r}; the resume slicer would mis-route decisions."
)
@pytest.mark.asyncio
async def test_returns_empty_list_when_nothing_paused():
"""A graph that completes normally produces no interrupts to surface."""
def done_node(_state):
return {"messages": [AIMessage(content="done")]}
g = StateGraph(_SubState)
g.add_node("done", done_node)
g.add_edge(START, "done")
g.add_edge("done", END)
graph = g.compile(checkpointer=InMemorySaver())
config = {"configurable": {"thread_id": "no-pause-thread"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
state = await graph.aget_state(config)
assert all_interrupt_values(state) == []
@pytest.mark.asyncio
async def test_single_paused_subagent_returns_a_list_of_one():
"""Single-pause case must still return a list (not unwrap to a dict)."""
def approve_node(_state):
decision = interrupt(
{
"action_requests": [{"name": "x", "args": {}, "description": ""}],
"review_configs": [{}],
"tool_call_id": "lonely-tcid",
}
)
return {"messages": [AIMessage(content=f"got:{decision}")]}
g = StateGraph(_SubState)
g.add_node("approve", approve_node)
g.add_edge(START, "approve")
g.add_edge("approve", END)
graph = g.compile(checkpointer=InMemorySaver())
config = {"configurable": {"thread_id": "single-thread"}}
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
state = await graph.aget_state(config)
values = all_interrupt_values(state)
assert isinstance(values, list)
assert len(values) == 1
assert values[0].get("tool_call_id") == "lonely-tcid"

View file

@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import (
_emit_stream_terminal_error as old_emit_terminal_error,
_extract_chunk_parts as old_extract_chunk_parts,
_extract_resolved_file_path as old_extract_resolved_file_path,
_first_interrupt_value as old_first_interrupt_value,
_tool_output_has_error as old_tool_output_has_error,
_tool_output_to_text as old_tool_output_to_text,
)
@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import (
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as new_extract_chunk_parts,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
first_interrupt_value as new_first_interrupt_value,
)
from app.tasks.chat.streaming.helpers.tool_output import (
extract_resolved_file_path as new_extract_resolved_file_path,
tool_output_has_error as new_tool_output_has_error,
@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
# ---------------------------------------------------------- interrupt inspector
@dataclass
class _Interrupt:
value: dict[str, Any]
@dataclass
class _Task:
interrupts: tuple[Any, ...] = ()
@dataclass
class _State:
tasks: tuple[Any, ...] = ()
interrupts: tuple[Any, ...] = ()
_INTERRUPT_CASES: list[Any] = [
_State(),
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
# Multiple tasks: must return the FIRST one in iteration order.
_State(
tasks=(
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
)
),
# Empty task interrupts -> falls back to root state.interrupts.
_State(
tasks=(_Task(interrupts=()),),
interrupts=(_Interrupt(value={"name": "root"}),),
),
# Interrupts as plain dicts (not wrapper objects).
_State(interrupts=({"value": {"name": "dict_root"}},)),
# A defective task whose `.interrupts` raises - must be tolerated.
_State(tasks=(object(),)),
]
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
# ----------------------------------------------------------- error classifier

View file

@ -0,0 +1,171 @@
"""Pin: thinking-step IDs must be globally unique within a thread.
The frontend rehydrates ``currentThinkingSteps`` from the prior assistant
message when starting a resume. If two consecutive resume turns emit step IDs
that overlap (e.g. both produce ``thinking-resume-1`` because each invocation
constructs a fresh :class:`AgentEventRelayState` with
``thinking_step_counter=0``), React renders sibling timeline rows with the
same key the warning the user reported in production.
The contract this module pins: each ``_stream_agent_events`` invocation must
receive a ``step_prefix`` that is unique within the thread (we salt with the
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
are always disjoint.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_resume_step_prefix,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
class _FakeAgentState:
def __init__(self) -> None:
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
self._state = _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
return self._state
def _tool_start(*, name: str, run_id: str) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": {}},
}
async def _drain_step_ids(
events: list[dict[str, Any]], *, step_prefix: str
) -> set[str]:
"""Run ``_stream_agent_events`` once and return every emitted thinking-step ID."""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "regression-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix=step_prefix
):
sse_lines.append(sse)
ids: set[str] = set()
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
payload = json.loads(body)
except json.JSONDecodeError:
continue
if payload.get("type") != "data-thinking-step":
continue
step_id = (payload.get("data") or {}).get("id")
if isinstance(step_id, str):
ids.add(step_id)
return ids
@pytest.mark.asyncio
async def test_consecutive_invocations_with_same_prefix_produce_overlapping_ids():
"""Pin the bug: identical ``step_prefix`` across two turns reuses ``-1``, ``-2``…
This is what production was doing for resume every resume invocation
passed ``step_prefix='thinking-resume'`` and the relay state's counter
restarted at 0. Two scrollback timelines built from such turns then
presented React with siblings keyed by the same ``thinking-resume-1``.
"""
events = [
_tool_start(name="t1", run_id="run-A-1"),
_tool_start(name="t2", run_id="run-A-2"),
]
ids_turn_one = await _drain_step_ids(events, step_prefix="thinking-resume")
ids_turn_two = await _drain_step_ids(events, step_prefix="thinking-resume")
assert ids_turn_one == ids_turn_two != set(), (
"fixture broken: expected non-empty overlapping ids when prefix is reused"
)
@pytest.mark.asyncio
async def test_per_turn_salted_prefix_yields_disjoint_step_ids_across_turns():
"""The fix: salting the prefix with the per-turn ``turn_id`` makes IDs disjoint.
Two consecutive resume calls in the same thread feed two different
``turn_id``s into the prefix, so the resulting step IDs cannot collide
no matter how many times the FE rehydrates from earlier assistant
messages which is the precondition for the React duplicate-key warning.
"""
events = [
_tool_start(name="t1", run_id="run-A-1"),
_tool_start(name="t2", run_id="run-A-2"),
]
ids_turn_one = await _drain_step_ids(
events, step_prefix="thinking-resume-104:1778698228472"
)
ids_turn_two = await _drain_step_ids(
events, step_prefix="thinking-resume-104:1778698244022"
)
assert ids_turn_one and ids_turn_two, "fixture broken: expected non-empty id sets"
assert ids_turn_one.isdisjoint(ids_turn_two), (
f"REGRESSION: per-turn-salted prefixes produced overlapping step IDs: "
f"{ids_turn_one & ids_turn_two!r}"
)
def test_resume_step_prefix_helper_includes_turn_id_verbatim():
"""Production call-site pin: ``stream_resume_chat`` builds the prefix via
this helper. Reverting it back to a hardcoded ``'thinking-resume'`` would
silently re-introduce the duplicate-key React warning across consecutive
resumes this test fails first instead.
"""
a = _resume_step_prefix("104:1778698228472")
b = _resume_step_prefix("104:1778698244022")
assert a.startswith("thinking-resume-"), (
f"prefix shape changed; the FE log filters and the timeline contract "
f"expect the ``thinking-resume-`` head to remain stable: got {a!r}"
)
assert "104:1778698228472" in a and "104:1778698244022" in b
assert a != b