Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-12 21:23:37 +02:00
commit 246dae40a8
229 changed files with 36484 additions and 436 deletions

View file

@ -0,0 +1,285 @@
"""Tests for the @-mention resolver.
These tests pin down the contract that ``mention_resolver`` is the
single seam between ``MentionedDocumentInfo`` chips (frontend) and the
canonical ``/documents/...`` virtual paths (agent). The streaming task,
priority middleware, and persistence layer all consume the resolver's
output keeping the tests focused on substitute-in-text + the
returned id partition keeps the seam stable across refactors.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat import mention_resolver
from app.agents.new_chat.mention_resolver import (
ResolvedMention,
ResolvedMentionSet,
resolve_mentions,
substitute_in_text,
)
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex
from app.schemas.new_chat import MentionedDocumentInfo
pytestmark = pytest.mark.unit
class TestSubstituteInText:
"""``substitute_in_text`` is a pure string transform and is exercised
on every cloud-mode turn, so it has to be both fast and behaviour-
identical to the frontend's ``parseMentionSegments`` (longest-token
first, single forward pass)."""
def test_returns_text_unchanged_when_no_tokens(self):
assert substitute_in_text("hello @foo", []) == "hello @foo"
def test_returns_text_unchanged_when_empty(self):
assert substitute_in_text("", [("@x", "/documents/x.xml")]) == ""
def test_replaces_single_token_with_backtick_path(self):
out = substitute_in_text(
"see @notes please",
[("@notes", "/documents/notes.xml")],
)
assert out == "see `/documents/notes.xml` please"
def test_longest_token_wins_over_prefix(self):
# ``@Project Roadmap`` must NOT be partially matched by ``@Project``.
# Mirrors the FE's parseMentionSegments contract.
token_to_path = [
("@Project Roadmap", "/documents/Roadmap.xml"),
("@Project", "/documents/Project.xml"),
]
out = substitute_in_text("about @Project Roadmap today", token_to_path)
assert out == "about `/documents/Roadmap.xml` today"
def test_handles_repeated_mentions(self):
out = substitute_in_text(
"@A and @A again @B",
[
("@A", "/documents/a.xml"),
("@B", "/documents/b.xml"),
],
)
assert (
out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`"
)
def test_does_not_match_inside_word(self):
# Substitution is positional — there's no word-boundary semantics.
# ``@Pro`` inside ``foo@Project`` still matches; this is the same
# behaviour as parseMentionSegments. The test pins it so a
# future "fix" doesn't accidentally diverge between FE/BE.
out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")])
assert out == "foo`/documents/p.xml`"
def test_idempotent_after_substitution(self):
# The output starts with a backtick, not ``@``, so re-running
# the substitution leaves it alone.
once = substitute_in_text("@A", [("@A", "/documents/a.xml")])
twice = substitute_in_text(once, [("@A", "/documents/a.xml")])
assert once == twice
class TestResolveMentions:
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
a ``ResolvedMentionSet`` whose id partitions feed
``KnowledgePriorityMiddleware``."""
@pytest.mark.asyncio
async def test_returns_empty_when_no_mentions(self):
session = MagicMock()
session.execute = AsyncMock()
result = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=None,
)
assert isinstance(result, ResolvedMentionSet)
assert result.mentions == []
assert result.token_to_path == []
assert result.mentioned_document_ids == []
assert result.mentioned_folder_ids == []
# No DB roundtrips when there's nothing to resolve.
session.execute.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch):
chip = MentionedDocumentInfo(
id=42,
title="Notes",
document_type="EXTENSION",
kind="doc",
)
doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=5,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "doc"
assert mention.id == 42
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml"
assert out.mentioned_document_ids == [42]
assert out.mentioned_folder_ids == []
assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path
@pytest.mark.asyncio
async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch):
chip = MentionedDocumentInfo(
id=9,
title="Reports",
document_type="FOLDER",
kind="folder",
)
folder_row = SimpleNamespace(id=9, name="Reports")
async def fake_build_index(_session, _ssid):
return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"})
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [folder_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=3,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "folder"
assert mention.id == 9
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/"
assert out.mentioned_document_ids == []
assert out.mentioned_folder_ids == [9]
@pytest.mark.asyncio
async def test_drops_chip_when_doc_is_missing(self, monkeypatch):
chip = MentionedDocumentInfo(
id=99, title="ghost", document_type="EXTENSION", kind="doc"
)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = []
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip],
)
assert out.mentions == []
assert out.mentioned_document_ids == []
assert out.token_to_path == []
@pytest.mark.asyncio
async def test_token_to_path_is_longest_first(self, monkeypatch):
# Two chips whose titles are prefixes of each other — the
# resolver MUST sort longest-first so substitution doesn't
# break the ``@Project Roadmap`` vs ``@Project`` invariant.
chip_short = MentionedDocumentInfo(
id=1, title="A", document_type="EXTENSION", kind="doc"
)
chip_long = MentionedDocumentInfo(
id=2, title="A long one", document_type="EXTENSION", kind="doc"
)
rows = [
SimpleNamespace(id=1, title="A", folder_id=None),
SimpleNamespace(id=2, title="A long one", folder_id=None),
]
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = rows
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip_short, chip_long],
)
tokens = [tok for tok, _ in out.token_to_path]
assert tokens == sorted(tokens, key=len, reverse=True)
@pytest.mark.asyncio
async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch):
# ``mentioned_document_ids`` (the legacy parallel array) must
# still resolve when no chip metadata is available — covers
# callers that haven't migrated to the discriminated chip list.
doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=2,
mentioned_documents=None,
mentioned_document_ids=[7],
)
assert out.mentioned_document_ids == [7]
assert len(out.mentions) == 1
assert out.mentions[0].title == "Legacy"
class TestResolvedMentionEquality:
"""Smoke check on the dataclass behaviour we rely on for asserting
test outputs."""
def test_equal_when_fields_equal(self):
a = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
b = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
assert a == b

View file

@ -196,3 +196,50 @@ class TestVirtualPathToDoc:
)
assert document is target_doc
assert session.execute.await_count == 2
@pytest.mark.asyncio
async def test_resolves_double_extension_for_uploaded_pdf(self):
# Regression: the agent renders every KB document under
# ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``),
# so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as
# ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path
# must round-trip back to the row even though the title itself
# does NOT end in ``.xml``.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml",
)
assert document is target_doc
@pytest.mark.asyncio
async def test_resolves_path_without_xml_suffix(self):
# The user (or a hand-edited link) may pass the title-only form
# ``/documents/2025-W2.pdf``. The resolver must still find the row
# by literal title equality.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf",
)
assert document is target_doc

View file

@ -0,0 +1,38 @@
from tests.e2e.fakes.composio_module import _drive_list_files
def _ids(result: dict) -> set[str]:
return {item["id"] for item in result["data"]["files"]}
def test_drive_list_files_filters_shortcuts_and_trashed_items():
result = _drive_list_files(
{
"q": (
"'root' in parents and trashed = false and "
"mimeType != 'application/vnd.google-apps.shortcut'"
)
}
)
ids = _ids(result)
assert "fake-file-canary" in ids
assert "fake-shortcut-canary" not in ids
assert "fake-file-trashed" not in ids
def test_drive_list_files_filters_to_exact_mime_type():
result = _drive_list_files(
{"q": "'root' in parents and trashed = false and mimeType = 'text/plain'"}
)
assert _ids(result) == {"fake-file-canary"}
def test_drive_list_files_uses_requested_parent_folder():
result = _drive_list_files(
{"q": "'fake-folder-projects' in parents and trashed = false"}
)
assert _ids(result) == {"fake-file-roadmap"}

View file

@ -1,8 +1,17 @@
"""Unit tests: build_composio_credentials returns valid Google Credentials.
"""Unit tests: Composio credential helpers + ``get_access_token`` masking guard.
Mocks the Composio SDK (external system boundary) and verifies that the
returned ``google.oauth2.credentials.Credentials`` object is correctly
configured with a token and a working refresh handler.
Covers two seams between Surfsense and Composio:
1. ``build_composio_credentials`` returns a ``google.oauth2.credentials.Credentials``
object with a working refresh handler (mocks the whole ``ComposioService``).
2. ``ComposioService.get_access_token`` rejects masked / missing tokens with
actionable error messages (mocks only the Composio SDK boundary so the
real guard logic is exercised).
The masking guard is the boundary handler that production tripped over when
Composio's "Mask Connected Account Secrets" project setting was enabled.
The corresponding fix landed in ``cea8618``; these tests lock that contract
in place so any future weakening of the guard surfaces immediately.
"""
from datetime import UTC, datetime
@ -14,6 +23,11 @@ from google.oauth2.credentials import Credentials
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# build_composio_credentials — high-level wrapper tests
# ---------------------------------------------------------------------------
@patch("app.services.composio_service.ComposioService")
def test_returns_credentials_with_token_and_expiry(mock_composio_service):
"""build_composio_credentials returns a Credentials object with the Composio access token."""
@ -54,3 +68,85 @@ def test_refresh_handler_fetches_fresh_token(mock_composio_service):
assert new_token == "refreshed-token"
assert new_expiry > datetime.now(UTC).replace(tzinfo=None)
assert mock_service.get_access_token.call_count == 2
# ---------------------------------------------------------------------------
# ComposioService.get_access_token — boundary masking guard tests
# ---------------------------------------------------------------------------
def _service_with_account(account: object):
"""Build a real ``ComposioService`` whose underlying Composio SDK is faked.
Only the SDK boundary is patched the real ``get_access_token`` method
runs, so changes to the masking / missing-token guards surface here.
"""
from app.services import composio_service as composio_service_module
with patch.object(composio_service_module, "Composio") as mock_composio_cls:
mock_client = MagicMock()
mock_client.connected_accounts.get.return_value = account
mock_composio_cls.return_value = mock_client
service = composio_service_module.ComposioService(api_key="unit-test-api-key")
# ``service.client`` already references ``mock_client`` even after the
# patch context exits because the constructor captured it during init.
return service
@pytest.mark.parametrize("masked_token", ["x", "xxxxxxxx", "x" * 19])
def test_get_access_token_raises_on_masked_token(masked_token):
"""Tokens shorter than the 20-char unmask threshold must raise with the dashboard hint.
Composio masks ``state.val.access_token`` by default (project setting
"Mask Connected Account Secrets"). A masked token will always silently
fail downstream OAuth calls, so the guard surfaces it with the exact
text needed to fix the dashboard config.
"""
fake_account = MagicMock()
fake_account.state.val.access_token = masked_token
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match="Mask Connected Account Secrets"):
service.get_access_token("any-account-id")
def test_get_access_token_raises_when_state_val_missing():
"""No ``state.val`` on the connected account is a hard failure with an account-id hint."""
fake_account = MagicMock()
fake_account.state = None
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No state\.val.*missing-state-account"):
service.get_access_token("missing-state-account")
def test_get_access_token_raises_when_access_token_empty():
"""``state.val`` present but ``access_token`` empty must fail before the masking check."""
fake_account = MagicMock()
fake_account.state.val.access_token = ""
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No access_token.*missing-token-account"):
service.get_access_token("missing-token-account")
def test_get_access_token_raises_when_access_token_none():
"""``state.val.access_token = None`` must fail before the masking check."""
fake_account = MagicMock()
fake_account.state.val.access_token = None
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No access_token.*none-token-account"):
service.get_access_token("none-token-account")
def test_get_access_token_returns_unmasked_token():
"""Happy path: a >=20-char access token is returned verbatim."""
fake_account = MagicMock()
unmasked = "u" * 32
fake_account.state.val.access_token = unmasked
service = _service_with_account(fake_account)
assert service.get_access_token("happy-account") == unmasked

View file

@ -118,12 +118,8 @@ def test_get_by_tool_call_id_matches_action_request_payload() -> None:
tasks=(
_Task(
interrupts=(
_Interrupt(
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
),
_Interrupt(
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
),
_Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"),
_Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"),
)
),
)
@ -146,9 +142,7 @@ def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
def test_interrupt_without_id_falls_back_to_none() -> None:
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
)
state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),))
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0].interrupt_id is None

View file

@ -37,9 +37,7 @@ def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
"context": {"reason": "destructive"},
}
out = normalize_interrupt_payload(raw)
assert out["action_requests"] == [
{"name": "send_email", "args": {"to": "a@b"}}
]
assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}]
assert out["review_configs"] == [
{
"action_name": "send_email",

View file

@ -158,9 +158,7 @@ def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception(
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
),
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
@ -220,7 +218,7 @@ class _FakeStreamingService:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:

View file

@ -60,8 +60,14 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
service = _StreamingService()
agent = _Agent(
[
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content="Hello")},
},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content=" world")},
},
]
)
result = StreamingResult()

View file

@ -37,7 +37,9 @@ def test_clear_ignored_for_non_task_tool() -> None:
def test_clear_ignored_when_task_run_id_mismatches() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-open")
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
clear_task_span_if_delegating_task_ended(
state, tool_name="task", run_id="run-other"
)
assert state.active_span_id is not None
assert state.active_task_run_id == "run-open"

View file

@ -240,9 +240,7 @@ class TestToolHeavyTurn:
class TestToolCallSpanMetadata:
def test_input_available_merges_new_metadata_keys_after_start(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"})
b.on_tool_input_available(
"call_t",
"task",
@ -257,9 +255,7 @@ class TestToolCallSpanMetadata:
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"})
b.on_tool_input_available(
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
)

View file

@ -0,0 +1,93 @@
import base64
import hashlib
import hmac
import json
import time
from uuid import uuid4
import pytest
from fastapi import HTTPException
from app.utils.oauth_security import OAuthStateManager
SECRET = "unit-test-secret"
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
"""Build an OAuth state payload compatible with OAuthStateManager."""
signature_payload = payload.copy()
payload_str = json.dumps(signature_payload, sort_keys=True)
computed_signature = hmac.new(
SECRET.encode(),
payload_str.encode(),
hashlib.sha256,
).hexdigest()
encoded_payload = {
**signature_payload,
"signature": signature if signature is not None else computed_signature,
}
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
def test_validate_state_accepts_fresh_signed_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
user_id = uuid4()
state = mgr.generate_secure_state(
space_id=1,
user_id=user_id,
toolkit_id="googledrive",
)
decoded = mgr.validate_state(state)
assert decoded["space_id"] == 1
assert decoded["user_id"] == str(user_id)
assert decoded["toolkit_id"] == "googledrive"
def test_validate_state_rejects_expired_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
expired_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()) - 3600,
"toolkit_id": "googledrive",
}
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(expired_state)
assert exc.value.status_code == 400
assert "expired" in exc.value.detail.lower()
def test_validate_state_rejects_tampered_signature():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
tampered_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()),
"toolkit_id": "googledrive",
},
signature="deadbeef" * 8,
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(tampered_state)
assert exc.value.status_code == 400
assert "tampering" in exc.value.detail.lower()
def test_validate_state_rejects_malformed_state():
mgr = OAuthStateManager(secret_key=SECRET)
with pytest.raises(HTTPException) as exc:
mgr.validate_state("not-base64-and-not-json")
assert exc.value.status_code == 400
assert "invalid state format" in exc.value.detail.lower()