mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
246dae40a8
229 changed files with 36484 additions and 436 deletions
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for the @-mention resolver.
|
||||
|
||||
These tests pin down the contract that ``mention_resolver`` is the
|
||||
single seam between ``MentionedDocumentInfo`` chips (frontend) and the
|
||||
canonical ``/documents/...`` virtual paths (agent). The streaming task,
|
||||
priority middleware, and persistence layer all consume the resolver's
|
||||
output — keeping the tests focused on substitute-in-text + the
|
||||
returned id partition keeps the seam stable across refactors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat import mention_resolver
|
||||
from app.agents.new_chat.mention_resolver import (
|
||||
ResolvedMention,
|
||||
ResolvedMentionSet,
|
||||
resolve_mentions,
|
||||
substitute_in_text,
|
||||
)
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestSubstituteInText:
|
||||
"""``substitute_in_text`` is a pure string transform and is exercised
|
||||
on every cloud-mode turn, so it has to be both fast and behaviour-
|
||||
identical to the frontend's ``parseMentionSegments`` (longest-token
|
||||
first, single forward pass)."""
|
||||
|
||||
def test_returns_text_unchanged_when_no_tokens(self):
|
||||
assert substitute_in_text("hello @foo", []) == "hello @foo"
|
||||
|
||||
def test_returns_text_unchanged_when_empty(self):
|
||||
assert substitute_in_text("", [("@x", "/documents/x.xml")]) == ""
|
||||
|
||||
def test_replaces_single_token_with_backtick_path(self):
|
||||
out = substitute_in_text(
|
||||
"see @notes please",
|
||||
[("@notes", "/documents/notes.xml")],
|
||||
)
|
||||
assert out == "see `/documents/notes.xml` please"
|
||||
|
||||
def test_longest_token_wins_over_prefix(self):
|
||||
# ``@Project Roadmap`` must NOT be partially matched by ``@Project``.
|
||||
# Mirrors the FE's parseMentionSegments contract.
|
||||
token_to_path = [
|
||||
("@Project Roadmap", "/documents/Roadmap.xml"),
|
||||
("@Project", "/documents/Project.xml"),
|
||||
]
|
||||
out = substitute_in_text("about @Project Roadmap today", token_to_path)
|
||||
assert out == "about `/documents/Roadmap.xml` today"
|
||||
|
||||
def test_handles_repeated_mentions(self):
|
||||
out = substitute_in_text(
|
||||
"@A and @A again @B",
|
||||
[
|
||||
("@A", "/documents/a.xml"),
|
||||
("@B", "/documents/b.xml"),
|
||||
],
|
||||
)
|
||||
assert (
|
||||
out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`"
|
||||
)
|
||||
|
||||
def test_does_not_match_inside_word(self):
|
||||
# Substitution is positional — there's no word-boundary semantics.
|
||||
# ``@Pro`` inside ``foo@Project`` still matches; this is the same
|
||||
# behaviour as parseMentionSegments. The test pins it so a
|
||||
# future "fix" doesn't accidentally diverge between FE/BE.
|
||||
out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")])
|
||||
assert out == "foo`/documents/p.xml`"
|
||||
|
||||
def test_idempotent_after_substitution(self):
|
||||
# The output starts with a backtick, not ``@``, so re-running
|
||||
# the substitution leaves it alone.
|
||||
once = substitute_in_text("@A", [("@A", "/documents/a.xml")])
|
||||
twice = substitute_in_text(once, [("@A", "/documents/a.xml")])
|
||||
assert once == twice
|
||||
|
||||
|
||||
class TestResolveMentions:
|
||||
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
|
||||
a ``ResolvedMentionSet`` whose id partitions feed
|
||||
``KnowledgePriorityMiddleware``."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_mentions(self):
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
result = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=None,
|
||||
)
|
||||
assert isinstance(result, ResolvedMentionSet)
|
||||
assert result.mentions == []
|
||||
assert result.token_to_path == []
|
||||
assert result.mentioned_document_ids == []
|
||||
assert result.mentioned_folder_ids == []
|
||||
# No DB roundtrips when there's nothing to resolve.
|
||||
session.execute.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=42,
|
||||
title="Notes",
|
||||
document_type="EXTENSION",
|
||||
kind="doc",
|
||||
)
|
||||
doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [doc_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=5,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert len(out.mentions) == 1
|
||||
mention = out.mentions[0]
|
||||
assert mention.kind == "doc"
|
||||
assert mention.id == 42
|
||||
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml"
|
||||
assert out.mentioned_document_ids == [42]
|
||||
assert out.mentioned_folder_ids == []
|
||||
assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=9,
|
||||
title="Reports",
|
||||
document_type="FOLDER",
|
||||
kind="folder",
|
||||
)
|
||||
folder_row = SimpleNamespace(id=9, name="Reports")
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"})
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [folder_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=3,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert len(out.mentions) == 1
|
||||
mention = out.mentions[0]
|
||||
assert mention.kind == "folder"
|
||||
assert mention.id == 9
|
||||
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/"
|
||||
assert out.mentioned_document_ids == []
|
||||
assert out.mentioned_folder_ids == [9]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drops_chip_when_doc_is_missing(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=99, title="ghost", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = []
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert out.mentions == []
|
||||
assert out.mentioned_document_ids == []
|
||||
assert out.token_to_path == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_to_path_is_longest_first(self, monkeypatch):
|
||||
# Two chips whose titles are prefixes of each other — the
|
||||
# resolver MUST sort longest-first so substitution doesn't
|
||||
# break the ``@Project Roadmap`` vs ``@Project`` invariant.
|
||||
chip_short = MentionedDocumentInfo(
|
||||
id=1, title="A", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
chip_long = MentionedDocumentInfo(
|
||||
id=2, title="A long one", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
rows = [
|
||||
SimpleNamespace(id=1, title="A", folder_id=None),
|
||||
SimpleNamespace(id=2, title="A long one", folder_id=None),
|
||||
]
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = rows
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=[chip_short, chip_long],
|
||||
)
|
||||
tokens = [tok for tok, _ in out.token_to_path]
|
||||
assert tokens == sorted(tokens, key=len, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch):
|
||||
# ``mentioned_document_ids`` (the legacy parallel array) must
|
||||
# still resolve when no chip metadata is available — covers
|
||||
# callers that haven't migrated to the discriminated chip list.
|
||||
doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [doc_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=2,
|
||||
mentioned_documents=None,
|
||||
mentioned_document_ids=[7],
|
||||
)
|
||||
assert out.mentioned_document_ids == [7]
|
||||
assert len(out.mentions) == 1
|
||||
assert out.mentions[0].title == "Legacy"
|
||||
|
||||
|
||||
class TestResolvedMentionEquality:
|
||||
"""Smoke check on the dataclass behaviour we rely on for asserting
|
||||
test outputs."""
|
||||
|
||||
def test_equal_when_fields_equal(self):
|
||||
a = ResolvedMention(
|
||||
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
|
||||
)
|
||||
b = ResolvedMention(
|
||||
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
|
||||
)
|
||||
assert a == b
|
||||
|
|
@ -196,3 +196,50 @@ class TestVirtualPathToDoc:
|
|||
)
|
||||
assert document is target_doc
|
||||
assert session.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_double_extension_for_uploaded_pdf(self):
|
||||
# Regression: the agent renders every KB document under
|
||||
# ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``),
|
||||
# so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as
|
||||
# ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path
|
||||
# must round-trip back to the row even though the title itself
|
||||
# does NOT end in ``.xml``.
|
||||
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml",
|
||||
)
|
||||
assert document is target_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_path_without_xml_suffix(self):
|
||||
# The user (or a hand-edited link) may pass the title-only form
|
||||
# ``/documents/2025-W2.pdf``. The resolver must still find the row
|
||||
# by literal title equality.
|
||||
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf",
|
||||
)
|
||||
assert document is target_doc
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from tests.e2e.fakes.composio_module import _drive_list_files
|
||||
|
||||
|
||||
def _ids(result: dict) -> set[str]:
|
||||
return {item["id"] for item in result["data"]["files"]}
|
||||
|
||||
|
||||
def test_drive_list_files_filters_shortcuts_and_trashed_items():
|
||||
result = _drive_list_files(
|
||||
{
|
||||
"q": (
|
||||
"'root' in parents and trashed = false and "
|
||||
"mimeType != 'application/vnd.google-apps.shortcut'"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
ids = _ids(result)
|
||||
|
||||
assert "fake-file-canary" in ids
|
||||
assert "fake-shortcut-canary" not in ids
|
||||
assert "fake-file-trashed" not in ids
|
||||
|
||||
|
||||
def test_drive_list_files_filters_to_exact_mime_type():
|
||||
result = _drive_list_files(
|
||||
{"q": "'root' in parents and trashed = false and mimeType = 'text/plain'"}
|
||||
)
|
||||
|
||||
assert _ids(result) == {"fake-file-canary"}
|
||||
|
||||
|
||||
def test_drive_list_files_uses_requested_parent_folder():
|
||||
result = _drive_list_files(
|
||||
{"q": "'fake-folder-projects' in parents and trashed = false"}
|
||||
)
|
||||
|
||||
assert _ids(result) == {"fake-file-roadmap"}
|
||||
|
|
@ -1,8 +1,17 @@
|
|||
"""Unit tests: build_composio_credentials returns valid Google Credentials.
|
||||
"""Unit tests: Composio credential helpers + ``get_access_token`` masking guard.
|
||||
|
||||
Mocks the Composio SDK (external system boundary) and verifies that the
|
||||
returned ``google.oauth2.credentials.Credentials`` object is correctly
|
||||
configured with a token and a working refresh handler.
|
||||
Covers two seams between Surfsense and Composio:
|
||||
|
||||
1. ``build_composio_credentials`` returns a ``google.oauth2.credentials.Credentials``
|
||||
object with a working refresh handler (mocks the whole ``ComposioService``).
|
||||
2. ``ComposioService.get_access_token`` rejects masked / missing tokens with
|
||||
actionable error messages (mocks only the Composio SDK boundary so the
|
||||
real guard logic is exercised).
|
||||
|
||||
The masking guard is the boundary handler that production tripped over when
|
||||
Composio's "Mask Connected Account Secrets" project setting was enabled.
|
||||
The corresponding fix landed in ``cea8618``; these tests lock that contract
|
||||
in place so any future weakening of the guard surfaces immediately.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -14,6 +23,11 @@ from google.oauth2.credentials import Credentials
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_composio_credentials — high-level wrapper tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("app.services.composio_service.ComposioService")
|
||||
def test_returns_credentials_with_token_and_expiry(mock_composio_service):
|
||||
"""build_composio_credentials returns a Credentials object with the Composio access token."""
|
||||
|
|
@ -54,3 +68,85 @@ def test_refresh_handler_fetches_fresh_token(mock_composio_service):
|
|||
assert new_token == "refreshed-token"
|
||||
assert new_expiry > datetime.now(UTC).replace(tzinfo=None)
|
||||
assert mock_service.get_access_token.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ComposioService.get_access_token — boundary masking guard tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _service_with_account(account: object):
|
||||
"""Build a real ``ComposioService`` whose underlying Composio SDK is faked.
|
||||
|
||||
Only the SDK boundary is patched — the real ``get_access_token`` method
|
||||
runs, so changes to the masking / missing-token guards surface here.
|
||||
"""
|
||||
from app.services import composio_service as composio_service_module
|
||||
|
||||
with patch.object(composio_service_module, "Composio") as mock_composio_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.connected_accounts.get.return_value = account
|
||||
mock_composio_cls.return_value = mock_client
|
||||
|
||||
service = composio_service_module.ComposioService(api_key="unit-test-api-key")
|
||||
|
||||
# ``service.client`` already references ``mock_client`` even after the
|
||||
# patch context exits because the constructor captured it during init.
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.parametrize("masked_token", ["x", "xxxxxxxx", "x" * 19])
|
||||
def test_get_access_token_raises_on_masked_token(masked_token):
|
||||
"""Tokens shorter than the 20-char unmask threshold must raise with the dashboard hint.
|
||||
|
||||
Composio masks ``state.val.access_token`` by default (project setting
|
||||
"Mask Connected Account Secrets"). A masked token will always silently
|
||||
fail downstream OAuth calls, so the guard surfaces it with the exact
|
||||
text needed to fix the dashboard config.
|
||||
"""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = masked_token
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match="Mask Connected Account Secrets"):
|
||||
service.get_access_token("any-account-id")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_state_val_missing():
|
||||
"""No ``state.val`` on the connected account is a hard failure with an account-id hint."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state = None
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No state\.val.*missing-state-account"):
|
||||
service.get_access_token("missing-state-account")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_access_token_empty():
|
||||
"""``state.val`` present but ``access_token`` empty must fail before the masking check."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = ""
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No access_token.*missing-token-account"):
|
||||
service.get_access_token("missing-token-account")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_access_token_none():
|
||||
"""``state.val.access_token = None`` must fail before the masking check."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = None
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No access_token.*none-token-account"):
|
||||
service.get_access_token("none-token-account")
|
||||
|
||||
|
||||
def test_get_access_token_returns_unmasked_token():
|
||||
"""Happy path: a >=20-char access token is returned verbatim."""
|
||||
fake_account = MagicMock()
|
||||
unmasked = "u" * 32
|
||||
fake_account.state.val.access_token = unmasked
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
assert service.get_access_token("happy-account") == unmasked
|
||||
|
|
|
|||
|
|
@ -118,12 +118,8 @@ def test_get_by_tool_call_id_matches_action_request_payload() -> None:
|
|||
tasks=(
|
||||
_Task(
|
||||
interrupts=(
|
||||
_Interrupt(
|
||||
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
|
||||
),
|
||||
_Interrupt(
|
||||
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
|
||||
),
|
||||
_Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"),
|
||||
_Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -146,9 +142,7 @@ def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
|
|||
|
||||
def test_interrupt_without_id_falls_back_to_none() -> None:
|
||||
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
|
||||
)
|
||||
state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),))
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0].interrupt_id is None
|
||||
|
|
|
|||
|
|
@ -37,9 +37,7 @@ def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
|
|||
"context": {"reason": "destructive"},
|
||||
}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert out["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "a@b"}}
|
||||
]
|
||||
assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}]
|
||||
assert out["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
|
|
|
|||
|
|
@ -158,9 +158,7 @@ def _classify_cases() -> list[Exception]:
|
|||
"""Inputs that the FE depends on being mapped to specific error codes."""
|
||||
return [
|
||||
Exception("totally generic error"),
|
||||
Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
|
||||
),
|
||||
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
|
||||
Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error",'
|
||||
'"code":429}}'
|
||||
|
|
@ -220,7 +218,7 @@ class _FakeStreamingService:
|
|||
self.calls.append(
|
||||
{"message": message, "error_code": error_code, "extra": extra}
|
||||
)
|
||||
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
|
||||
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
|
||||
|
||||
|
||||
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
|
||||
|
|
|
|||
|
|
@ -60,8 +60,14 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
|
|||
service = _StreamingService()
|
||||
agent = _Agent(
|
||||
[
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"data": {"chunk": _Chunk(content="Hello")},
|
||||
},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"data": {"chunk": _Chunk(content=" world")},
|
||||
},
|
||||
]
|
||||
)
|
||||
result = StreamingResult()
|
||||
|
|
|
|||
|
|
@ -37,7 +37,9 @@ def test_clear_ignored_for_non_task_tool() -> None:
|
|||
def test_clear_ignored_when_task_run_id_mismatches() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-open")
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
|
||||
clear_task_span_if_delegating_task_ended(
|
||||
state, tool_name="task", run_id="run-other"
|
||||
)
|
||||
assert state.active_span_id is not None
|
||||
assert state.active_task_run_id == "run-open"
|
||||
|
||||
|
|
|
|||
|
|
@ -240,9 +240,7 @@ class TestToolHeavyTurn:
|
|||
class TestToolCallSpanMetadata:
|
||||
def test_input_available_merges_new_metadata_keys_after_start(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
|
||||
)
|
||||
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"})
|
||||
b.on_tool_input_available(
|
||||
"call_t",
|
||||
"task",
|
||||
|
|
@ -257,9 +255,7 @@ class TestToolCallSpanMetadata:
|
|||
|
||||
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
|
||||
)
|
||||
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"})
|
||||
b.on_tool_input_available(
|
||||
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
|
||||
)
|
||||
|
|
|
|||
93
surfsense_backend/tests/unit/utils/test_oauth_security.py
Normal file
93
surfsense_backend/tests/unit/utils/test_oauth_security.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.utils.oauth_security import OAuthStateManager
|
||||
|
||||
SECRET = "unit-test-secret"
|
||||
|
||||
|
||||
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
|
||||
"""Build an OAuth state payload compatible with OAuthStateManager."""
|
||||
signature_payload = payload.copy()
|
||||
payload_str = json.dumps(signature_payload, sort_keys=True)
|
||||
computed_signature = hmac.new(
|
||||
SECRET.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
encoded_payload = {
|
||||
**signature_payload,
|
||||
"signature": signature if signature is not None else computed_signature,
|
||||
}
|
||||
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
|
||||
|
||||
|
||||
def test_validate_state_accepts_fresh_signed_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
user_id = uuid4()
|
||||
|
||||
state = mgr.generate_secure_state(
|
||||
space_id=1,
|
||||
user_id=user_id,
|
||||
toolkit_id="googledrive",
|
||||
)
|
||||
|
||||
decoded = mgr.validate_state(state)
|
||||
|
||||
assert decoded["space_id"] == 1
|
||||
assert decoded["user_id"] == str(user_id)
|
||||
assert decoded["toolkit_id"] == "googledrive"
|
||||
|
||||
|
||||
def test_validate_state_rejects_expired_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
expired_state = _encode_state(
|
||||
{
|
||||
"space_id": 1,
|
||||
"user_id": str(uuid4()),
|
||||
"timestamp": int(time.time()) - 3600,
|
||||
"toolkit_id": "googledrive",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state(expired_state)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "expired" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_validate_state_rejects_tampered_signature():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
tampered_state = _encode_state(
|
||||
{
|
||||
"space_id": 1,
|
||||
"user_id": str(uuid4()),
|
||||
"timestamp": int(time.time()),
|
||||
"toolkit_id": "googledrive",
|
||||
},
|
||||
signature="deadbeef" * 8,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state(tampered_state)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "tampering" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_validate_state_rejects_malformed_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state("not-base64-and-not-json")
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "invalid state format" in exc.value.detail.lower()
|
||||
Loading…
Add table
Add a link
Reference in a new issue