Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-12 21:23:37 +02:00
commit 246dae40a8
229 changed files with 36484 additions and 436 deletions

View file

@ -24,4 +24,6 @@ wheels/
*.egg
.pytest_cache/
.coverage
htmlcov/
htmlcov/
tests/

View file

@ -46,6 +46,10 @@ class SurfSenseContextSchema:
Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
mentioned_folder_ids: KB folders the user @-mentioned this turn
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
entries in ``<priority_documents>`` so the agent prioritises
walking those folders with ``ls`` / ``find_documents``.
file_operation_contract: One-shot file operation contract emitted
by ``FileIntentMiddleware`` for the upcoming turn.
turn_id / request_id: Correlation IDs surfaced by the streaming
@ -59,6 +63,7 @@ class SurfSenseContextSchema:
search_space_id: int | None = None
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
file_operation_contract: FileOperationContractState | None = None
turn_id: str | None = None
request_id: str | None = None

View file

@ -0,0 +1,281 @@
"""Resolve @-mention chips to canonical virtual paths and substitute the
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
the agent sees.
The frontend's mention seam is a single discriminated-union list of
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
reaches the backend stream task we have three needs that this module
centralises:
1. Map each chip to its canonical virtual path
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
folders) so the agent sees concrete filesystem locations instead of
ambiguous ``@``-titles.
2. Substitute ``@title`` tokens in the user-typed text with backtick-
wrapped paths so the path becomes part of the ``HumanMessage`` body
the LLM consumes without rewriting the persisted user message
text (which keeps ``@title`` so chip rendering on reload is
unchanged).
3. Surface the resolved id sets (docs + folders) to the priority
middleware so it can render ``[USER-MENTIONED]`` priority entries
without re-doing path resolution.
This is intentionally one module see the architectural note in
``mention-paths-and-folders`` plan: previously the doc-resolution lived
inline in ``stream_new_chat`` and the folder mention had no resolution
at all. Centralising both behind a single ``resolve_mentions`` call
turns a leaky multi-field seam into a single deeper interface.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, Folder
from app.schemas.new_chat import MentionedDocumentInfo
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ResolvedMention:
"""Canonical view of a single @-mention chip.
``virtual_path`` is the path the agent will see (no trailing slash
for documents, trailing ``/`` for folders to match the convention
used by ``KnowledgeTreeMiddleware``).
"""
kind: str # "doc" | "folder"
id: int
title: str
virtual_path: str
@dataclass
class ResolvedMentionSet:
"""Aggregate result of resolving a turn's mention chips.
``token_to_path`` maps ``@title`` (the literal token the user typed
and the editor emitted) to the canonical virtual path for that
chip. It is produced longest-token-first so substitution mirrors
``parseMentionSegments`` on the frontend (a longer title like
``@Project Roadmap`` is never shadowed by a shorter prefix
``@Project``).
``mentioned_document_ids`` collapses doc + surfsense_doc chips into
a single ordered, deduped list because the priority middleware
treats them uniformly downstream see
``KnowledgePriorityMiddleware._compute_priority_paths``.
"""
mentions: list[ResolvedMention] = field(default_factory=list)
token_to_path: list[tuple[str, str]] = field(default_factory=list)
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
"""Return ``/documents/Folder/Sub/`` for a folder id.
Falls back to the documents root when the folder is missing from
the index (deleted or in a different search space). Trailing slash
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
the agent's ``ls`` can dispatch on it as a directory.
"""
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
return f"{base}/" if not base.endswith("/") else base
async def resolve_mentions(
session: AsyncSession,
*,
search_space_id: int,
mentioned_documents: list[MentionedDocumentInfo] | None,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
) -> ResolvedMentionSet:
"""Resolve every @-mention chip on a turn into virtual paths.
The function takes both the ``mentioned_documents`` discriminated
list (chip metadata used for substitution + persistence) and the
parallel id arrays (``mentioned_document_ids``,
``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two
reasons:
* Legacy clients that haven't migrated to the unified chip list
still send the id arrays we treat the union as authoritative.
* The id arrays are the canonical input to
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
returning the deduped, validated lists lets the route forward
them unchanged.
Resolution is best-effort: a chip whose id no longer exists (e.g.
document was deleted between mention and submit) is silently
dropped. The agent still sees the user's original text, just
without a backtick-path substitution for that chip.
"""
chip_doc_ids: list[int] = []
chip_folder_ids: list[int] = []
chip_titles_by_id: dict[tuple[str, int], str] = {}
if mentioned_documents:
for chip in mentioned_documents:
kind = chip.kind
if kind == "folder":
chip_folder_ids.append(chip.id)
else:
chip_doc_ids.append(chip.id)
chip_titles_by_id[(kind, chip.id)] = chip.title
doc_id_pool: list[int] = list(
dict.fromkeys(
[
*(mentioned_document_ids or []),
*(mentioned_surfsense_doc_ids or []),
*chip_doc_ids,
]
)
)
folder_id_pool: list[int] = list(
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
)
if not doc_id_pool and not folder_id_pool:
return ResolvedMentionSet()
index = await build_path_index(session, search_space_id)
doc_rows: dict[int, Document] = {}
if doc_id_pool:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_id_pool),
)
)
for row in result.scalars().all():
doc_rows[row.id] = row
folder_rows: dict[int, Folder] = {}
if folder_id_pool:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.id.in_(folder_id_pool),
)
)
for row in result.scalars().all():
folder_rows[row.id] = row
resolved: list[ResolvedMention] = []
accepted_doc_ids: list[int] = []
accepted_folder_ids: list[int] = []
for doc_id in doc_id_pool:
row = doc_rows.get(doc_id)
if row is None:
logger.debug(
"mention_resolver: dropping doc id=%s (not found in space=%s)",
doc_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
path = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
resolved.append(
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
)
accepted_doc_ids.append(row.id)
for folder_id in folder_id_pool:
row = folder_rows.get(folder_id)
if row is None:
logger.debug(
"mention_resolver: dropping folder id=%s (not found in space=%s)",
folder_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
path = _folder_virtual_path(row.id, index.folder_paths)
resolved.append(
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
)
accepted_folder_ids.append(row.id)
token_to_path: list[tuple[str, str]] = []
seen_tokens: set[str] = set()
for mention in resolved:
if not mention.title:
continue
token = f"@{mention.title}"
if token in seen_tokens:
continue
seen_tokens.add(token)
token_to_path.append((token, mention.virtual_path))
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
return ResolvedMentionSet(
mentions=resolved,
token_to_path=token_to_path,
mentioned_document_ids=accepted_doc_ids,
mentioned_folder_ids=accepted_folder_ids,
)
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
Mirrors ``parseMentionSegments`` on the frontend: longest token
first, single forward pass, no regex (titles can contain regex
metacharacters). The substitution is idempotent for already-
substituted text because the backtick-wrapped path no longer
starts with ``@``.
Empty / no-op cases short-circuit so callers can pass this through
unconditionally without paying for a scan.
"""
if not text or not token_to_path:
return text
out: list[str] = []
i = 0
n = len(text)
while i < n:
matched: tuple[str, str] | None = None
for token, path in token_to_path:
if text.startswith(token, i):
matched = (token, path)
break
if matched is None:
out.append(text[i])
i += 1
continue
token, path = matched
out.append(f"`{path}`")
i += len(token)
return "".join(out)
__all__ = [
"ResolvedMention",
"ResolvedMentionSet",
"resolve_mentions",
"substitute_in_text",
]

View file

@ -54,6 +54,7 @@ from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
Folder,
shielded_async_session,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
@ -836,6 +837,22 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
# Folder mentions live alongside doc mentions on the runtime
# context. They never feed hybrid search (folders aren't
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
# priority entries so the agent walks the folder with ``ls`` /
# ``find_documents`` instead of ignoring it. Cloud filesystem
# mode only.
folder_mention_ids: list[int] = []
if (
ctx is not None
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
== FilesystemMode.CLOUD
):
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
if ctx_folders:
folder_mention_ids = list(ctx_folders)
mentioned_results: list[dict[str, Any]] = []
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
@ -880,12 +897,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
priority, matched_chunk_ids = await self._materialize_priority(merged)
if folder_mention_ids:
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
priority = folder_entries + priority
_perf_log.info(
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(priority),
len(mentioned_results),
len(folder_mention_ids),
)
update: dict[str, Any] = {
@ -899,6 +921,58 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
update["messages"] = new_messages
return update
async def _materialize_folder_priority(
self, folder_ids: list[int]
) -> list[dict[str, Any]]:
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
path (matching ``KnowledgeTreeMiddleware`` and the agent's
``ls`` adapter) and is flagged ``mentioned=True`` so the
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
``None`` so the renderer prints ``n/a`` folders aren't
ranked, the agent decides which children to read.
"""
if not folder_ids:
return []
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, self.search_space_id)
folder_rows = await session.execute(
select(Folder.id, Folder.name).where(
Folder.search_space_id == self.search_space_id,
Folder.id.in_(folder_ids),
)
)
folder_titles: dict[int, str] = {
row.id: row.name for row in folder_rows.all()
}
entries: list[dict[str, Any]] = []
seen: set[int] = set()
for folder_id in folder_ids:
if folder_id in seen:
continue
seen.add(folder_id)
base = index.folder_paths.get(folder_id)
if base is None:
logger.debug(
"kb_priority: dropping folder id=%s (missing from path index)",
folder_id,
)
continue
path = base if base.endswith("/") else f"{base}/"
entries.append(
{
"path": path,
"score": None,
"document_id": None,
"folder_id": folder_id,
"title": folder_titles.get(folder_id, ""),
"mentioned": True,
}
)
return entries
async def _materialize_priority(
self, merged: list[dict[str, Any]]
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.path_resolver import virtual_path_to_doc
from app.db import (
Chunk,
Document,
@ -752,7 +753,24 @@ async def get_document_by_virtual_path(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Resolve a knowledge-base document id by exact virtual path."""
"""Resolve a knowledge-base document by its agent-facing virtual path.
The agent renders every document under ``/documents/...`` with a
``.xml`` extension appended via ``safe_filename`` (so a PDF titled
``2025-W2.pdf`` becomes ``/documents/2025-W2.pdf.xml``). When the user
clicks that path in an answer, this endpoint must round-trip back to
the underlying ``Document`` row regardless of its type agent-created
NOTE docs (which carry ``virtual_path`` in metadata), uploaded PDFs,
and connector docs all flow through here.
Resolution is delegated to :func:`virtual_path_to_doc`, the single
source of truth that handles:
* ``unique_identifier_hash`` lookup (agent NOTE fast path)
* ``" (<doc_id>).xml"`` disambiguation suffixes
* ``.xml`` extension stripping for title-based fallback
* ``safe_filename`` round-trip for connector titles with lossy chars
"""
try:
await check_permission(
session,
@ -762,24 +780,19 @@ async def get_document_by_virtual_path(
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(
Document.id,
Document.title,
Document.document_type,
).filter(
Document.search_space_id == search_space_id,
Document.document_metadata["virtual_path"].as_string() == virtual_path,
)
document = await virtual_path_to_doc(
session,
search_space_id=search_space_id,
virtual_path=virtual_path,
)
row = result.first()
if row is None:
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentTitleRead(
id=row.id,
title=row.title,
document_type=row.document_type,
id=document.id,
title=document.title,
document_type=document.document_type,
folder_id=document.folder_id,
)
except HTTPException:
raise

View file

@ -1781,6 +1781,7 @@ async def handle_new_chat(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_folder_ids=request.mentioned_folder_ids,
mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
@ -2266,6 +2267,7 @@ async def regenerate_response(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_folder_ids=request.mentioned_folder_ids,
mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,

View file

@ -201,18 +201,34 @@ class NewChatUserImagePart(BaseModel):
class MentionedDocumentInfo(BaseModel):
"""Display metadata for a single ``@``-mentioned document.
"""Display metadata for a single ``@``-mention chip.
The full triple ``{id, title, document_type}`` is forwarded by the
frontend mention chip so the server can embed it in the persisted
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
The history loader then renders the chips on reload without an extra
Carries either a knowledge-base document or a knowledge-base folder
(discriminated by ``kind``). The full triple
``{id, title, document_type}`` is forwarded by the frontend mention
chip so the server can embed it in the persisted user message
``ContentPart[]`` (single ``mentioned-documents`` part). The
history loader then renders the chips on reload without an extra
fetch mirrors the pre-refactor frontend ``persistUserTurn`` shape.
``kind`` defaults to ``"doc"`` so legacy clients and persisted rows
that predate folder mentions deserialise unchanged.
"""
id: int
title: str = Field(..., min_length=1, max_length=500)
document_type: str = Field(..., min_length=1, max_length=100)
kind: Literal["doc", "folder"] = Field(
default="doc",
description=(
"Discriminator for the chip's referent: ``doc`` is a "
"knowledge-base ``Document`` row, ``folder`` is a "
"knowledge-base ``Folder`` row. Folders carry the sentinel "
"``document_type='FOLDER'`` to keep the frontend dedup key "
"``(kind:document_type:id)`` from colliding doc and folder "
"ids that happen to share an integer value."
),
)
class NewChatRequest(BaseModel):
@ -228,15 +244,26 @@ class NewChatRequest(BaseModel):
mentioned_surfsense_doc_ids: list[int] | None = (
None # Optional SurfSense documentation IDs mentioned with @ in the chat
)
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Optional knowledge-base folder IDs the user mentioned with "
"@. Resolved to virtual paths (``/documents/.../``) by "
"``mention_resolver`` and surfaced to the agent via "
"(a) backtick-wrapped substitution in ``user_query`` and "
"(b) a ``[USER-MENTIONED]`` entry in ``<priority_documents>``. "
"The agent's ``ls`` tool can then walk the folder itself."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document. Persisted as a ``mentioned-documents`` "
"ContentPart on the user message so reload renders chips "
"without an extra fetch. Optional and additive — when None "
"the user message is persisted without a mentioned-documents "
"part."
"Display metadata (id, title, document_type, kind) for every "
"@-mention chip — both documents and folders. Persisted as a "
"``mentioned-documents`` ContentPart on the user message so "
"reload renders chips without an extra fetch. Optional and "
"additive — when None the user message is persisted without "
"a mentioned-documents part."
),
)
disabled_tools: list[str] | None = (
@ -290,14 +317,22 @@ class RegenerateRequest(BaseModel):
)
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Optional knowledge-base folder IDs the user mentioned with "
"@ on the edited user turn. Only used when ``user_query`` is "
"non-None (edit). Mirrors ``NewChatRequest.mentioned_folder_ids``."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document on the edited user turn. Only used "
"when ``user_query`` is non-None (edit). Persisted as a "
"``mentioned-documents`` ContentPart on the new user "
"message. None means no chip metadata."
"Display metadata (id, title, document_type, kind) for every "
"@-mention chip on the edited user turn — both documents and "
"folders. Only used when ``user_query`` is non-None (edit). "
"Persisted as a ``mentioned-documents`` ContentPart on the "
"new user message. None means no chip metadata."
),
)
disabled_tools: list[str] | None = None
@ -373,6 +408,16 @@ class ResumeRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Forwarded for symmetry with /new_chat and /regenerate. "
"Resume reuses the original interrupted user turn so this "
"field is informational only — the originating turn's "
"folder mentions already shaped the priority hints baked "
"into the agent's checkpoint."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(

View file

@ -43,9 +43,7 @@ class EmitterRegistry:
return main_emitter()
def has_active_subagents(self) -> bool:
return any(
emitter.level == "subagent" for emitter in self._by_run_id.values()
)
return any(emitter.level == "subagent" for emitter in self._by_run_id.values())
def clear(self) -> None:
self._by_run_id.clear()

View file

@ -6,9 +6,7 @@ from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_reasoning_start(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
def format_reasoning_start(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter)
)
@ -28,9 +26,7 @@ def format_reasoning_delta(
)
def format_reasoning_end(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
def format_reasoning_end(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter)
)

View file

@ -7,9 +7,7 @@ from ..envelope import format_sse
def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-start", "id": text_id}, emitter)
)
return format_sse(attach_emitted_by({"type": "text-start", "id": text_id}, emitter))
def format_text_delta(
@ -26,6 +24,4 @@ def format_text_delta(
def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-end", "id": text_id}, emitter)
)
return format_sse(attach_emitted_by({"type": "text-end", "id": text_id}, emitter))

View file

@ -84,9 +84,7 @@ class StreamingService:
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
return lifecycle.format_step_finish(emitter=emitter)
def format_text_start(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
def format_text_start(self, text_id: str, *, emitter: Emitter | None = None) -> str:
return text.format_text_start(text_id, emitter=emitter)
def format_text_delta(
@ -94,9 +92,7 @@ class StreamingService:
) -> str:
return text.format_text_delta(text_id, delta, emitter=emitter)
def format_text_end(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
def format_text_end(self, text_id: str, *, emitter: Emitter | None = None) -> str:
return text.format_text_end(text_id, emitter=emitter)
def format_reasoning_start(

View file

@ -51,7 +51,9 @@ logger = logging.getLogger(__name__)
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
def _merge_tool_part_metadata(
part: dict[str, Any], metadata: dict[str, Any] | None
) -> None:
"""Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins.
Used for tool-call linkage (``spanId``, ``thinkingStepId``, ): a later

View file

@ -109,17 +109,18 @@ def _build_user_content(
[{"type": "text", "text": "..."},
{"type": "image", "image": "data:..."},
{"type": "mentioned-documents", "documents": [{"id": int,
"title": str, "document_type": str}, ...]}]
"title": str, "document_type": str, "kind": "doc" | "folder"},
...]}]
The companion reader is
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
which expects exactly this shape keep them in sync.
``mentioned_documents``: optional list of ``{id, title, document_type}``
dicts. When non-empty (and a ``mentioned-documents`` part is not already
in some other input shape), a single ``{"type": "mentioned-documents",
"documents": [...]}`` part is appended. Mirrors the FE injection at
``page.tsx:281-286`` (``persistUserTurn``).
``mentioned_documents``: optional list of mention chip dicts. Each
dict may include a ``kind`` discriminator (``"doc"`` or ``"folder"``)
so the persisted ContentPart round-trips folder chips on reload.
When ``kind`` is missing we default to ``"doc"`` so legacy clients
that haven't migrated to the union schema still persist correctly.
"""
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
for url in user_image_data_urls or ():
@ -135,11 +136,14 @@ def _build_user_content(
document_type = doc.get("document_type")
if doc_id is None or title is None or document_type is None:
continue
kind_raw = doc.get("kind", "doc")
kind = kind_raw if kind_raw in ("doc", "folder") else "doc"
normalized.append(
{
"id": doc_id,
"title": str(title),
"document_type": str(document_type),
"kind": kind,
}
)
if normalized:

View file

@ -43,6 +43,7 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
get_cancel_state,
@ -929,6 +930,7 @@ async def stream_new_chat(
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
@ -958,6 +960,7 @@ async def stream_new_chat(
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode)
checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
Yields:
@ -1502,6 +1505,53 @@ async def stream_new_chat(
)
recent_reports = list(recent_reports_result.scalars().all())
# Resolve @-mention chips to canonical virtual paths and rewrite
# the user-typed text so the LLM sees ``\`/documents/...\``` instead
# of bare ``@title``. The persisted user-message text keeps
# ``@title`` so chip rendering on reload is unchanged — see
# ``persistence._build_user_content``.
#
# Cloud mode only: local-folder mode keeps the legacy
# ``@title`` text path; mention support there is a follow-up
# task because the path scheme (mount-rooted) and the picker
# UI both need separate work.
accepted_folder_ids: list[int] = []
if fs_mode == FilesystemMode.CLOUD.value and (
mentioned_document_ids
or mentioned_surfsense_doc_ids
or mentioned_folder_ids
or mentioned_documents
):
from app.schemas.new_chat import (
MentionedDocumentInfo as _MentionedDocumentInfo,
)
chip_objs: list[_MentionedDocumentInfo] | None = None
if mentioned_documents:
chip_objs = []
for raw in mentioned_documents:
if isinstance(raw, _MentionedDocumentInfo):
chip_objs.append(raw)
continue
try:
chip_objs.append(_MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r",
raw,
)
resolved = await resolve_mentions(
session,
search_space_id=search_space_id,
mentioned_documents=chip_objs,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
)
user_query = substitute_in_text(user_query, resolved.token_to_path)
accepted_folder_ids = resolved.mentioned_folder_ids
# Format the user query with context (SurfSense docs + reports only)
final_query = user_query
context_parts = []
@ -1901,6 +1951,9 @@ async def stream_new_chat(
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
request_id=request_id,
turn_id=stream_result.turn_id,
)

View file

@ -26,9 +26,7 @@ def handle_report_progress(
return None, last_active_step_items
phase = data.get("phase", "")
topic_items = [
item for item in last_active_step_items if item.startswith("Topic:")
]
topic_items = [item for item in last_active_step_items if item.startswith("Topic:")]
if phase in ("revising_section", "adding_section"):
plan_items = [
@ -56,7 +54,9 @@ def handle_report_progress(
return frame, new_items
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
def handle_document_created(
data: dict[str, Any], *, streaming_service: Any
) -> str | None:
if not data.get("id"):
return None
return streaming_service.format_data(

View file

@ -13,7 +13,9 @@ from app.tasks.chat.streaming.handlers.tools import (
)
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended
from app.tasks.chat.streaming.relay.task_span import (
clear_task_span_if_delegating_task_ended,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
@ -32,9 +34,7 @@ def iter_tool_end_frames(
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
staged_file_path = (
state.file_path_by_run.pop(run_id, None) if run_id else None
)
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
if tool_name == "update_memory":
state.called_update_memory = True
@ -116,6 +116,4 @@ def iter_tool_end_frames(
)
yield from iter_tool_completion_emission_frames(emission_ctx)
clear_task_span_if_delegating_task_ended(
state, tool_name=tool_name, run_id=run_id
)
clear_task_span_if_delegating_task_ended(state, tool_name=tool_name, run_id=run_id)

View file

@ -15,7 +15,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items

View file

@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -34,7 +34,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -29,7 +29,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
@ -44,9 +46,7 @@ def resolve_completed_thinking(
else "Report"
)
word_count = (
tool_output.get("word_count", 0)
if isinstance(tool_output, dict)
else 0
tool_output.get("word_count", 0) if isinstance(tool_output, dict) else 0
)
is_revision = (
tool_output.get("is_revision", False)

View file

@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items

View file

@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Editing file", last_items)

View file

@ -24,9 +24,7 @@ def iter_completion_emission_frames(
output_text = om.group(1) if om else ""
thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE):
fpath = sf_match.group(1).strip()
if fpath and fpath not in ctx.stream_result.sandbox_files:
ctx.stream_result.sandbox_files.append(fpath)

View file

@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching files", last_items)

View file

@ -25,7 +25,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching content", last_items)

View file

@ -20,7 +20,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
if isinstance(tool_output, dict):
@ -38,9 +40,7 @@ def resolve_completed_thinking(
paths = [str(p) for p in parsed]
except (ValueError, SyntaxError):
paths = [
line.strip()
for line in ls_output.strip().split("\n")
if line.strip()
line.strip() for line in ls_output.strip().split("\n") if line.strip()
]
for p in paths:
name = p.rstrip("/").split("/")[-1]

View file

@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
d = as_tool_input_dict(tool_input)
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
display = p if len(p) <= 80 else "" + p[-77:]
return ToolStartThinking(title="Creating folder", items=[display] if display else [])
return ToolStartThinking(
title="Creating folder", items=[display] if display else []
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Creating folder", last_items)

View file

@ -27,7 +27,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Moving file", last_items)

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Reading file", last_items)

View file

@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Deleting file", last_items)

View file

@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
d = as_tool_input_dict(tool_input)
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
display = p if len(p) <= 80 else "" + p[-77:]
return ToolStartThinking(title="Deleting folder", items=[display] if display else [])
return ToolStartThinking(
title="Deleting folder", items=[display] if display else []
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Deleting folder", last_items)

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Writing file", last_items)

View file

@ -20,15 +20,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
return ToolStartThinking(
title="Planning tasks",
items=(
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
if todo_count
else []
[f"{todo_count} task{'s' if todo_count != 1 else ''}"] if todo_count else []
),
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Planning tasks", last_items)

View file

@ -58,14 +58,18 @@ def _emission_module(tool_name: str) -> str:
def _import_thinking(tool_name: str):
try:
return importlib.import_module(f"{_BASE}.{_thinking_module(tool_name)}.thinking")
return importlib.import_module(
f"{_BASE}.{_thinking_module(tool_name)}.thinking"
)
except ModuleNotFoundError:
return importlib.import_module(f"{_BASE}.default.thinking")
def _import_emission(tool_name: str):
try:
return importlib.import_module(f"{_BASE}.{_emission_module(tool_name)}.emission")
return importlib.import_module(
f"{_BASE}.{_emission_module(tool_name)}.emission"
)
except ModuleNotFoundError:
return importlib.import_module(f"{_BASE}.default.emission")

View file

@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -28,11 +28,7 @@ def iter_completion_emission_frames(
xml,
):
chunk_url, content = m.group(1).strip(), m.group(2).strip()
if (
chunk_url.startswith("http")
and chunk_url in citations
and content
):
if chunk_url.startswith("http") and chunk_url in citations and content:
citations[chunk_url]["snippet"] = (
content[:200] + "" if len(content) > 200 else content
)

View file

@ -0,0 +1,69 @@
# Backend E2E Test Harness
Strict fakes + alternative entrypoints used **only** by Playwright E2E.
Excluded from the production Docker image via `.dockerignore`.
## Files
| Path | Role |
| -------------------------------- | ------------------------------------------------------------------------------- |
| `run_backend.py` | FastAPI entrypoint that hijacks `sys.modules` before importing `app.app:app` |
| `run_celery.py` | Celery worker entrypoint with the same hijack + patch logic |
| `middleware/scenario.py` | `X-E2E-Scenario` header → ContextVar (read by fakes) |
| `fakes/composio_module.py` | Strict drop-in for the `composio` package; raises on unknown surface |
| `fakes/llm.py` | `fake_get_user_long_context_llm` returning a `FakeListChatModel` |
| `fakes/embeddings.py` | Deterministic 0.1-vector `embed_text` / `embed_texts` |
| `fakes/fixtures/drive_files.json`| Canned Drive listings + file contents (incl. canary tokens) |
## Why a sys.modules hijack?
Production code does `from composio import Composio` at module load
time. By the time the FastAPI app object exists, that binding has
already been resolved. The hijack runs **before** any `app.*` import,
so the binding resolves to our strict fake. No production source
changes; fakes are physically excluded from production images.
Belt + suspenders + no internet: the strict `__getattr__` in every
fake raises `NotImplementedError` if a future production code path
introduces a new SDK call. CI also sets `HTTPS_PROXY=http://127.0.0.1:1`
plus sentinel API keys so any leaked outbound HTTP fails immediately.
## Adding a new fake
1. Create `fakes/<sdk>_module.py` modelled on `composio_module.py`.
2. In `run_backend.py` and `run_celery.py`, register
`sys.modules["<sdk>"] = _fake_<sdk>` before the `from app.app import app`
line.
3. If the new fake needs scenario branching, read from
`tests.e2e.middleware.scenario.current_scenario()`.
## Reused by backend integration tests
The strict fakes are not only for Playwright. Backend route integration
tests can import the same fake before importing `app.app`, so Composio
route tests exercise production route code without touching the real
SDK:
```python
from tests.e2e.fakes import composio_module as _fake_composio
sys.modules["composio"] = _fake_composio
from app.app import app
```
See `surfsense_backend/tests/integration/composio/conftest.py` for the
current pattern.
## Running locally
```bash
cd surfsense_backend
uv run python tests/e2e/run_backend.py
# in a second shell:
uv run python tests/e2e/run_celery.py
```
Then in `surfsense_web`:
```bash
pnpm test:e2e
```

View file

@ -0,0 +1,7 @@
"""E2E test harness root.
This package is loaded only by the test entrypoints
(`tests/e2e/run_backend.py`, `tests/e2e/run_celery.py`). It is excluded
from the production Docker image via `surfsense_backend/.dockerignore`,
so production binaries never see this code.
"""

View file

@ -0,0 +1,8 @@
"""Strict fakes for third-party SDKs, used in E2E mode only.
Every fake here implements __getattr__ that raises NotImplementedError
on any unknown surface. Combined with sys.modules-level hijacking in
run_backend.py / run_celery.py, this makes silent pass-through to the
real SDK impossible: a future production code path that introduces a
new SDK call site fails CI with a clear "add this to the fake" message.
"""

View file

@ -0,0 +1,23 @@
"""Helpers for serving text and binary fixture file bodies."""
from __future__ import annotations
from pathlib import Path
from typing import Any
def _resolve_file_bytes(
fixture: dict[str, Any], key: str | None, fixtures_dir: Path
) -> bytes | None:
"""Resolve a fake file body, preferring binary fixture files over text."""
if not key:
return None
binary_path = fixture.get("_file_binary_paths", {}).get(key)
if binary_path is not None:
return (fixtures_dir / binary_path).read_bytes()
content = fixture.get("_file_contents", {}).get(key)
if content is None:
return None
return content.encode("utf-8")

View file

@ -0,0 +1,733 @@
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from typing import Any, Self
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
DRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001"
DRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_PDF_001"
DRIVE_PDF_CANARY_FILE = "e2e-canary.pdf"
COMPOSIO_DRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_COMPOSIO_DRIVE_PDF_001"
COMPOSIO_DRIVE_PDF_CANARY_FILE = "e2e-composio-canary.pdf"
GMAIL_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_GMAIL_001"
GMAIL_CANARY_SUBJECT = "E2E Canary Email"
GMAIL_CANARY_MESSAGE_ID = "fake-msg-canary-001"
CALENDAR_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CALENDAR_001"
CALENDAR_CANARY_SUMMARY = "E2E Canary Calendar Event"
ONEDRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_001"
ONEDRIVE_CANARY_FILE = "e2e-onedrive-canary.txt"
ONEDRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_PDF_001"
ONEDRIVE_PDF_CANARY_FILE = "e2e-onedrive-canary.pdf"
DROPBOX_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_001"
DROPBOX_CANARY_FILE = "e2e-dropbox-canary.txt"
DROPBOX_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_PDF_001"
DROPBOX_PDF_CANARY_FILE = "e2e-dropbox-canary.pdf"
NOTION_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_NOTION_001"
NOTION_CANARY_TITLE = "E2E Canary Notion Page"
CONFLUENCE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CONFLUENCE_001"
CONFLUENCE_CANARY_TITLE = "E2E Canary Confluence Page"
LINEAR_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_LINEAR_001"
LINEAR_CANARY_TITLE = "E2E Canary Linear Issue"
JIRA_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_JIRA_001"
JIRA_CANARY_SUMMARY = "E2E Canary Jira Issue"
JIRA_CANARY_KEY = "E2E-101"
SLACK_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_SLACK_001"
SLACK_CANARY_CHANNEL = "slack-e2e-canary"
CLICKUP_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CLICKUP_001"
CLICKUP_CANARY_TITLE = "E2E Canary ClickUp Task"
CLICKUP_CANARY_TASK_ID = "fake-clickup-task-canary-001"
MANUAL_UPLOAD_MD_CANARY_TOKEN = "E2E-MANUAL-UPLOAD-MD-CANARY-7f3a"
MANUAL_UPLOAD_MD_CANARY_FILE = "canary.md"
MANUAL_UPLOAD_PDF_CANARY_TOKEN = "E2E-MANUAL-UPLOAD-PDF-CANARY-9d2b"
MANUAL_UPLOAD_PDF_CANARY_FILE = "canary.pdf"
NO_RELEVANT_CONTENT_SENTINEL = "No relevant indexed content found."
NO_RELEVANT_CONTENT_QUERY = "E2E_NO_RELEVANT_CONTENT_SMOKE"
def _content_to_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(_content_to_text(item) for item in content)
if isinstance(content, dict):
text = content.get("text") or content.get("content")
if isinstance(text, str):
return text
return json.dumps(content, sort_keys=True)
if content is None:
return ""
return str(content)
def _messages_to_text(messages: list[BaseMessage]) -> str:
return "\n".join(_content_to_text(message.content) for message in messages)
def _contains_any(text: str, needles: tuple[str, ...]) -> bool:
lowered = text.lower()
return any(needle.lower() in lowered for needle in needles)
def _latest_tool_message(messages: list[BaseMessage]) -> BaseMessage | None:
return next(
(message for message in reversed(messages) if message.type == "tool"), None
)
class FakeChatLLM(BaseChatModel):
@property
def _llm_type(self) -> str:
return "e2e-fake-chat"
def bind_tools(self, tools: Any, **kwargs: Any) -> Self:
return self
def _response_for(self, messages: list[BaseMessage]) -> str:
latest_human = next(
(
_content_to_text(message.content)
for message in reversed(messages)
if message.type == "human"
),
"",
)
if NO_RELEVANT_CONTENT_QUERY in latest_human:
return NO_RELEVANT_CONTENT_SENTINEL
prompt_text = _messages_to_text(messages)
latest_tool = _latest_tool_message(messages)
latest_tool_name = getattr(latest_tool, "name", None)
latest_tool_text = _content_to_text(latest_tool.content) if latest_tool else ""
if (
latest_tool_name == "read_gmail_email"
and GMAIL_CANARY_TOKEN in latest_tool_text
):
return f"Gmail live tool content found: {GMAIL_CANARY_TOKEN}"
if (
latest_tool_name == "search_gmail"
and GMAIL_CANARY_MESSAGE_ID in latest_tool_text
):
return "Reading the matching Gmail message next."
if (
latest_tool_name == "search_calendar_events"
and CALENDAR_CANARY_TOKEN in latest_tool_text
):
return f"Calendar live tool content found: {CALENDAR_CANARY_TOKEN}"
if (
latest_tool_name == "list_issues"
and LINEAR_CANARY_TOKEN in latest_tool_text
):
return f"Linear live tool content found: {LINEAR_CANARY_TOKEN}"
if (
latest_tool_name == "searchJiraIssuesUsingJql"
and JIRA_CANARY_TOKEN in latest_tool_text
):
return f"Jira live tool content found: {JIRA_CANARY_TOKEN}"
if (
latest_tool_name == "slack_search_channels"
and SLACK_CANARY_TOKEN in latest_tool_text
):
return f"Slack live tool content found: {SLACK_CANARY_TOKEN}"
if (
latest_tool_name in {"clickup_search", "clickup_get_task"}
and CLICKUP_CANARY_TOKEN in latest_tool_text
):
return f"ClickUp live tool content found: {CLICKUP_CANARY_TOKEN}"
wants_gmail = _contains_any(
latest_human,
("gmail", "email", "message", GMAIL_CANARY_SUBJECT),
)
wants_calendar = _contains_any(
latest_human,
("calendar", "event", "meeting", CALENDAR_CANARY_SUMMARY),
)
wants_drive = _contains_any(
latest_human,
("drive", "file", "e2e-canary.txt"),
)
wants_drive_pdf = _contains_any(
latest_human,
(
"drive pdf",
DRIVE_PDF_CANARY_FILE,
DRIVE_PDF_CANARY_TOKEN,
COMPOSIO_DRIVE_PDF_CANARY_FILE,
COMPOSIO_DRIVE_PDF_CANARY_TOKEN,
),
) or (wants_drive and "pdf" in latest_human.lower())
wants_onedrive = _contains_any(
latest_human,
("onedrive", ONEDRIVE_CANARY_FILE, ONEDRIVE_CANARY_TOKEN),
)
wants_onedrive_pdf = wants_onedrive and _contains_any(
latest_human,
("pdf", ONEDRIVE_PDF_CANARY_FILE, ONEDRIVE_PDF_CANARY_TOKEN),
)
wants_dropbox = _contains_any(
latest_human,
("dropbox", DROPBOX_CANARY_FILE, DROPBOX_CANARY_TOKEN),
)
wants_dropbox_pdf = wants_dropbox and _contains_any(
latest_human,
("pdf", DROPBOX_PDF_CANARY_FILE, DROPBOX_PDF_CANARY_TOKEN),
)
wants_notion = _contains_any(
latest_human,
("notion", "page", NOTION_CANARY_TITLE),
)
wants_confluence = _contains_any(
latest_human,
("confluence", CONFLUENCE_CANARY_TITLE),
)
wants_linear = _contains_any(
latest_human,
("linear", "issue", LINEAR_CANARY_TITLE),
)
wants_jira = _contains_any(
latest_human,
(
"jira",
"atlassian",
JIRA_CANARY_SUMMARY,
JIRA_CANARY_KEY,
"surfsense-e2e.atlassian.net",
"fake-jira-cloud-001",
),
)
wants_slack = _contains_any(
latest_human,
("slack", SLACK_CANARY_TOKEN),
)
wants_clickup = _contains_any(
latest_human,
("clickup", CLICKUP_CANARY_TITLE),
)
wants_manual_upload = _contains_any(
latest_human,
(
"uploaded",
"manual upload",
MANUAL_UPLOAD_MD_CANARY_FILE,
MANUAL_UPLOAD_PDF_CANARY_FILE,
MANUAL_UPLOAD_MD_CANARY_TOKEN,
MANUAL_UPLOAD_PDF_CANARY_TOKEN,
),
)
wants_manual_upload_pdf = wants_manual_upload and _contains_any(
latest_human,
("pdf", MANUAL_UPLOAD_PDF_CANARY_FILE, MANUAL_UPLOAD_PDF_CANARY_TOKEN),
)
wants_manual_upload_md = wants_manual_upload and _contains_any(
latest_human,
(
"markdown",
".md",
MANUAL_UPLOAD_MD_CANARY_FILE,
MANUAL_UPLOAD_MD_CANARY_TOKEN,
),
)
has_gmail_evidence = (
GMAIL_CANARY_SUBJECT in prompt_text
or GMAIL_CANARY_MESSAGE_ID in prompt_text
or GMAIL_CANARY_TOKEN in prompt_text
)
has_calendar_evidence = (
CALENDAR_CANARY_SUMMARY in prompt_text
or CALENDAR_CANARY_TOKEN in prompt_text
)
has_drive_evidence = (
"e2e-canary.txt" in prompt_text
or "fake-file-canary" in prompt_text
or DRIVE_CANARY_TOKEN in prompt_text
)
has_native_drive_pdf_evidence = (
DRIVE_PDF_CANARY_FILE in prompt_text
or "fake-file-pdf-native" in prompt_text
or DRIVE_PDF_CANARY_TOKEN in prompt_text
)
has_composio_drive_pdf_evidence = (
COMPOSIO_DRIVE_PDF_CANARY_FILE in prompt_text
or "fake-file-pdf-composio" in prompt_text
or COMPOSIO_DRIVE_PDF_CANARY_TOKEN in prompt_text
)
has_onedrive_evidence = (
ONEDRIVE_CANARY_FILE in prompt_text
or "fake-onedrive-canary" in prompt_text
or ONEDRIVE_CANARY_TOKEN in prompt_text
)
has_onedrive_pdf_evidence = (
ONEDRIVE_PDF_CANARY_FILE in prompt_text
or "fake-onedrive-pdf-canary" in prompt_text
or ONEDRIVE_PDF_CANARY_TOKEN in prompt_text
)
has_dropbox_evidence = (
DROPBOX_CANARY_FILE in prompt_text
or "id:fake-dropbox-canary" in prompt_text
or DROPBOX_CANARY_TOKEN in prompt_text
)
has_dropbox_pdf_evidence = (
DROPBOX_PDF_CANARY_FILE in prompt_text
or "id:fake-dropbox-pdf-canary" in prompt_text
or DROPBOX_PDF_CANARY_TOKEN in prompt_text
)
has_notion_evidence = (
NOTION_CANARY_TITLE in prompt_text or NOTION_CANARY_TOKEN in prompt_text
)
has_confluence_evidence = (
CONFLUENCE_CANARY_TITLE in prompt_text
or CONFLUENCE_CANARY_TOKEN in prompt_text
or "fake-confluence-page-canary-001" in prompt_text
or "fake-confluence-space-001" in prompt_text
)
has_linear_evidence = (
LINEAR_CANARY_TITLE in prompt_text
or LINEAR_CANARY_TOKEN in prompt_text
or "fake-linear-issue-canary-001" in prompt_text
)
has_jira_evidence = (
JIRA_CANARY_SUMMARY in prompt_text
or JIRA_CANARY_TOKEN in prompt_text
or JIRA_CANARY_KEY in prompt_text
or "fake-jira-issue-canary-001" in prompt_text
or "fake-jira-cloud-001" in prompt_text
or "surfsense-e2e.atlassian.net" in prompt_text
)
has_slack_evidence = (
SLACK_CANARY_CHANNEL in prompt_text
or SLACK_CANARY_TOKEN in prompt_text
or "C_FAKE_SLACK_CANARY" in prompt_text
or "T_FAKE_SLACK_TEAM" in prompt_text
)
has_clickup_evidence = (
CLICKUP_CANARY_TITLE in prompt_text
or CLICKUP_CANARY_TOKEN in prompt_text
or CLICKUP_CANARY_TASK_ID in prompt_text
)
has_manual_upload_md_evidence = (
MANUAL_UPLOAD_MD_CANARY_FILE in prompt_text
or MANUAL_UPLOAD_MD_CANARY_TOKEN in prompt_text
)
has_manual_upload_pdf_evidence = (
MANUAL_UPLOAD_PDF_CANARY_FILE in prompt_text
or MANUAL_UPLOAD_PDF_CANARY_TOKEN in prompt_text
)
if wants_clickup and has_clickup_evidence:
return f"ClickUp content found: {CLICKUP_CANARY_TOKEN}"
if wants_slack and has_slack_evidence:
return f"Slack content found: {SLACK_CANARY_TOKEN}"
if wants_jira and has_jira_evidence:
return f"Jira content found: {JIRA_CANARY_TOKEN}"
if wants_linear and has_linear_evidence:
return f"Linear content found: {LINEAR_CANARY_TOKEN}"
if wants_confluence and has_confluence_evidence:
return f"Confluence content found: {CONFLUENCE_CANARY_TOKEN}"
if wants_notion and has_notion_evidence:
return f"Notion content found: {NOTION_CANARY_TOKEN}"
if wants_calendar and has_calendar_evidence:
return f"Calendar content found: {CALENDAR_CANARY_TOKEN}"
if wants_gmail and has_gmail_evidence:
return f"Gmail content found: {GMAIL_CANARY_TOKEN}"
if wants_onedrive_pdf and has_onedrive_pdf_evidence:
return f"OneDrive PDF content found: {ONEDRIVE_PDF_CANARY_TOKEN}"
if wants_onedrive and has_onedrive_evidence:
return f"OneDrive content found: {ONEDRIVE_CANARY_TOKEN}"
if wants_dropbox_pdf and has_dropbox_pdf_evidence:
return f"Dropbox PDF content found: {DROPBOX_PDF_CANARY_TOKEN}"
if wants_dropbox and has_dropbox_evidence:
return f"Dropbox content found: {DROPBOX_CANARY_TOKEN}"
if wants_drive_pdf and has_native_drive_pdf_evidence:
return f"Drive PDF content found: {DRIVE_PDF_CANARY_TOKEN}"
if wants_drive_pdf and has_composio_drive_pdf_evidence:
return f"Drive PDF content found: {COMPOSIO_DRIVE_PDF_CANARY_TOKEN}"
if wants_drive and has_drive_evidence:
return f"Drive content found: {DRIVE_CANARY_TOKEN}"
if wants_manual_upload_pdf and has_manual_upload_pdf_evidence:
return f"Manual upload PDF content found: {MANUAL_UPLOAD_PDF_CANARY_TOKEN}"
if wants_manual_upload_md and has_manual_upload_md_evidence:
return f"Manual upload MD content found: {MANUAL_UPLOAD_MD_CANARY_TOKEN}"
if (
has_notion_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Notion content found: {NOTION_CANARY_TOKEN}"
if (
has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Confluence content found: {CONFLUENCE_CANARY_TOKEN}"
if (
has_jira_evidence
and not has_confluence_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Jira content found: {JIRA_CANARY_TOKEN}"
if (
has_linear_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Linear content found: {LINEAR_CANARY_TOKEN}"
if (
has_calendar_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Calendar content found: {CALENDAR_CANARY_TOKEN}"
if (
has_gmail_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Gmail content found: {GMAIL_CANARY_TOKEN}"
if (
has_onedrive_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"OneDrive content found: {ONEDRIVE_CANARY_TOKEN}"
if (
has_dropbox_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Dropbox content found: {DROPBOX_CANARY_TOKEN}"
if (
has_drive_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_gmail_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Drive content found: {DRIVE_CANARY_TOKEN}"
if (
has_slack_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_clickup_evidence
):
return f"Slack content found: {SLACK_CANARY_TOKEN}"
if (
has_clickup_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
):
return f"ClickUp content found: {CLICKUP_CANARY_TOKEN}"
if (
has_manual_upload_pdf_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Manual upload PDF content found: {MANUAL_UPLOAD_PDF_CANARY_TOKEN}"
if (
has_manual_upload_md_evidence
and not has_confluence_evidence
and not has_jira_evidence
and not has_linear_evidence
and not has_notion_evidence
and not has_calendar_evidence
and not has_gmail_evidence
and not has_drive_evidence
and not has_onedrive_evidence
and not has_dropbox_evidence
and not has_slack_evidence
and not has_clickup_evidence
):
return f"Manual upload MD content found: {MANUAL_UPLOAD_MD_CANARY_TOKEN}"
return NO_RELEVANT_CONTENT_SENTINEL
def _tool_call_message_for(self, messages: list[BaseMessage]) -> AIMessage | None:
latest_human = next(
(
_content_to_text(message.content)
for message in reversed(messages)
if message.type == "human"
),
"",
)
latest_tool = _latest_tool_message(messages)
latest_tool_name = getattr(latest_tool, "name", None)
latest_tool_text = _content_to_text(latest_tool.content) if latest_tool else ""
if (
latest_tool_name == "search_gmail"
and GMAIL_CANARY_MESSAGE_ID in latest_tool_text
):
return AIMessage(
content="",
tool_calls=[
{
"name": "read_gmail_email",
"args": {"message_id": GMAIL_CANARY_MESSAGE_ID},
"id": "call_e2e_read_gmail",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
("gmail", "email", "message", GMAIL_CANARY_SUBJECT),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "search_gmail",
"args": {
"query": f"subject:{GMAIL_CANARY_SUBJECT}",
"max_results": 5,
},
"id": "call_e2e_search_gmail",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
("calendar", "event", "meeting", CALENDAR_CANARY_SUMMARY),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "search_calendar_events",
"args": {
"start_date": "2026-05-07",
"end_date": "2026-05-21",
"max_results": 10,
},
"id": "call_e2e_search_calendar_events",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
(
"jira",
"atlassian",
JIRA_CANARY_SUMMARY,
JIRA_CANARY_KEY,
"surfsense-e2e.atlassian.net",
"fake-jira-cloud-001",
),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "searchJiraIssuesUsingJql",
"args": {
"jql": f'summary ~ "{JIRA_CANARY_SUMMARY}"',
"maxResults": 5,
},
"id": "call_e2e_search_jira_issues",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
("linear", "issue", LINEAR_CANARY_TITLE),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "list_issues",
"args": {"query": LINEAR_CANARY_TITLE, "limit": 5},
"id": "call_e2e_list_linear_issues",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
("slack", SLACK_CANARY_TOKEN),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "slack_search_channels",
"args": {"query": SLACK_CANARY_CHANNEL, "limit": 5},
"id": "call_e2e_search_slack_channels",
}
],
)
if latest_tool is None and _contains_any(
latest_human,
("clickup", CLICKUP_CANARY_TITLE),
):
return AIMessage(
content="",
tool_calls=[
{
"name": "clickup_search",
"args": {"query": CLICKUP_CANARY_TITLE, "limit": 5},
"id": "call_e2e_search_clickup_tasks",
}
],
)
return None
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
del stop, run_manager, kwargs
message = self._tool_call_message_for(messages) or AIMessage(
content=self._response_for(messages), tool_calls=[]
)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
del stop, run_manager, kwargs
tool_call_message = self._tool_call_message_for(messages)
if tool_call_message:
for tool_call in tool_call_message.tool_calls:
yield ChatGenerationChunk(
message=AIMessageChunk(
content="",
tool_call_chunks=[
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": 0,
}
],
)
)
return
yield ChatGenerationChunk(
message=AIMessageChunk(content=self._response_for(messages))
)
def fake_create_chat_litellm_from_agent_config(
*args: Any, **kwargs: Any
) -> FakeChatLLM:
del args, kwargs
return FakeChatLLM()
def fake_create_chat_litellm_from_config(*args: Any, **kwargs: Any) -> FakeChatLLM:
del args, kwargs
return FakeChatLLM()

View file

@ -0,0 +1,134 @@
"""Strict ClickUp MCP OAuth/tool fakes for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "clickup_tasks.json"
_AUTHORIZATION_URL = "https://mcp.clickup.com/authorize"
_REGISTRATION_URL = "https://mcp.clickup.com/register"
_TOKEN_URL = "https://mcp.clickup.com/token"
_MCP_URL = "https://mcp.clickup.com/mcp"
_CLIENT_ID = "fake-clickup-mcp-client-id"
_CLIENT_SECRET = "fake-clickup-mcp-client-secret"
_ACCESS_TOKEN = "fake-clickup-mcp-access-token"
_REFRESH_TOKEN = "fake-clickup-mcp-refresh-token"
_OAUTH_CODE = "fake-clickup-oauth-code"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
def _task_text(task: dict[str, Any]) -> str:
return (
f"{task['name']}\n"
f"id: {task['id']}\n"
f"workspace: {task['workspace_name']} ({task['workspace_id']})\n"
f"list: {task['list_name']}\n"
f"status: {task['status']}\n"
f"description: {task['description']}"
)
async def _list_tools() -> SimpleNamespace:
return SimpleNamespace(
tools=[
SimpleNamespace(
name="clickup_search",
description="Search ClickUp tasks visible to the authenticated user.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Text to search for in ClickUp tasks.",
},
"limit": {
"type": "integer",
"description": "Maximum number of tasks to return.",
},
},
"required": ["query"],
},
),
SimpleNamespace(
name="clickup_get_task",
description="Get a ClickUp task by id.",
inputSchema={
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "ClickUp task id.",
}
},
"required": ["task_id"],
},
),
]
)
async def _call_tool(
tool_name: str, arguments: dict[str, Any] | None = None
) -> SimpleNamespace:
arguments = arguments or {}
task = _FIXTURE["tasks"][0]
if tool_name == "clickup_search":
query = str(arguments.get("query", ""))
if query and task["name"].lower() not in query.lower():
raise ValueError(f"Unexpected ClickUp task query: {query!r}")
return SimpleNamespace(content=[SimpleNamespace(text=_task_text(task))])
if tool_name == "clickup_get_task":
task_id = arguments.get("task_id")
if task_id != task["id"]:
raise ValueError(f"Unexpected ClickUp task id: {task_id!r}")
return SimpleNamespace(content=[SimpleNamespace(text=_task_text(task))])
raise NotImplementedError(f"Unexpected ClickUp MCP tool call: {tool_name!r}")
def install(active_patches: list[Any]) -> None:
"""Register ClickUp MCP OAuth/tool handlers with the shared dispatchers."""
del active_patches
mcp_oauth_runtime.register_service(
mcp_url=_MCP_URL,
discovery_metadata={
"issuer": "https://mcp.clickup.com",
"authorization_endpoint": _AUTHORIZATION_URL,
"token_endpoint": _TOKEN_URL,
"registration_endpoint": _REGISTRATION_URL,
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"response_types_supported": ["code"],
},
client_id=_CLIENT_ID,
client_secret=_CLIENT_SECRET,
token_endpoint=_TOKEN_URL,
registration_endpoint=_REGISTRATION_URL,
oauth_code=_OAUTH_CODE,
access_token=_ACCESS_TOKEN,
refresh_token=_REFRESH_TOKEN,
scope="read write",
redirect_uri_substring="/api/v1/auth/mcp/clickup/connector/callback",
)
mcp_runtime.register(
url=_MCP_URL,
expected_bearer=_ACCESS_TOKEN,
list_tools=_list_tools,
call_tool=_call_tool,
)

View file

@ -0,0 +1,559 @@
"""Strict drop-in replacement for the `composio` Python SDK.
Registered as `sys.modules["composio"]` by `tests/e2e/run_backend.py`
and `tests/e2e/run_celery.py` BEFORE any production code imports
`composio`. From that point on, every `from composio import Composio`
in production resolves to `Composio` defined here.
Every class implements __getattr__ that raises NotImplementedError on
unknown attributes. A future production code path that introduces a
new SDK call (e.g. `client.bulk_operations.run`) fails CI loudly with
a clear "add this surface to the fake" message instead of silently
passing through to the real SDK.
Scenario branching is read from the request-scoped ContextVar in
`tests/e2e/middleware/scenario.py`, set by the X-E2E-Scenario header.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from .binary_loader import _resolve_file_bytes
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Fixture loading
# ---------------------------------------------------------------------------
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
_DRIVE_FIXTURE_PATH = _FIXTURES_DIR / "drive_files.json"
_GMAIL_FIXTURE_PATH = _FIXTURES_DIR / "gmail_messages.json"
_CALENDAR_FIXTURE_PATH = _FIXTURES_DIR / "calendar_events.json"
_DRIVE_DOWNLOAD_DIR = Path("/tmp/surfsense-e2e-composio-downloads")
def _load_drive_fixture() -> dict[str, Any]:
"""Load the canned Drive fixture once per process."""
with _DRIVE_FIXTURE_PATH.open() as f:
return json.load(f)
def _load_gmail_fixture() -> dict[str, Any]:
"""Load the canned Gmail fixture once per process."""
with _GMAIL_FIXTURE_PATH.open() as f:
return json.load(f)
def _load_calendar_fixture() -> dict[str, Any]:
"""Load the canned Calendar fixture once per process."""
with _CALENDAR_FIXTURE_PATH.open() as f:
return json.load(f)
_DRIVE_FIXTURE = _load_drive_fixture()
_GMAIL_FIXTURE = _load_gmail_fixture()
_CALENDAR_FIXTURE = _load_calendar_fixture()
def _get_scenario() -> str:
"""Return the current X-E2E-Scenario, defaulting to 'happy'.
Imported lazily so the fake module can be loaded BEFORE
`tests.e2e.middleware.scenario` is importable (during sys.modules
hijack at the very top of the entrypoint).
"""
try:
from tests.e2e.middleware.scenario import current_scenario
return current_scenario()
except Exception:
return "happy"
# ---------------------------------------------------------------------------
# Strict mixin: every fake class raises on unknown attribute access
# ---------------------------------------------------------------------------
class _StrictFakeMixin:
"""Base class for fakes. Any unknown attribute access fails loudly."""
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E Composio fake missing surface: {self._component_name}.{name!r}. "
f"If production code needs this, add an explicit method to "
f"surfsense_backend/tests/e2e/fakes/composio_module.py — "
f"the strict fake refuses to silently fall through to the real SDK."
)
# ---------------------------------------------------------------------------
# Result objects mimicking the real SDK's response dataclasses
# ---------------------------------------------------------------------------
class _ConnectionRequest:
"""Mimics composio.connected_accounts.initiate(...) return value."""
def __init__(self, *, redirect_url: str, account_id: str) -> None:
self.redirect_url = redirect_url
self.id = account_id
class _ConnectedAccount:
"""Mimics a connected_account row returned by wait_for_connection / refresh."""
def __init__(
self,
*,
account_id: str,
status: str = "ACTIVE",
redirect_url: str | None = None,
access_token: str = "fake-e2e-access-token-not-real-do-not-use-32chars",
) -> None:
self.id = account_id
self.status = status
self.redirect_url = redirect_url
self.state = _AccountState(access_token=access_token)
class _AccountState:
def __init__(self, *, access_token: str) -> None:
self.val = _AccountStateVal(access_token=access_token)
class _AccountStateVal:
def __init__(self, *, access_token: str) -> None:
self.access_token = access_token
class _AuthConfig:
"""Mimics one auth_config row returned by client.auth_configs.list().items."""
def __init__(self, *, config_id: str, toolkit_slug: str) -> None:
self.id = config_id
self.toolkit = _Toolkit(slug=toolkit_slug)
class _Toolkit:
def __init__(self, *, slug: str) -> None:
self.slug = slug
class _AuthConfigsListResult:
def __init__(self, items: list[_AuthConfig]) -> None:
self.items = items
# ---------------------------------------------------------------------------
# Sub-clients on the Composio top-level object
# ---------------------------------------------------------------------------
class _ConnectedAccounts(_StrictFakeMixin):
"""Strict fake for client.connected_accounts.*"""
_component_name = "connected_accounts"
def initiate(
self,
*,
user_id: str,
auth_config_id: str,
callback_url: str,
allow_multiple: bool = True,
**_: Any,
) -> _ConnectionRequest:
scenario = _get_scenario()
# Synthesize a deterministic account ID. Same toolkit on the same
# entity yields the same ID so the duplicate-connection scenario
# exercises the reconnect branch in composio_routes.py.
toolkit_id = auth_config_id.replace("auth-config-", "")
account_id = f"fake-acct-{toolkit_id}-{user_id}"
# The SDK's redirect_url normally points to Composio's hosted OAuth
# UI which then bounces to the third-party provider. In E2E we
# short-circuit straight back to OUR same-origin callback to avoid
# any real network requirement.
if scenario == "denied":
redirect = (
f"{callback_url}&error=access_denied"
if "?" in callback_url
else f"{callback_url}?error=access_denied"
)
else:
redirect = (
f"{callback_url}&connectedAccountId={account_id}"
if "?" in callback_url
else f"{callback_url}?connectedAccountId={account_id}"
)
logger.info(
"[fake-composio] initiate scenario=%s toolkit=%s redirect=%s",
scenario,
toolkit_id,
redirect,
)
return _ConnectionRequest(redirect_url=redirect, account_id=account_id)
def wait_for_connection(
self, *, id: str, timeout: float = 30.0, **_: Any
) -> _ConnectedAccount:
return _ConnectedAccount(account_id=id, status="ACTIVE")
def get(self, *, nanoid: str, **_: Any) -> _ConnectedAccount:
return _ConnectedAccount(account_id=nanoid, status="ACTIVE")
def delete(self, account_id: str, /, **_: Any) -> dict[str, Any]:
logger.info("[fake-composio] delete account=%s", account_id)
return {"success": True, "id": account_id}
def refresh(
self,
*,
nanoid: str,
body_redirect_url: str | None = None,
**_: Any,
) -> _ConnectedAccount:
return _ConnectedAccount(
account_id=nanoid,
status="ACTIVE",
redirect_url=body_redirect_url,
)
class _AuthConfigs(_StrictFakeMixin):
"""Strict fake for client.auth_configs.*"""
_component_name = "auth_configs"
def list(self, **_: Any) -> _AuthConfigsListResult:
# Return one auth config per toolkit we plan to test. The real
# SDK lets you have multiple, but one is enough for E2E.
return _AuthConfigsListResult(
items=[
_AuthConfig(
config_id="auth-config-googledrive", toolkit_slug="googledrive"
),
_AuthConfig(config_id="auth-config-gmail", toolkit_slug="gmail"),
_AuthConfig(
config_id="auth-config-googlecalendar",
toolkit_slug="googlecalendar",
),
]
)
class _Tools(_StrictFakeMixin):
"""Strict fake for client.tools.*"""
_component_name = "tools"
def execute(
self,
*,
slug: str,
connected_account_id: str,
user_id: str | None = None,
arguments: dict[str, Any] | None = None,
dangerously_skip_version_check: bool = True,
**_: Any,
) -> dict[str, Any]:
scenario = _get_scenario()
args = arguments or {}
logger.info(
"[fake-composio] tools.execute slug=%s scenario=%s args=%s",
slug,
scenario,
list(args.keys()),
)
if scenario == "auth_expired":
# Match the error strings that composio_routes.py classifies
# as authentication failures (see lines ~720-728).
raise _AuthExpiredError(
"Token has been expired or revoked. (HTTP 401: invalid_grant)"
)
if slug == "GOOGLEDRIVE_LIST_FILES":
return _drive_list_files(args)
if slug == "GOOGLEDRIVE_DOWNLOAD_FILE":
return _drive_download_file(args)
if slug == "GOOGLEDRIVE_GET_FILE_METADATA":
return _drive_get_metadata(args)
if slug == "GOOGLEDRIVE_GET_CHANGES_START_PAGE_TOKEN":
return {"data": {"startPageToken": "fake-start-page-token-1"}}
if slug == "GOOGLEDRIVE_LIST_CHANGES":
return {
"data": {"changes": [], "newStartPageToken": "fake-start-page-token-1"}
}
if slug == "GOOGLEDRIVE_GET_ABOUT":
# Used by ComposioService.get_connected_account_email for
# googledrive. Returning a fake email lets the connector get a
# nice display name; failure is non-fatal.
return {"data": {"user": {"emailAddress": "e2e-fake@surfsense.example"}}}
if slug == "GMAIL_GET_PROFILE":
return {"data": {"emailAddress": "e2e-fake@surfsense.example"}}
if slug == "GMAIL_FETCH_EMAILS":
return _gmail_fetch_emails(args)
if slug == "GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID":
return _gmail_fetch_message_by_message_id(args)
if slug == "GOOGLECALENDAR_EVENTS_LIST":
return _calendar_events_list(args)
if slug == "GOOGLECALENDAR_GET_CALENDAR":
return {
"data": {
"id": "primary",
"summary": "e2e-fake@surfsense.example",
"primary": True,
"timeZone": "UTC",
}
}
if slug == "GOOGLECALENDAR_CALENDARS_LIST":
return {
"data": {
"items": [
{
"id": "primary",
"summary": "e2e-fake@surfsense.example",
"primary": True,
}
]
}
}
# No silent passthrough: a slug we have not modelled is a test bug.
raise NotImplementedError(
f"E2E Composio fake has no handler for tool slug {slug!r}. "
f"Add it to surfsense_backend/tests/e2e/fakes/composio_module.py "
f"in `_Tools.execute` if production code needs it."
)
# ---------------------------------------------------------------------------
# Drive tool handlers
# ---------------------------------------------------------------------------
def _drive_list_files(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GOOGLEDRIVE_LIST_FILES.
The real SDK accepts a Drive-style `q=` query like
`'<folder_id>' in parents and trashed = false ...`. We parse out the
folder id and serve the matching fixture list.
"""
q = args.get("q", "")
folder_id = "root"
if "in parents" in q:
# q looks like: '<folder_id>' in parents and trashed = false ...
try:
folder_id = q.split("'")[1]
except IndexError:
folder_id = "root"
files = _filter_drive_files_for_query(q, _DRIVE_FIXTURE.get(folder_id, []))
return {
"data": {
"files": files,
"nextPageToken": None,
}
}
def _extract_quoted_value(q: str, anchor: str) -> str | None:
anchor_idx = q.find(anchor)
if anchor_idx == -1:
return None
after_anchor = q[anchor_idx + len(anchor) :]
first_quote_idx = after_anchor.find("'")
if first_quote_idx == -1:
return None
after_first_quote = after_anchor[first_quote_idx + 1 :]
second_quote_idx = after_first_quote.find("'")
if second_quote_idx == -1:
return None
return after_first_quote[:second_quote_idx]
def _filter_drive_files_for_query(
q: str, files: list[dict[str, Any]]
) -> list[dict[str, Any]]:
filtered = list(files)
if "trashed = false" in q:
filtered = [entry for entry in filtered if entry.get("trashed") is not True]
excluded_mime_type = _extract_quoted_value(q, "mimeType !=")
if excluded_mime_type:
filtered = [
entry for entry in filtered if entry.get("mimeType") != excluded_mime_type
]
included_mime_type = _extract_quoted_value(q, "mimeType =")
if included_mime_type:
filtered = [
entry for entry in filtered if entry.get("mimeType") == included_mime_type
]
return filtered
def _drive_download_file(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GOOGLEDRIVE_DOWNLOAD_FILE.
The real SDK writes the downloaded bytes to a local file and returns
the path. composio_service.py then reads bytes from that path. We
reproduce that behaviour by writing fixture content into a tmp
directory and returning the path.
"""
file_id = args.get("file_id", "")
contents = _resolve_file_bytes(_DRIVE_FIXTURE, file_id, _FIXTURES_DIR)
if contents is None:
# Unknown file id is a test bug, fail loudly.
raise NotImplementedError(
f"E2E Composio fake has no canned content for file_id={file_id!r}. "
f"Add it under '_file_contents' in "
f"surfsense_backend/tests/e2e/fakes/fixtures/drive_files.json."
)
_DRIVE_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
metadata = _drive_get_metadata({"file_id": file_id})["data"]
file_name = metadata.get("name") or f"{file_id}.txt"
out_path = _DRIVE_DOWNLOAD_DIR / file_name
out_path.write_bytes(contents)
return {
"data": {
"file_path": str(out_path),
"file_name": file_name,
"size": len(contents),
}
}
def _drive_get_metadata(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GOOGLEDRIVE_GET_FILE_METADATA."""
file_id = args.get("file_id", "")
# Search every folder fixture for the file
for items in _DRIVE_FIXTURE.values():
if not isinstance(items, list):
continue
for entry in items:
if entry.get("id") == file_id:
return {"data": entry}
raise NotImplementedError(
f"E2E fake: no metadata fixture for file_id={file_id!r}. "
f"Add it to drive_files.json."
)
# ---------------------------------------------------------------------------
# Gmail tool handlers
# ---------------------------------------------------------------------------
def _gmail_fetch_emails(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GMAIL_FETCH_EMAILS.
The production indexer uses this as a list page, then calls
GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID for the full body of each id.
"""
del args
messages = list(_GMAIL_FIXTURE.get("messages", []))
return {
"data": {
"messages": messages,
"nextPageToken": None,
"resultSizeEstimate": len(messages),
}
}
def _gmail_fetch_message_by_message_id(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID."""
message_id = args.get("message_id", "")
details = _GMAIL_FIXTURE.get("details", {})
detail = details.get(message_id)
if detail is None:
raise NotImplementedError(
f"E2E Composio fake has no Gmail detail fixture for "
f"message_id={message_id!r}. Add it under 'details' in "
f"surfsense_backend/tests/e2e/fakes/fixtures/gmail_messages.json."
)
return {"data": detail}
# ---------------------------------------------------------------------------
# Google Calendar tool handlers
# ---------------------------------------------------------------------------
def _calendar_events_list(args: dict[str, Any]) -> dict[str, Any]:
"""Mimic GOOGLECALENDAR_EVENTS_LIST for live calendar chat tools."""
max_results = int(args.get("max_results", 250) or 250)
items = list(_CALENDAR_FIXTURE.get("items", []))[:max_results]
return {
"data": {
"items": items,
"summary": "e2e-fake@surfsense.example",
"timeZone": "UTC",
}
}
# ---------------------------------------------------------------------------
# Errors
# ---------------------------------------------------------------------------
class _AuthExpiredError(Exception):
"""Raised by the fake when scenario=auth_expired.
composio_service.execute_tool catches every exception and surfaces
str(error) inside the result dict; composio_routes.py then classifies
"expired or revoked" / "401" tokens and sets connector.config.auth_expired.
"""
# ---------------------------------------------------------------------------
# Top-level Composio class — the only public symbol production imports
# ---------------------------------------------------------------------------
class Composio(_StrictFakeMixin):
"""Drop-in replacement for `composio.Composio`.
Production calls: `Composio(api_key=..., file_download_dir=...)`
"""
_component_name = "Composio"
def __init__(
self,
*,
api_key: str | None = None,
file_download_dir: str | None = None,
**_: Any,
) -> None:
self.api_key = api_key
self.file_download_dir = file_download_dir
self.connected_accounts = _ConnectedAccounts()
self.tools = _Tools()
self.auth_configs = _AuthConfigs()
logger.info(
"[fake-composio] Composio() constructed (E2E mode, no real network)"
)
# Public re-exports so `from composio import Composio` resolves correctly
# when this module is registered as sys.modules["composio"].
__all__ = ["Composio"]

View file

@ -0,0 +1,58 @@
"""Strict Confluence indexer fake for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import patch
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "confluence_pages.json"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
class _FakeConfluenceHistoryConnector:
def __init__(self, *, session: Any, connector_id: int, **kwargs: Any):
del session, kwargs
self.connector_id = connector_id
async def get_pages_by_date_range(
self,
*,
start_date: str,
end_date: str,
include_comments: bool = False,
) -> tuple[list[dict[str, Any]], None]:
if not start_date or not end_date:
raise ValueError(
"Confluence indexer fake expected start_date and end_date."
)
del include_comments
return _FIXTURE["pages"], None
async def close(self) -> None:
return None
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
"E2E Confluence indexer fake missing surface: "
f"ConfluenceHistoryConnector.{name!r}. "
"Add it to surfsense_backend/tests/e2e/fakes/confluence_indexer.py."
)
def install(active_patches: list[Any]) -> None:
"""Patch only the Confluence indexer's bound connector class."""
p = patch(
"app.tasks.connector_indexers.confluence_indexer.ConfluenceHistoryConnector",
_FakeConfluenceHistoryConnector,
)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,146 @@
"""Strict Confluence OAuth fakes for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import patch
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "confluence_pages.json"
_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
_ACCESS_TOKEN = "fake-confluence-access-token"
_REFRESH_TOKEN = "fake-confluence-refresh-token"
_OAUTH_CODE = "fake-confluence-oauth-code"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E Confluence OAuth fake missing surface: "
f"{self._component_name}.{name!r}. "
"Add it to surfsense_backend/tests/e2e/fakes/confluence_oauth.py."
)
class _FakeResponse(_StrictFakeMixin):
_component_name = "httpx.Response"
def __init__(self, payload: Any, status_code: int = 200):
self._payload = payload
self.status_code = status_code
self.text = json.dumps(payload, sort_keys=True)
def json(self) -> Any:
return self._payload
class _FakeHttpxAsyncClient(_StrictFakeMixin):
_component_name = "httpx.AsyncClient"
def __init__(self, *args: Any, **kwargs: Any):
del args, kwargs
async def __aenter__(self) -> _FakeHttpxAsyncClient:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def post(self, url: str, **kwargs: Any) -> _FakeResponse:
if url != _TOKEN_URL:
raise NotImplementedError(f"Unexpected Confluence OAuth POST url={url!r}")
data = kwargs.get("json") or {}
headers = kwargs.get("headers") or {}
if headers.get("Content-Type") != "application/json":
raise ValueError("Confluence OAuth token exchange expected JSON headers.")
grant_type = data.get("grant_type")
if grant_type == "authorization_code":
if data.get("code") != _OAUTH_CODE:
raise ValueError(
f"Unexpected fake Confluence OAuth code: {data.get('code')!r}"
)
if not data.get("client_id") or not data.get("client_secret"):
raise ValueError(
"Confluence OAuth token exchange missing client creds."
)
if "/api/v1/auth/confluence/connector/callback" not in str(
data.get("redirect_uri", "")
):
raise ValueError(
"Confluence OAuth token exchange got unexpected redirect_uri: "
f"{data.get('redirect_uri')!r}"
)
elif grant_type == "refresh_token":
if data.get("refresh_token") != _REFRESH_TOKEN:
raise ValueError(
"Unexpected fake Confluence refresh token: "
f"{data.get('refresh_token')!r}"
)
else:
raise ValueError(f"Unexpected fake Confluence grant_type: {grant_type!r}")
return _FakeResponse(
{
"access_token": _ACCESS_TOKEN,
"refresh_token": _REFRESH_TOKEN,
"expires_in": 3600,
"scope": "read:confluence-user read:space:confluence read:page:confluence",
"token_type": "Bearer",
}
)
async def get(self, url: str, **kwargs: Any) -> _FakeResponse:
if url != _RESOURCES_URL:
raise NotImplementedError(f"Unexpected Confluence OAuth GET url={url!r}")
headers = kwargs.get("headers") or {}
auth = headers.get("Authorization")
if auth != f"Bearer {_ACCESS_TOKEN}":
raise ValueError(f"Unexpected Confluence resources Authorization: {auth!r}")
site = _FIXTURE["site"]
return _FakeResponse(
[
{
"id": site["cloud_id"],
"name": site["name"],
"url": site["url"],
"scopes": [
"read:confluence-user",
"read:space:confluence",
"read:page:confluence",
],
}
]
)
class _FakeHttpxModule(_StrictFakeMixin):
_component_name = "httpx"
AsyncClient = _FakeHttpxAsyncClient
def install(active_patches: list[Any]) -> None:
"""Patch only Confluence route-local HTTP OAuth calls."""
p = patch(
"app.routes.confluence_add_connector_route.httpx",
_FakeHttpxModule(),
)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,182 @@
"""Strict Dropbox HTTP/API fakes for Playwright E2E.
This module patches the Dropbox OAuth route and indexer consumer-site
bindings. It keeps the production add/callback/indexing flow intact while
serving deterministic Dropbox-shaped token, profile, metadata, and file
content responses.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import patch
import httpx
from .binary_loader import _resolve_file_bytes
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
_DROPBOX_FIXTURE_PATH = _FIXTURES_DIR / "dropbox_files.json"
def _load_dropbox_fixture() -> dict[str, Any]:
with _DROPBOX_FIXTURE_PATH.open() as f:
return json.load(f)
_DROPBOX_FIXTURE = _load_dropbox_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E Dropbox fake missing surface: "
f"{self._component_name}.{name!r}. Add it to "
f"surfsense_backend/tests/e2e/fakes/dropbox_api.py."
)
class _FakeDropboxClient(_StrictFakeMixin):
_component_name = "DropboxClient"
def __init__(self, session: Any, connector_id: int):
self._session = session
self._connector_id = connector_id
async def _get_valid_token(self) -> str:
return "fake-dropbox-access-token"
async def list_folder(
self, path: str = ""
) -> tuple[list[dict[str, Any]], str | None]:
items = _DROPBOX_FIXTURE.get(path)
if not isinstance(items, list):
return [], f"E2E Dropbox fake has no folder for path={path!r}."
return [dict(item) for item in items], None
async def get_latest_cursor(self, path: str = "") -> tuple[str | None, str | None]:
if path not in _DROPBOX_FIXTURE:
return None, f"E2E Dropbox fake has no cursor for path={path!r}."
return f"fake-dropbox-cursor:{path or 'root'}", None
async def get_changes(
self, cursor: str
) -> tuple[list[dict[str, Any]], str | None, str | None]:
return [], cursor, None
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
metadata = _dropbox_get_metadata(path)
if metadata is None:
return None, f"E2E Dropbox fake has no metadata for path={path!r}."
return metadata, None
async def download_file(self, path: str) -> tuple[bytes | None, str | None]:
content = _resolve_file_bytes(_DROPBOX_FIXTURE, path, _FIXTURES_DIR)
if content is None:
return None, f"E2E Dropbox fake has no content for path={path!r}."
return content, None
async def download_file_to_disk(self, path: str, dest_path: str) -> str | None:
content = _resolve_file_bytes(_DROPBOX_FIXTURE, path, _FIXTURES_DIR)
if content is None:
return f"E2E Dropbox fake has no content for path={path!r}."
with open(dest_path, "wb") as f:
f.write(content)
return None
async def get_current_account(self) -> tuple[dict[str, Any] | None, str | None]:
return _dropbox_current_account(), None
class _FakeAsyncClient(_StrictFakeMixin):
_component_name = "httpx.AsyncClient"
def __init__(self, *args: Any, **kwargs: Any):
del args, kwargs
async def __aenter__(self) -> _FakeAsyncClient:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
del args, kwargs
if url == "https://api.dropboxapi.com/oauth2/token":
return _json_response(
"POST",
url,
{
"access_token": "fake-dropbox-access-token",
"refresh_token": "fake-dropbox-refresh-token",
"token_type": "bearer",
"expires_in": 3600,
"account_id": "dbid:fake-dropbox-account",
},
)
if url == "https://api.dropboxapi.com/2/users/get_current_account":
return _json_response("POST", url, _dropbox_current_account())
raise NotImplementedError(f"E2E Dropbox fake unexpected POST URL: {url!r}")
async def request(
self, method: str, url: str, *args: Any, **kwargs: Any
) -> httpx.Response:
del args, kwargs
raise NotImplementedError(
f"E2E Dropbox fake unexpected request: {method!r} {url!r}"
)
class _FakeHttpxModule(_StrictFakeMixin):
_component_name = "httpx"
AsyncClient = _FakeAsyncClient
def _json_response(
method: str, url: str, payload: dict[str, Any], status_code: int = 200
) -> httpx.Response:
return httpx.Response(
status_code=status_code,
json=payload,
request=httpx.Request(method, url),
)
def _dropbox_current_account() -> dict[str, Any]:
return {
"email": "dropbox-e2e@surfsense.example",
"name": {"display_name": "SurfSense Dropbox E2E"},
"account_id": "dbid:fake-dropbox-account",
}
def _dropbox_get_metadata(path: str | None) -> dict[str, Any] | None:
for items in _DROPBOX_FIXTURE.values():
if not isinstance(items, list):
continue
for entry in items:
if entry.get("path_lower") == path or entry.get("id") == path:
return dict(entry)
return None
def install(active_patches: list[Any]) -> None:
"""Patch production Dropbox bindings to use strict Dropbox fakes."""
targets = [
("app.routes.dropbox_add_connector_route.httpx", _FakeHttpxModule()),
("app.routes.dropbox_add_connector_route.DropboxClient", _FakeDropboxClient),
(
"app.tasks.connector_indexers.dropbox_indexer.DropboxClient",
_FakeDropboxClient,
),
("app.connectors.dropbox.client.httpx", _FakeHttpxModule()),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,80 @@
"""Deterministic embedding fakes for E2E.
Mirrors the existing `patched_embed_texts` fixture in
`surfsense_backend/tests/integration/conftest.py`:
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
The dimension matches whatever `config.embedding_model_instance.dimension`
returns in the running process so the fakes are vector-compatible with
the documents.embedding pgvector column.
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
from app.config import config
logger = logging.getLogger(__name__)
def _embedding_dim() -> int:
"""Resolve the dimension once, lazily, so tests work for any embedding model."""
return int(config.embedding_model_instance.dimension)
def fake_embed_text(text: str) -> np.ndarray:
"""Deterministic single-text embedding."""
return np.full(shape=(_embedding_dim(),), fill_value=0.1, dtype=np.float32)
def fake_embed_texts(texts: list[str]) -> list[np.ndarray]:
"""Deterministic batch embedding. One vector per input text."""
if not texts:
return []
dim = _embedding_dim()
return [np.full(shape=(dim,), fill_value=0.1, dtype=np.float32) for _ in texts]
def install(patches: list[Any]) -> None:
"""Install embedding patches at every binding site we know about.
Caller passes a `patches` list that the entrypoint will track in
order to start them (and, in principle, stop them on shutdown we
intentionally never stop because the process exits when the test
server stops).
"""
from unittest.mock import patch as _patch
targets = [
# Source binding (where the real implementation lives)
("app.utils.document_converters.embed_text", fake_embed_text),
("app.utils.document_converters.embed_texts", fake_embed_texts),
# Consumers that did `from app.utils.document_converters import embed_text/texts`
("app.indexing_pipeline.document_embedder.embed_text", fake_embed_text),
("app.indexing_pipeline.document_embedder.embed_texts", fake_embed_texts),
# Pipeline service binding (the actual call site for indexing.index)
(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
fake_embed_texts,
),
]
for target, replacement in targets:
try:
p = _patch(target, replacement)
p.start()
patches.append(p)
logger.info("[fake-embeddings] patched %s", target)
except (ModuleNotFoundError, AttributeError) as exc:
# If a future refactor moves a binding, fail loudly — silent
# passthrough to a real embedding model would be expensive
# and non-deterministic.
raise RuntimeError(
f"Could not patch embedding binding {target!r}: {exc!s}. "
f"Update surfsense_backend/tests/e2e/fakes/embeddings.py "
f"to point at the new binding site."
) from exc

View file

@ -0,0 +1,84 @@
"""Generate deterministic one-page PDFs for connector E2E fixtures."""
from __future__ import annotations
from pathlib import Path
PDF_FIXTURES = {
"drive-canary.pdf": (
"Native Drive PDF Canary",
"This one-page text-layer PDF proves native Drive Docling coverage.",
"SURFSENSE_E2E_CANARY_TOKEN_DRIVE_PDF_001",
),
"onedrive-canary.pdf": (
"OneDrive PDF Canary",
"This one-page text-layer PDF proves OneDrive Docling coverage.",
"SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_PDF_001",
),
"dropbox-canary.pdf": (
"Dropbox PDF Canary",
"This one-page text-layer PDF proves Dropbox Docling coverage.",
"SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_PDF_001",
),
"composio-drive-canary.pdf": (
"Composio Drive PDF Canary",
"This one-page text-layer PDF proves Composio Drive Docling coverage.",
"SURFSENSE_E2E_CANARY_TOKEN_COMPOSIO_DRIVE_PDF_001",
),
}
def _escape_pdf_text(text: str) -> str:
return text.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)")
def _build_pdf(lines: tuple[str, str, str]) -> bytes:
text_ops = ["BT", "/F1 12 Tf", "72 760 Td"]
for index, line in enumerate(lines):
if index:
text_ops.append("0 -18 Td")
text_ops.append(f"({_escape_pdf_text(line)}) Tj")
text_ops.append("ET")
stream = "\n".join(text_ops).encode("ascii")
objects = [
b"<< /Type /Catalog /Pages 2 0 R >>",
b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>",
b"<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] "
b"/Resources << /Font << /F1 4 0 R >> >> /Contents 5 0 R >>",
b"<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>",
b"<< /Length "
+ str(len(stream)).encode("ascii")
+ b" >>\nstream\n"
+ stream
+ b"\nendstream",
]
pdf = bytearray(b"%PDF-1.4\n")
offsets = [0]
for obj_number, obj in enumerate(objects, start=1):
offsets.append(len(pdf))
pdf.extend(f"{obj_number} 0 obj\n".encode("ascii"))
pdf.extend(obj)
pdf.extend(b"\nendobj\n")
xref_offset = len(pdf)
pdf.extend(f"xref\n0 {len(objects) + 1}\n".encode("ascii"))
pdf.extend(b"0000000000 65535 f \n")
for offset in offsets[1:]:
pdf.extend(f"{offset:010d} 00000 n \n".encode("ascii"))
pdf.extend(
f"trailer\n<< /Size {len(objects) + 1} /Root 1 0 R >>\n"
f"startxref\n{xref_offset}\n%%EOF\n".encode("ascii")
)
return bytes(pdf)
def main() -> None:
out_dir = Path(__file__).parent
for filename, lines in PDF_FIXTURES.items():
(out_dir / filename).write_bytes(_build_pdf(lines))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,48 @@
{
"items": [
{
"id": "fake-calendar-event-canary-001",
"status": "confirmed",
"summary": "E2E Canary Calendar Event",
"description": "This Calendar event proves the Composio Calendar live tool fetched event details. SURFSENSE_E2E_CANARY_TOKEN_CALENDAR_001",
"location": "SurfSense E2E Room",
"htmlLink": "https://calendar.google.com/calendar/event?eid=fake-calendar-event-canary-001",
"start": {
"dateTime": "2026-05-12T10:00:00Z",
"timeZone": "UTC"
},
"end": {
"dateTime": "2026-05-12T10:30:00Z",
"timeZone": "UTC"
},
"attendees": [
{
"email": "e2e-fake@surfsense.example",
"responseStatus": "accepted"
}
]
},
{
"id": "fake-calendar-event-planning-001",
"status": "confirmed",
"summary": "E2E Planning Sync",
"description": "Non-canary planning sync used to prove list responses can contain multiple events.",
"location": "SurfSense Planning Room",
"htmlLink": "https://calendar.google.com/calendar/event?eid=fake-calendar-event-planning-001",
"start": {
"dateTime": "2026-05-13T15:00:00Z",
"timeZone": "UTC"
},
"end": {
"dateTime": "2026-05-13T15:45:00Z",
"timeZone": "UTC"
},
"attendees": [
{
"email": "planner@surfsense.example",
"responseStatus": "needsAction"
}
]
}
]
}

View file

@ -0,0 +1,13 @@
{
"tasks": [
{
"id": "fake-clickup-task-canary-001",
"name": "E2E Canary ClickUp Task",
"list_name": "SurfSense E2E ClickUp List",
"workspace_id": "fake-clickup-workspace-001",
"workspace_name": "SurfSense E2E ClickUp Workspace",
"status": "open",
"description": "Canary task body containing SURFSENSE_E2E_CANARY_TOKEN_CLICKUP_001"
}
]
}

View file

@ -0,0 +1,33 @@
{
"site": {
"cloud_id": "fake-confluence-cloud-001",
"name": "SurfSense E2E Confluence",
"url": "https://surfsense-e2e-confluence.atlassian.net"
},
"pages": [
{
"id": "fake-confluence-page-canary-001",
"title": "E2E Canary Confluence Page",
"spaceId": "fake-confluence-space-001",
"body": {
"storage": {
"value": "<h1>E2E Canary Confluence Page</h1><p>This page proves Confluence OAuth indexing works end-to-end. SURFSENSE_E2E_CANARY_TOKEN_CONFLUENCE_001</p>"
}
},
"comments": [
{
"id": "fake-confluence-comment-canary-001",
"body": {
"storage": {
"value": "<p>Confluence comment content is included in indexed markdown.</p>"
}
},
"version": {
"authorId": "fake-confluence-user-001",
"createdAt": "2026-05-08T00:00:00.000Z"
}
}
]
}
]
}

View file

@ -0,0 +1,99 @@
{
"root": [
{
"id": "fake-folder-projects",
"name": "Projects",
"mimeType": "application/vnd.google-apps.folder",
"modifiedTime": "2025-01-15T10:00:00.000Z",
"createdTime": "2024-12-01T08:00:00.000Z"
},
{
"id": "fake-folder-archive",
"name": "Archive",
"mimeType": "application/vnd.google-apps.folder",
"modifiedTime": "2024-11-20T14:30:00.000Z",
"createdTime": "2024-09-10T12:00:00.000Z"
},
{
"id": "fake-file-readme",
"name": "README.md",
"mimeType": "text/markdown",
"modifiedTime": "2025-02-01T09:15:00.000Z",
"createdTime": "2025-01-30T16:20:00.000Z"
},
{
"id": "fake-file-canary",
"name": "e2e-canary.txt",
"mimeType": "text/plain",
"modifiedTime": "2025-02-10T11:00:00.000Z",
"createdTime": "2025-02-10T11:00:00.000Z"
},
{
"id": "fake-file-pdf-native",
"name": "e2e-canary.pdf",
"mimeType": "application/pdf",
"size": 735,
"modifiedTime": "2025-02-10T11:05:00.000Z",
"createdTime": "2025-02-10T11:05:00.000Z"
},
{
"id": "fake-file-pdf-composio",
"name": "e2e-composio-canary.pdf",
"mimeType": "application/pdf",
"size": 748,
"modifiedTime": "2025-02-10T11:10:00.000Z",
"createdTime": "2025-02-10T11:10:00.000Z"
},
{
"id": "fake-file-budget",
"name": "Q1-Budget.csv",
"mimeType": "text/csv",
"modifiedTime": "2025-01-25T13:45:00.000Z",
"createdTime": "2025-01-25T13:45:00.000Z"
},
{
"id": "fake-shortcut-canary",
"name": "Shortcut to Canary",
"mimeType": "application/vnd.google-apps.shortcut",
"modifiedTime": "2025-02-10T12:00:00.000Z",
"createdTime": "2025-02-10T12:00:00.000Z"
},
{
"id": "fake-file-trashed",
"name": "trashed-e2e-note.txt",
"mimeType": "text/plain",
"modifiedTime": "2025-02-11T09:00:00.000Z",
"createdTime": "2025-02-11T09:00:00.000Z",
"trashed": true
}
],
"fake-folder-projects": [
{
"id": "fake-file-roadmap",
"name": "2025-Roadmap.md",
"mimeType": "text/markdown",
"modifiedTime": "2025-02-12T08:30:00.000Z",
"createdTime": "2025-01-05T10:00:00.000Z"
}
],
"fake-folder-archive": [
{
"id": "fake-file-old-notes",
"name": "old-meeting-notes.txt",
"mimeType": "text/plain",
"modifiedTime": "2024-08-15T15:00:00.000Z",
"createdTime": "2024-08-15T15:00:00.000Z"
}
],
"_file_contents": {
"fake-file-readme": "# E2E Fake Drive\n\nThis README is served by the strict Composio fake. SURFSENSE_E2E_README_MARKER",
"fake-file-canary": "Canary token for E2E tests: SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001\nThis file's content is asserted by indexing.spec.ts to confirm the indexing pipeline ran end-to-end.",
"fake-file-budget": "Quarter,Revenue,Expenses\nQ1,100000,75000\nSURFSENSE_E2E_BUDGET_MARKER,2025,test",
"fake-file-roadmap": "# 2025 Roadmap\n\n- E2E Drive indexing\n- Composio Gmail/Calendar SURFSENSE_E2E_ROADMAP_MARKER",
"fake-file-old-notes": "Old meeting notes archived in 2024. SURFSENSE_E2E_ARCHIVE_MARKER"
},
"_file_binary_paths": {
"fake-file-pdf-native": "binary/drive-canary.pdf",
"fake-file-pdf-composio": "binary/composio-drive-canary.pdf"
}
}

View file

@ -0,0 +1,34 @@
{
"": [
{
".tag": "file",
"id": "id:fake-dropbox-canary",
"name": "e2e-dropbox-canary.txt",
"path_lower": "/e2e-dropbox-canary.txt",
"path_display": "/e2e-dropbox-canary.txt",
"size": 152,
"is_downloadable": true,
"server_modified": "2026-05-08T00:00:00Z",
"client_modified": "2026-05-08T00:00:00Z",
"content_hash": "fake-dropbox-hash-001"
},
{
".tag": "file",
"id": "id:fake-dropbox-pdf-canary",
"name": "e2e-dropbox-canary.pdf",
"path_lower": "/e2e-dropbox-canary.pdf",
"path_display": "/e2e-dropbox-canary.pdf",
"size": 727,
"is_downloadable": true,
"server_modified": "2026-05-08T00:05:00Z",
"client_modified": "2026-05-08T00:05:00Z",
"content_hash": "fake-dropbox-hash-pdf-001"
}
],
"_file_contents": {
"/e2e-dropbox-canary.txt": "Canary token for Dropbox E2E tests: SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_001\nThis file proves the Dropbox indexing pipeline ran end-to-end."
},
"_file_binary_paths": {
"/e2e-dropbox-canary.pdf": "binary/dropbox-canary.pdf"
}
}

View file

@ -0,0 +1,34 @@
{
"messages": [
{
"id": "fake-msg-canary-001",
"threadId": "fake-thread-canary-001",
"snippet": "E2E canary email body is loaded through the Gmail detail endpoint."
},
{
"id": "fake-msg-planning-001",
"threadId": "fake-thread-planning-001",
"snippet": "Planning email used to keep Gmail fixtures representative."
}
],
"details": {
"fake-msg-canary-001": {
"id": "fake-msg-canary-001",
"threadId": "fake-thread-canary-001",
"subject": "E2E Canary Email",
"from": "sender@surfsense.example",
"to": "e2e-fake@surfsense.example",
"date": "Mon, 10 Feb 2025 11:00:00 +0000",
"messageText": "This Gmail message body proves the Composio Gmail live tool fetched message details. SURFSENSE_E2E_CANARY_TOKEN_GMAIL_001"
},
"fake-msg-planning-001": {
"id": "fake-msg-planning-001",
"threadId": "fake-thread-planning-001",
"subject": "E2E Planning Notes",
"from": "planner@surfsense.example",
"to": "e2e-fake@surfsense.example",
"date": "Tue, 11 Feb 2025 09:30:00 +0000",
"messageText": "Planning notes for a non-canary Gmail fixture."
}
}
}

View file

@ -0,0 +1,15 @@
{
"site": {
"cloud_id": "fake-jira-cloud-001",
"name": "SurfSense E2E Atlassian",
"url": "https://surfsense-e2e.atlassian.net"
},
"issues": [
{
"id": "fake-jira-issue-canary-001",
"key": "E2E-101",
"summary": "E2E Canary Jira Issue",
"description": "This Jira issue proves live MCP tool calls work end-to-end. SURFSENSE_E2E_CANARY_TOKEN_JIRA_001"
}
]
}

View file

@ -0,0 +1,17 @@
{
"organization": {
"name": "SurfSense E2E Linear Org",
"url_key": "surfsense-e2e"
},
"issues": [
{
"id": "fake-linear-issue-canary-001",
"identifier": "E2E-101",
"title": "E2E Canary Linear Issue",
"description": "This Linear issue proves the live MCP tool fetched issue details. SURFSENSE_E2E_CANARY_TOKEN_LINEAR_001",
"state": "Todo",
"assignee": "E2E Owner",
"team": "SurfSense E2E"
}
]
}

View file

@ -0,0 +1,64 @@
{
"pages": [
{
"object": "page",
"id": "fake-notion-page-canary-001",
"created_time": "2026-05-07T00:00:00.000Z",
"last_edited_time": "2026-05-07T00:00:00.000Z",
"url": "https://notion.so/fake-notion-page-canary-001",
"properties": {
"Name": {
"id": "title",
"type": "title",
"title": [
{
"type": "text",
"plain_text": "E2E Canary Notion Page",
"text": {
"content": "E2E Canary Notion Page"
}
}
]
}
}
}
],
"blocks": {
"fake-notion-page-canary-001": [
{
"object": "block",
"id": "fake-notion-block-heading-001",
"type": "heading_2",
"has_children": false,
"heading_2": {
"rich_text": [
{
"type": "text",
"plain_text": "E2E Notion Canary",
"text": {
"content": "E2E Notion Canary"
}
}
]
}
},
{
"object": "block",
"id": "fake-notion-block-body-001",
"type": "paragraph",
"has_children": false,
"paragraph": {
"rich_text": [
{
"type": "text",
"plain_text": "This Notion page proves the indexed connector fetched Notion blocks through OAuth credentials. SURFSENSE_E2E_CANARY_TOKEN_NOTION_001",
"text": {
"content": "This Notion page proves the indexed connector fetched Notion blocks through OAuth credentials. SURFSENSE_E2E_CANARY_TOKEN_NOTION_001"
}
}
]
}
}
]
}
}

View file

@ -0,0 +1,82 @@
{
"root": [
{
"id": "fake-onedrive-folder-projects",
"name": "Projects",
"size": 0,
"folder": {
"childCount": 1
},
"parentReference": {
"id": "root",
"path": "/drive/root:"
},
"createdDateTime": "2025-02-01T08:00:00Z",
"lastModifiedDateTime": "2025-02-10T10:00:00Z",
"webUrl": "https://onedrive.example/fake/projects"
},
{
"id": "fake-onedrive-canary",
"name": "e2e-onedrive-canary.txt",
"size": 164,
"file": {
"mimeType": "text/plain",
"hashes": {
"quickXorHash": "fake-onedrive-canary-qxh"
}
},
"parentReference": {
"id": "root",
"path": "/drive/root:"
},
"createdDateTime": "2025-02-10T11:00:00Z",
"lastModifiedDateTime": "2025-02-10T11:00:00Z",
"webUrl": "https://onedrive.example/fake/e2e-onedrive-canary.txt"
},
{
"id": "fake-onedrive-pdf-canary",
"name": "e2e-onedrive-canary.pdf",
"size": 730,
"file": {
"mimeType": "application/pdf",
"hashes": {
"quickXorHash": "fake-onedrive-pdf-canary-qxh"
}
},
"parentReference": {
"id": "root",
"path": "/drive/root:"
},
"createdDateTime": "2025-02-10T11:05:00Z",
"lastModifiedDateTime": "2025-02-10T11:05:00Z",
"webUrl": "https://onedrive.example/fake/e2e-onedrive-canary.pdf"
}
],
"fake-onedrive-folder-projects": [
{
"id": "fake-onedrive-roadmap",
"name": "OneDrive Roadmap.md",
"size": 82,
"file": {
"mimeType": "text/markdown",
"hashes": {
"quickXorHash": "fake-onedrive-roadmap-qxh"
}
},
"parentReference": {
"id": "fake-onedrive-folder-projects",
"path": "/drive/root:/Projects"
},
"createdDateTime": "2025-02-11T08:00:00Z",
"lastModifiedDateTime": "2025-02-12T08:30:00Z",
"webUrl": "https://onedrive.example/fake/projects/onedrive-roadmap.md"
}
],
"_file_contents": {
"fake-onedrive-canary": "Canary token for OneDrive E2E tests: SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_001\nThis file proves the OneDrive indexing pipeline ran end-to-end.",
"fake-onedrive-roadmap": "# OneDrive Roadmap\n\n- E2E OneDrive indexing\n- Deterministic Graph fake"
},
"_file_binary_paths": {
"fake-onedrive-pdf-canary": "binary/onedrive-canary.pdf"
}
}

View file

@ -0,0 +1,18 @@
{
"team": {
"id": "T_FAKE_SLACK_TEAM",
"name": "SurfSense E2E Slack Workspace"
},
"channel": {
"id": "C_FAKE_SLACK_CANARY",
"name": "slack-e2e-canary",
"purpose": "SurfSense E2E Slack canary channel"
},
"messages": [
{
"ts": "1715000000.000100",
"user": "U_FAKE_SLACK_USER",
"text": "This Slack message proves the live MCP tool fetched channel content. SURFSENSE_E2E_CANARY_TOKEN_SLACK_001"
}
]
}

View file

@ -0,0 +1,133 @@
"""Strict Jira MCP OAuth/tool fakes for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "jira_issues.json"
_AUTHORIZATION_URL = "https://mcp.atlassian.com/v1/authorize"
_REGISTRATION_URL = "https://cf.mcp.atlassian.com/v1/register"
_TOKEN_URL = "https://cf.mcp.atlassian.com/v1/token"
_MCP_URL = "https://mcp.atlassian.com/v1/mcp"
_CLIENT_ID = "fake-jira-mcp-client-id"
_CLIENT_SECRET = "fake-jira-mcp-client-secret"
_ACCESS_TOKEN = "fake-jira-mcp-access-token"
_REFRESH_TOKEN = "fake-jira-mcp-refresh-token"
_OAUTH_CODE = "fake-jira-oauth-code"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
def _issue_text(issue: dict[str, Any]) -> str:
return (
f"{issue['key']} {issue['summary']}\n"
f"id: {issue['id']}\n"
f"description: {issue['description']}"
)
async def _list_tools() -> SimpleNamespace:
return SimpleNamespace(
tools=[
SimpleNamespace(
name="getAccessibleAtlassianResources",
description="Get Jira sites accessible to the authenticated Atlassian user.",
inputSchema={
"type": "object",
"properties": {},
"required": [],
},
),
SimpleNamespace(
name="searchJiraIssuesUsingJql",
description="Search Jira issues using a Jira Query Language expression.",
inputSchema={
"type": "object",
"properties": {
"jql": {
"type": "string",
"description": "JQL query used to search Jira issues.",
},
"maxResults": {
"type": "integer",
"description": "Maximum number of matching issues to return.",
},
},
"required": ["jql"],
},
),
]
)
async def _call_tool(
tool_name: str, arguments: dict[str, Any] | None = None
) -> SimpleNamespace:
arguments = arguments or {}
site = _FIXTURE["site"]
issue = _FIXTURE["issues"][0]
if tool_name == "getAccessibleAtlassianResources":
if arguments:
raise ValueError(
f"Unexpected Jira getAccessibleAtlassianResources args: {arguments!r}"
)
text = f"{site['name']}\ncloud_id: {site['cloud_id']}\nurl: {site['url']}"
return SimpleNamespace(content=[SimpleNamespace(text=text)])
if tool_name == "searchJiraIssuesUsingJql":
jql = str(arguments.get("jql", ""))
if (
issue["summary"].lower() not in jql.lower()
and issue["key"].lower() not in jql.lower()
):
raise ValueError(f"Unexpected Jira JQL query: {jql!r}")
text = _issue_text(issue)
return SimpleNamespace(content=[SimpleNamespace(text=text)])
raise NotImplementedError(f"Unexpected Jira MCP tool call: {tool_name!r}")
def install(active_patches: list[Any]) -> None:
"""Register Jira MCP OAuth/tool handlers with the shared dispatchers."""
del active_patches
mcp_oauth_runtime.register_service(
mcp_url=_MCP_URL,
discovery_metadata={
"issuer": "https://mcp.atlassian.com",
"authorization_endpoint": _AUTHORIZATION_URL,
"token_endpoint": _TOKEN_URL,
"registration_endpoint": _REGISTRATION_URL,
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"response_types_supported": ["code"],
},
client_id=_CLIENT_ID,
client_secret=_CLIENT_SECRET,
token_endpoint=_TOKEN_URL,
registration_endpoint=_REGISTRATION_URL,
oauth_code=_OAUTH_CODE,
access_token=_ACCESS_TOKEN,
refresh_token=_REFRESH_TOKEN,
scope="read:jira-work read:me write:jira-work",
redirect_uri_substring="/api/v1/auth/mcp/jira/connector/callback",
)
mcp_runtime.register(
url=_MCP_URL,
expected_bearer=_ACCESS_TOKEN,
list_tools=_list_tools,
call_tool=_call_tool,
)

View file

@ -0,0 +1,305 @@
"""Strict Linear MCP OAuth/tool fakes for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "linear_issues.json"
_DISCOVERY_URL = "https://mcp.linear.app/.well-known/oauth-authorization-server"
_AUTHORIZATION_URL = "https://mcp.linear.app/authorize"
_REGISTRATION_URL = "https://mcp.linear.app/register"
_TOKEN_URL = "https://mcp.linear.app/token"
_MCP_URL = "https://mcp.linear.app/mcp"
_CLIENT_ID = "fake-linear-mcp-client-id"
_CLIENT_SECRET = "fake-linear-mcp-client-secret"
_ACCESS_TOKEN = "fake-linear-mcp-access-token"
_REFRESH_TOKEN = "fake-linear-mcp-refresh-token"
_OAUTH_CODE = "fake-linear-oauth-code"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E Linear fake missing surface: {self._component_name}.{name!r}. "
"Add it to surfsense_backend/tests/e2e/fakes/linear_module.py."
)
async def _fake_discover_oauth_metadata(
mcp_url: str,
*,
origin_override: str | None = None,
timeout: float = 15.0,
) -> dict[str, Any]:
del origin_override, timeout
if mcp_url != _MCP_URL:
raise NotImplementedError(f"Unexpected Linear MCP discovery url={mcp_url!r}")
return {
"issuer": "https://mcp.linear.app",
"authorization_endpoint": _AUTHORIZATION_URL,
"token_endpoint": _TOKEN_URL,
"registration_endpoint": _REGISTRATION_URL,
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"response_types_supported": ["code"],
}
async def _fake_register_client(
registration_endpoint: str,
redirect_uri: str,
*,
client_name: str = "SurfSense",
timeout: float = 15.0,
) -> dict[str, Any]:
del timeout
if registration_endpoint != _REGISTRATION_URL:
raise NotImplementedError(
f"Unexpected Linear DCR endpoint={registration_endpoint!r}"
)
if client_name != "SurfSense":
raise ValueError(f"Unexpected Linear DCR client_name={client_name!r}")
if "/api/v1/auth/mcp/linear/connector/callback" not in redirect_uri:
raise ValueError(f"Unexpected Linear redirect_uri={redirect_uri!r}")
return {
"client_id": _CLIENT_ID,
"client_secret": _CLIENT_SECRET,
"client_id_issued_at": 1_776_621_600,
"token_endpoint_auth_method": "client_secret_basic",
}
async def _fake_exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict[str, Any]:
del timeout
if token_endpoint != _TOKEN_URL:
raise NotImplementedError(
f"Unexpected Linear token_endpoint={token_endpoint!r}"
)
if code != _OAUTH_CODE:
raise ValueError(f"Unexpected fake Linear OAuth code: {code!r}")
if "/api/v1/auth/mcp/linear/connector/callback" not in redirect_uri:
raise ValueError(f"Unexpected Linear redirect_uri={redirect_uri!r}")
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
raise ValueError(
"Unexpected Linear client credentials: "
f"client_id={client_id!r} client_secret={client_secret!r}"
)
if not code_verifier:
raise ValueError("Linear token exchange missing code_verifier.")
return {
"access_token": _ACCESS_TOKEN,
"refresh_token": _REFRESH_TOKEN,
"expires_in": 3600,
"scope": "read write",
"token_type": "Bearer",
}
async def _fake_refresh_access_token(
token_endpoint: str,
refresh_token: str,
client_id: str,
client_secret: str,
*,
timeout: float = 30.0,
) -> dict[str, Any]:
del timeout
if token_endpoint != _TOKEN_URL:
raise NotImplementedError(
f"Unexpected Linear token_endpoint={token_endpoint!r}"
)
if refresh_token != _REFRESH_TOKEN:
raise ValueError(f"Unexpected fake Linear refresh token: {refresh_token!r}")
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
raise ValueError(
"Unexpected Linear refresh client credentials: "
f"client_id={client_id!r} client_secret={client_secret!r}"
)
return {
"access_token": _ACCESS_TOKEN,
"refresh_token": _REFRESH_TOKEN,
"expires_in": 3600,
"scope": "read write",
"token_type": "Bearer",
}
class _FakeStreamableHttpClient(_StrictFakeMixin):
_component_name = "streamablehttp_client"
def __init__(
self, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
):
del kwargs
if url != _MCP_URL:
raise NotImplementedError(f"Unexpected Linear MCP url={url!r}")
auth = (headers or {}).get("Authorization")
if auth != f"Bearer {_ACCESS_TOKEN}":
raise ValueError(f"Unexpected Linear MCP Authorization header: {auth!r}")
self.url = url
self.headers = headers or {}
async def __aenter__(self) -> tuple[object, object, None]:
return object(), object(), None
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
class _FakeClientSession(_StrictFakeMixin):
_component_name = "ClientSession"
def __init__(self, read: object, write: object):
self.read = read
self.write = write
async def __aenter__(self) -> _FakeClientSession:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def initialize(self) -> None:
return None
async def list_tools(self) -> SimpleNamespace:
return SimpleNamespace(
tools=[
SimpleNamespace(
name="list_issues",
description="List Linear issues visible to the authenticated user.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Text to search for in Linear issues.",
},
"limit": {
"type": "integer",
"description": "Maximum number of issues to return.",
},
},
"required": [],
},
),
SimpleNamespace(
name="get_issue",
description="Get a Linear issue by id or identifier.",
inputSchema={
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Issue id or identifier.",
}
},
"required": ["id"],
},
),
]
)
async def call_tool(
self, tool_name: str, *, arguments: dict[str, Any] | None = None
) -> SimpleNamespace:
arguments = arguments or {}
issue = _FIXTURE["issues"][0]
if tool_name == "list_issues":
query = str(arguments.get("query", ""))
if query and issue["title"].lower() not in query.lower():
raise ValueError(f"Unexpected Linear issue query: {query!r}")
text = (
f"{issue['identifier']} {issue['title']}\n"
f"id: {issue['id']}\n"
f"description: {issue['description']}"
)
return SimpleNamespace(content=[SimpleNamespace(text=text)])
if tool_name == "get_issue":
issue_id = arguments.get("id")
if issue_id not in {issue["id"], issue["identifier"]}:
raise ValueError(f"Unexpected Linear issue id: {issue_id!r}")
text = (
f"{issue['identifier']} {issue['title']}\n"
f"id: {issue['id']}\n"
f"description: {issue['description']}"
)
return SimpleNamespace(content=[SimpleNamespace(text=text)])
raise NotImplementedError(f"Unexpected Linear MCP tool call: {tool_name!r}")
def _fake_streamablehttp_client(
url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
) -> _FakeStreamableHttpClient:
return _FakeStreamableHttpClient(url, headers=headers, **kwargs)
async def _list_tools() -> SimpleNamespace:
return await _FakeClientSession(object(), object()).list_tools()
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> SimpleNamespace:
return await _FakeClientSession(object(), object()).call_tool(
tool_name, arguments=arguments
)
def install(active_patches: list[Any]) -> None:
"""Register Linear MCP OAuth/tool handlers with the shared dispatchers."""
del active_patches
mcp_oauth_runtime.register_service(
mcp_url=_MCP_URL,
discovery_metadata={
"issuer": "https://mcp.linear.app",
"authorization_endpoint": _AUTHORIZATION_URL,
"token_endpoint": _TOKEN_URL,
"registration_endpoint": _REGISTRATION_URL,
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"response_types_supported": ["code"],
},
client_id=_CLIENT_ID,
client_secret=_CLIENT_SECRET,
token_endpoint=_TOKEN_URL,
registration_endpoint=_REGISTRATION_URL,
oauth_code=_OAUTH_CODE,
access_token=_ACCESS_TOKEN,
refresh_token=_REFRESH_TOKEN,
scope="read write",
redirect_uri_substring="/api/v1/auth/mcp/linear/connector/callback",
)
mcp_runtime.register(
url=_MCP_URL,
expected_bearer=_ACCESS_TOKEN,
list_tools=_list_tools,
call_tool=_call_tool,
)

View file

@ -0,0 +1,48 @@
"""Deterministic LLM fake for the E2E indexing pipeline.
The production indexing pipeline summarizes documents with:
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
summary_result = await summary_chain.ainvoke({"document": ...})
summary_content = summary_result.content
The `llm` parameter is supplied per-document by
`app.services.llm_service.get_user_long_context_llm`. We patch THAT
function to return a langchain-native FakeListChatModel so the rest of
the chain works unchanged. No real LLM provider package is touched.
Run-backend / run-celery use unittest.mock.patch.start() to install
this at every binding site (the source module + every consumer that
did `from app.services.llm_service import get_user_long_context_llm`
at module load time).
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.language_models.fake_chat_models import FakeListChatModel
logger = logging.getLogger(__name__)
def _make_fake_llm() -> FakeListChatModel:
"""Build a fresh FakeListChatModel that returns a deterministic summary."""
# FakeListChatModel cycles through `responses` for each invocation. We
# supply a single deterministic string. The summary content is tagged
# with a marker that specs CAN assert on if they want, but the
# primary indexing assertion is on the file content (chunked + stored
# separately by the pipeline).
fake = FakeListChatModel(
responses=[
"E2E_FAKE_SUMMARY: Indexed by Playwright E2E run with deterministic LLM stub."
]
)
return fake
async def fake_get_user_long_context_llm(*args: Any, **kwargs: Any) -> Any:
"""Drop-in replacement for app.services.llm_service.get_user_long_context_llm."""
logger.info("[fake-llm] returning FakeListChatModel for E2E indexing")
return _make_fake_llm()

View file

@ -0,0 +1,217 @@
"""Shared strict MCP OAuth fake dispatcher for E2E tests."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from unittest.mock import patch
@dataclass(frozen=True)
class _OAuthHandler:
mcp_url: str
discovery_metadata: dict[str, Any]
client_id: str
client_secret: str
token_endpoint: str
registration_endpoint: str
oauth_code: str
access_token: str
refresh_token: str
scope: str
redirect_uri_substring: str
expected_origin_override: str | None = None
_SERVICES_BY_MCP_URL: dict[str, _OAuthHandler] = {}
_SERVICES_BY_REGISTRATION_URL: dict[str, _OAuthHandler] = {}
_SERVICES_BY_TOKEN_URL: dict[str, _OAuthHandler] = {}
def register_service(
*,
mcp_url: str,
discovery_metadata: dict[str, Any],
client_id: str,
client_secret: str,
token_endpoint: str,
registration_endpoint: str,
oauth_code: str,
access_token: str,
refresh_token: str,
scope: str,
redirect_uri_substring: str,
expected_origin_override: str | None = None,
) -> None:
"""Register deterministic MCP OAuth behavior for one service."""
handler = _OAuthHandler(
mcp_url=mcp_url,
discovery_metadata=discovery_metadata,
client_id=client_id,
client_secret=client_secret,
token_endpoint=token_endpoint,
registration_endpoint=registration_endpoint,
oauth_code=oauth_code,
access_token=access_token,
refresh_token=refresh_token,
scope=scope,
redirect_uri_substring=redirect_uri_substring,
expected_origin_override=expected_origin_override,
)
_register_unique(_SERVICES_BY_MCP_URL, mcp_url, handler, "MCP URL")
_register_unique(
_SERVICES_BY_REGISTRATION_URL,
registration_endpoint,
handler,
"registration endpoint",
)
_register_unique(_SERVICES_BY_TOKEN_URL, token_endpoint, handler, "token endpoint")
def _register_unique(
registry: dict[str, _OAuthHandler],
key: str,
handler: _OAuthHandler,
label: str,
) -> None:
existing = registry.get(key)
if existing is not None and existing != handler:
raise ValueError(f"MCP OAuth fake {label} already registered: {key!r}.")
registry[key] = handler
async def _fake_discover_oauth_metadata(
mcp_url: str,
*,
origin_override: str | None = None,
timeout: float = 15.0,
) -> dict[str, Any]:
del timeout
handler = _SERVICES_BY_MCP_URL.get(mcp_url)
if handler is None:
raise NotImplementedError(f"Unexpected MCP OAuth discovery url={mcp_url!r}")
if origin_override != handler.expected_origin_override:
raise ValueError(
f"Unexpected MCP OAuth origin_override for {mcp_url!r}: {origin_override!r}"
)
return dict(handler.discovery_metadata)
async def _fake_register_client(
registration_endpoint: str,
redirect_uri: str,
*,
client_name: str = "SurfSense",
timeout: float = 15.0,
) -> dict[str, Any]:
del timeout
handler = _SERVICES_BY_REGISTRATION_URL.get(registration_endpoint)
if handler is None:
raise NotImplementedError(
f"Unexpected MCP OAuth DCR endpoint={registration_endpoint!r}"
)
if client_name != "SurfSense":
raise ValueError(f"Unexpected MCP OAuth DCR client_name={client_name!r}")
if handler.redirect_uri_substring not in redirect_uri:
raise ValueError(
f"Unexpected MCP OAuth DCR redirect_uri={redirect_uri!r}; "
f"expected {handler.redirect_uri_substring!r}"
)
return {
"client_id": handler.client_id,
"client_secret": handler.client_secret,
"client_id_issued_at": 1_776_621_600,
"token_endpoint_auth_method": "client_secret_basic",
}
async def _fake_exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict[str, Any]:
del timeout
handler = _SERVICES_BY_TOKEN_URL.get(token_endpoint)
if handler is None:
raise NotImplementedError(
f"Unexpected MCP OAuth token_endpoint={token_endpoint!r}"
)
if code != handler.oauth_code:
raise ValueError(f"Unexpected fake MCP OAuth code: {code!r}")
if handler.redirect_uri_substring not in redirect_uri:
raise ValueError(
f"Unexpected MCP OAuth token redirect_uri={redirect_uri!r}; "
f"expected {handler.redirect_uri_substring!r}"
)
if client_id != handler.client_id or client_secret != handler.client_secret:
raise ValueError(
"Unexpected MCP OAuth client credentials: "
f"client_id={client_id!r} client_secret={client_secret!r}"
)
if not code_verifier:
raise ValueError("MCP OAuth token exchange missing code_verifier.")
return {
"access_token": handler.access_token,
"refresh_token": handler.refresh_token,
"expires_in": 3600,
"scope": handler.scope,
"token_type": "Bearer",
}
async def _fake_refresh_access_token(
token_endpoint: str,
refresh_token: str,
client_id: str,
client_secret: str,
*,
timeout: float = 30.0,
) -> dict[str, Any]:
del timeout
handler = _SERVICES_BY_TOKEN_URL.get(token_endpoint)
if handler is None:
raise NotImplementedError(
f"Unexpected MCP OAuth refresh token_endpoint={token_endpoint!r}"
)
if refresh_token != handler.refresh_token:
raise ValueError(f"Unexpected fake MCP OAuth refresh token: {refresh_token!r}")
if client_id != handler.client_id or client_secret != handler.client_secret:
raise ValueError(
"Unexpected MCP OAuth refresh client credentials: "
f"client_id={client_id!r} client_secret={client_secret!r}"
)
return {
"access_token": handler.access_token,
"refresh_token": handler.refresh_token,
"expires_in": 3600,
"scope": handler.scope,
"token_type": "Bearer",
}
def install(active_patches: list[Any]) -> None:
"""Patch generic MCP OAuth helper boundaries exactly once."""
targets = [
(
"app.services.mcp_oauth.discovery.discover_oauth_metadata",
_fake_discover_oauth_metadata,
),
("app.services.mcp_oauth.discovery.register_client", _fake_register_client),
(
"app.services.mcp_oauth.discovery.exchange_code_for_tokens",
_fake_exchange_code_for_tokens,
),
(
"app.services.mcp_oauth.discovery.refresh_access_token",
_fake_refresh_access_token,
),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,148 @@
"""Shared strict MCP streamable-HTTP runtime fake for E2E tests."""
from __future__ import annotations
import inspect
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
from unittest.mock import patch
ListToolsFn = Callable[[], Any | Awaitable[Any]]
CallToolFn = Callable[[str, dict[str, Any]], Any | Awaitable[Any]]
@dataclass(frozen=True)
class _RuntimeHandler:
expected_bearer: str
list_tools: ListToolsFn
call_tool: CallToolFn
_HANDLERS: dict[str, _RuntimeHandler] = {}
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E MCP runtime fake missing surface: {self._component_name}.{name!r}. "
"Add it to surfsense_backend/tests/e2e/fakes/mcp_runtime.py."
)
class _FakeEndpoint(_StrictFakeMixin):
_component_name = "streamablehttp_endpoint"
def __init__(self, url: str, handler: _RuntimeHandler):
self.url = url
self.handler = handler
class _FakeStreamableHttpClient(_StrictFakeMixin):
_component_name = "streamablehttp_client"
def __init__(
self, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
):
del kwargs
handler = _HANDLERS.get(url)
if handler is None:
raise NotImplementedError(f"Unexpected MCP streamable-http url={url!r}")
auth = (headers or {}).get("Authorization")
expected = f"Bearer {handler.expected_bearer}"
if auth != expected:
raise ValueError(
f"Unexpected MCP Authorization header for {url!r}: {auth!r}"
)
self.url = url
self.headers = headers or {}
self.handler = handler
async def __aenter__(self) -> tuple[_FakeEndpoint, _FakeEndpoint, None]:
return (
_FakeEndpoint(self.url, self.handler),
_FakeEndpoint(self.url, self.handler),
None,
)
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
class _FakeClientSession(_StrictFakeMixin):
_component_name = "ClientSession"
def __init__(self, read: _FakeEndpoint, write: _FakeEndpoint):
if read.handler is not write.handler:
raise ValueError("MCP fake received mismatched read/write endpoints.")
self.read = read
self.write = write
self.handler = read.handler
async def __aenter__(self) -> _FakeClientSession:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def initialize(self) -> None:
return None
async def list_tools(self) -> SimpleNamespace:
result = self.handler.list_tools()
if inspect.isawaitable(result):
result = await result
return result
async def call_tool(
self, tool_name: str, *, arguments: dict[str, Any] | None = None
) -> SimpleNamespace:
result = self.handler.call_tool(tool_name, arguments or {})
if inspect.isawaitable(result):
result = await result
return result
def _fake_streamablehttp_client(
url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
) -> _FakeStreamableHttpClient:
return _FakeStreamableHttpClient(url, headers=headers, **kwargs)
def register(
*,
url: str,
expected_bearer: str,
list_tools: ListToolsFn,
call_tool: CallToolFn,
) -> None:
"""Register a fake streamable-HTTP MCP server by canonical MCP URL."""
existing = _HANDLERS.get(url)
handler = _RuntimeHandler(
expected_bearer=expected_bearer,
list_tools=list_tools,
call_tool=call_tool,
)
if existing is not None and existing != handler:
raise ValueError(f"MCP runtime fake handler already registered for {url!r}.")
_HANDLERS[url] = handler
def install(active_patches: list[Any]) -> None:
"""Patch production MCP streamable-HTTP boundaries exactly once."""
targets = [
(
"app.agents.new_chat.tools.mcp_tool.streamablehttp_client",
_fake_streamablehttp_client,
),
("app.agents.new_chat.tools.mcp_tool.ClientSession", _FakeClientSession),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,444 @@
"""Strict native Google SDK fakes for Playwright E2E.
This module patches the production Google OAuth and Drive SDK bindings used by
the native Google connector happy paths. It deliberately does not replace the
whole Google package; unmodelled service methods fail loudly.
"""
from __future__ import annotations
import base64
import json
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any
from unittest.mock import patch
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from google.oauth2.credentials import Credentials
from .binary_loader import _resolve_file_bytes
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
_DRIVE_FIXTURE_PATH = _FIXTURES_DIR / "drive_files.json"
_GMAIL_FIXTURE_PATH = _FIXTURES_DIR / "gmail_messages.json"
_CALENDAR_FIXTURE_PATH = _FIXTURES_DIR / "calendar_events.json"
def _load_drive_fixture() -> dict[str, Any]:
with _DRIVE_FIXTURE_PATH.open() as f:
return json.load(f)
_DRIVE_FIXTURE = _load_drive_fixture()
def _load_gmail_fixture() -> dict[str, Any]:
with _GMAIL_FIXTURE_PATH.open() as f:
return json.load(f)
_GMAIL_FIXTURE = _load_gmail_fixture()
def _load_calendar_fixture() -> dict[str, Any]:
with _CALENDAR_FIXTURE_PATH.open() as f:
return json.load(f)
_CALENDAR_FIXTURE = _load_calendar_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E native Google fake missing surface: "
f"{self._component_name}.{name!r}. Add it to "
f"surfsense_backend/tests/e2e/fakes/native_google.py."
)
class _FakeFlow(_StrictFakeMixin):
_component_name = "Flow"
def __init__(
self, *, redirect_uri: str | None = None, scopes: list[str] | None = None
):
self.redirect_uri = redirect_uri
self.scopes = scopes or []
self.code_verifier: str | None = None
self.credentials = _fake_credentials(scopes=self.scopes)
@classmethod
def from_client_config(
cls,
client_config: dict[str, Any],
scopes: list[str],
redirect_uri: str | None = None,
**_: Any,
) -> _FakeFlow:
del client_config
return cls(redirect_uri=redirect_uri, scopes=scopes)
def authorization_url(self, *, state: str, **_: Any) -> tuple[str, str]:
if not self.redirect_uri:
raise ValueError("Fake Google Flow requires redirect_uri.")
parsed = urlparse(self.redirect_uri)
query = parse_qs(parsed.query)
query["code"] = ["fake-native-drive-oauth-code"]
query["state"] = [state]
redirect = urlunparse(parsed._replace(query=urlencode(query, doseq=True)))
return redirect, state
def fetch_token(self, *, code: str, **_: Any) -> None:
if code != "fake-native-drive-oauth-code":
raise ValueError(f"Unexpected fake Google OAuth code: {code!r}")
self.credentials = _fake_credentials(scopes=self.scopes)
def _fake_credentials(*, scopes: list[str] | None = None) -> Credentials:
return Credentials(
token="fake-native-drive-access-token",
refresh_token="fake-native-drive-refresh-token",
token_uri="https://oauth2.googleapis.com/token",
client_id="fake-native-drive-client-id",
client_secret="fake-native-drive-client-secret",
scopes=scopes or ["https://www.googleapis.com/auth/drive"],
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
)
class _FakeRequest(_StrictFakeMixin):
_component_name = "request"
def __init__(self, payload: Any):
self.payload = payload
self.http = None
def execute(self, **_: Any) -> Any:
return self.payload
class _FakeMediaRequest(_StrictFakeMixin):
_component_name = "media_request"
def __init__(self, content: bytes):
self.content = content
self.http = None
class _FakeMediaIoBaseDownload(_StrictFakeMixin):
_component_name = "MediaIoBaseDownload"
def __init__(self, fd, request: _FakeMediaRequest, chunksize: int | None = None):
del chunksize
self.fd = fd
self.request = request
self._done = False
def next_chunk(self) -> tuple[None, bool]:
if not self._done:
self.fd.write(self.request.content)
self._done = True
return None, True
class _FakeDriveFiles(_StrictFakeMixin):
_component_name = "drive.files"
def list(self, **kwargs: Any) -> _FakeRequest:
q = kwargs.get("q", "")
folder_id = "root"
if "in parents" in q:
try:
folder_id = q.split("'")[1]
except IndexError:
folder_id = "root"
files = _filter_drive_files_for_query(q, _DRIVE_FIXTURE.get(folder_id, []))
return _FakeRequest({"files": files, "nextPageToken": None})
def get(self, **kwargs: Any) -> _FakeRequest:
file_id = kwargs.get("fileId")
metadata = _drive_get_metadata(file_id)
return _FakeRequest(metadata)
def get_media(self, **kwargs: Any) -> _FakeMediaRequest:
file_id = kwargs.get("fileId")
content = _resolve_file_bytes(_DRIVE_FIXTURE, file_id, _FIXTURES_DIR)
if content is None:
raise NotImplementedError(
f"E2E native Google fake has no content for fileId={file_id!r}."
)
return _FakeMediaRequest(content)
def export(self, **kwargs: Any) -> _FakeRequest:
file_id = kwargs.get("fileId")
content = _DRIVE_FIXTURE.get("_file_contents", {}).get(file_id)
if content is None:
raise NotImplementedError(
f"E2E native Google fake has no export content for fileId={file_id!r}."
)
return _FakeRequest(content.encode("utf-8"))
class _FakeDriveChanges(_StrictFakeMixin):
_component_name = "drive.changes"
def getStartPageToken(self, **_: Any) -> _FakeRequest: # noqa: N802
return _FakeRequest({"startPageToken": "fake-native-start-page-token-1"})
def list(self, **_: Any) -> _FakeRequest:
return _FakeRequest(
{"changes": [], "newStartPageToken": "fake-native-start-page-token-1"}
)
class _FakeDriveService(_StrictFakeMixin):
_component_name = "drive_service"
def files(self) -> _FakeDriveFiles:
return _FakeDriveFiles()
def changes(self) -> _FakeDriveChanges:
return _FakeDriveChanges()
class _FakeGmailUsers(_StrictFakeMixin):
_component_name = "gmail.users"
def getProfile(self, **kwargs: Any) -> _FakeRequest: # noqa: N802
user_id = kwargs.get("userId")
if user_id != "me":
raise NotImplementedError(
f"Unexpected fake Gmail profile userId={user_id!r}"
)
return _FakeRequest({"emailAddress": "native-drive-e2e@surfsense.example"})
def messages(self) -> _FakeGmailMessages:
return _FakeGmailMessages()
class _FakeGmailMessages(_StrictFakeMixin):
_component_name = "gmail.messages"
def list(self, **kwargs: Any) -> _FakeRequest:
user_id = kwargs.get("userId")
if user_id != "me":
raise NotImplementedError(f"Unexpected fake Gmail list userId={user_id!r}")
max_results = int(kwargs.get("maxResults", 10) or 10)
messages = list(_GMAIL_FIXTURE.get("messages", []))[:max_results]
return _FakeRequest(
{
"messages": [
{
"id": message["id"],
"threadId": message.get("threadId"),
}
for message in messages
],
"resultSizeEstimate": len(messages),
}
)
def get(self, **kwargs: Any) -> _FakeRequest:
user_id = kwargs.get("userId")
if user_id != "me":
raise NotImplementedError(f"Unexpected fake Gmail get userId={user_id!r}")
message_id = kwargs.get("id")
detail = _GMAIL_FIXTURE.get("details", {}).get(message_id)
if detail is None:
raise NotImplementedError(
f"E2E native Gmail fake has no message detail for id={message_id!r}."
)
return _FakeRequest(_gmail_detail_to_native_message(detail))
class _FakeGmailService(_StrictFakeMixin):
_component_name = "gmail_service"
def users(self) -> _FakeGmailUsers:
return _FakeGmailUsers()
class _FakeCalendarEvents(_StrictFakeMixin):
_component_name = "calendar.events"
def list(self, **kwargs: Any) -> _FakeRequest:
calendar_id = kwargs.get("calendarId")
if calendar_id != "primary":
raise NotImplementedError(
f"Unexpected fake Calendar events calendarId={calendar_id!r}"
)
max_results = int(kwargs.get("maxResults", 250) or 250)
time_min = kwargs.get("timeMin")
time_max = kwargs.get("timeMax")
items = [
event
for event in _CALENDAR_FIXTURE.get("items", [])
if _calendar_event_in_range(event, time_min, time_max)
][:max_results]
return _FakeRequest(
{
"items": items,
"summary": "native-calendar-e2e@surfsense.example",
"timeZone": "UTC",
}
)
class _FakeCalendarService(_StrictFakeMixin):
_component_name = "calendar_service"
def events(self) -> _FakeCalendarEvents:
return _FakeCalendarEvents()
def _fake_build(service_name: str, version: str, **_: Any) -> Any:
if service_name == "drive" and version == "v3":
return _FakeDriveService()
if service_name == "gmail" and version == "v1":
return _FakeGmailService()
if service_name == "calendar" and version == "v3":
return _FakeCalendarService()
raise NotImplementedError(
f"E2E native Google fake cannot build {service_name!r} {version!r}."
)
def _extract_quoted_value(q: str, anchor: str) -> str | None:
anchor_idx = q.find(anchor)
if anchor_idx == -1:
return None
after_anchor = q[anchor_idx + len(anchor) :]
first_quote_idx = after_anchor.find("'")
if first_quote_idx == -1:
return None
after_first_quote = after_anchor[first_quote_idx + 1 :]
second_quote_idx = after_first_quote.find("'")
if second_quote_idx == -1:
return None
return after_first_quote[:second_quote_idx]
def _filter_drive_files_for_query(
q: str, files: list[dict[str, Any]]
) -> list[dict[str, Any]]:
filtered = list(files)
if "trashed = false" in q:
filtered = [entry for entry in filtered if entry.get("trashed") is not True]
excluded_mime_type = _extract_quoted_value(q, "mimeType !=")
if excluded_mime_type:
filtered = [
entry for entry in filtered if entry.get("mimeType") != excluded_mime_type
]
included_mime_type = _extract_quoted_value(q, "mimeType =")
if included_mime_type:
filtered = [
entry for entry in filtered if entry.get("mimeType") == included_mime_type
]
return filtered
def _drive_get_metadata(file_id: str | None) -> dict[str, Any]:
for items in _DRIVE_FIXTURE.values():
if not isinstance(items, list):
continue
for entry in items:
if entry.get("id") == file_id:
return dict(entry)
raise NotImplementedError(
f"E2E native Google fake has no metadata for fileId={file_id!r}."
)
def _parse_rfc3339(value: str | None) -> datetime | None:
if not value:
return None
normalized = value.replace("Z", "+00:00")
return datetime.fromisoformat(normalized)
def _calendar_event_start(event: dict[str, Any]) -> datetime | None:
start = event.get("start", {})
value = start.get("dateTime") or start.get("date")
parsed = _parse_rfc3339(value)
if parsed and parsed.tzinfo is None:
return parsed.replace(tzinfo=UTC)
return parsed
def _calendar_event_in_range(
event: dict[str, Any], time_min: str | None, time_max: str | None
) -> bool:
event_start = _calendar_event_start(event)
if event_start is None:
return True
parsed_min = _parse_rfc3339(time_min)
parsed_max = _parse_rfc3339(time_max)
if parsed_min and event_start < parsed_min:
return False
return not (parsed_max and event_start > parsed_max)
def _gmail_detail_to_native_message(detail: dict[str, Any]) -> dict[str, Any]:
message_text = detail.get("messageText", "")
encoded_body = base64.urlsafe_b64encode(message_text.encode("utf-8")).decode(
"ascii"
)
return {
"id": detail.get("id"),
"threadId": detail.get("threadId"),
"labelIds": ["INBOX", "IMPORTANT"],
"snippet": message_text[:160],
"payload": {
"mimeType": "text/plain",
"headers": [
{"name": "Subject", "value": detail.get("subject", "No Subject")},
{"name": "From", "value": detail.get("from", "Unknown Sender")},
{"name": "To", "value": detail.get("to", "Unknown Recipient")},
{"name": "Date", "value": detail.get("date", "Unknown Date")},
],
"body": {
"data": encoded_body,
"size": len(message_text),
},
},
}
def install(active_patches: list[Any]) -> None:
"""Patch production bindings to use native Google SDK fakes."""
targets = [
("app.routes.google_drive_add_connector_route.Flow", _FakeFlow),
("app.routes.google_gmail_add_connector_route.Flow", _FakeFlow),
("app.routes.google_calendar_add_connector_route.Flow", _FakeFlow),
("app.connectors.google_drive.client.build", _fake_build),
("app.connectors.google_gmail_connector.build", _fake_build),
("app.connectors.google_calendar_connector.build", _fake_build),
("app.agents.new_chat.tools.google_calendar.create_event.build", _fake_build),
("app.agents.new_chat.tools.google_calendar.update_event.build", _fake_build),
("app.agents.new_chat.tools.google_calendar.delete_event.build", _fake_build),
("googleapiclient.http.MediaIoBaseDownload", _FakeMediaIoBaseDownload),
(
"app.connectors.google_drive.client._build_thread_http",
lambda credentials: None,
),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,206 @@
"""Strict Notion OAuth/API fakes for Playwright E2E."""
from __future__ import annotations
import json
import types
from pathlib import Path
from typing import Any
from unittest.mock import patch
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "notion_pages.json"
_TOKEN_URL = "https://api.notion.com/v1/oauth/token"
_ACCESS_TOKEN = "fake-notion-access-token"
_REFRESH_TOKEN = "fake-notion-refresh-token"
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E Notion fake missing surface: {self._component_name}.{name!r}. "
"Add it to surfsense_backend/tests/e2e/fakes/notion_module.py."
)
class APIResponseError(Exception):
def __init__(
self,
message: str,
*,
status: int = 400,
code: str = "validation_error",
headers: dict[str, str] | None = None,
body: Any | None = None,
):
super().__init__(message)
self.status = status
self.code = code
self.headers = headers or {}
self.body = body or {"message": message}
errors = types.ModuleType("notion_client.errors")
errors.APIResponseError = APIResponseError
class _FakeBlocksChildren(_StrictFakeMixin):
_component_name = "notion.blocks.children"
async def list(self, **kwargs: Any) -> dict[str, Any]:
block_id = kwargs.get("block_id")
start_cursor = kwargs.get("start_cursor")
if start_cursor is not None:
raise NotImplementedError(
f"E2E Notion fake does not model block pagination cursor={start_cursor!r}."
)
blocks = _FIXTURE.get("blocks", {}).get(block_id)
if blocks is None:
raise APIResponseError(
f"Could not find block: {block_id}",
status=404,
code="object_not_found",
body={"message": f"Could not find block: {block_id}"},
)
return {
"object": "list",
"results": blocks,
"has_more": False,
"next_cursor": None,
}
class _FakeBlocks(_StrictFakeMixin):
_component_name = "notion.blocks"
def __init__(self) -> None:
self.children = _FakeBlocksChildren()
class AsyncClient(_StrictFakeMixin):
_component_name = "notion.AsyncClient"
def __init__(self, *, auth: str, **kwargs: Any):
del kwargs
if auth != _ACCESS_TOKEN:
raise ValueError(f"Unexpected fake Notion auth token: {auth!r}")
self.auth = auth
self.blocks = _FakeBlocks()
async def search(self, **kwargs: Any) -> dict[str, Any]:
unsupported = set(kwargs) - {"filter", "sort", "start_cursor"}
if unsupported:
raise NotImplementedError(
f"E2E Notion fake search got unsupported kwargs: {sorted(unsupported)}"
)
if kwargs.get("start_cursor") is not None:
raise NotImplementedError(
f"E2E Notion fake does not model search cursor={kwargs['start_cursor']!r}."
)
expected_filter = {"value": "page", "property": "object"}
if kwargs.get("filter") != expected_filter:
raise NotImplementedError(
f"E2E Notion fake search expected filter={expected_filter!r}, "
f"got {kwargs.get('filter')!r}."
)
return {
"object": "list",
"results": _FIXTURE.get("pages", []),
"has_more": False,
"next_cursor": None,
}
async def aclose(self) -> None:
return None
class _FakeTokenResponse(_StrictFakeMixin):
_component_name = "notion.oauth.response"
def __init__(self, payload: dict[str, Any], status_code: int = 200):
self._payload = payload
self.status_code = status_code
self.text = json.dumps(payload, sort_keys=True)
def json(self) -> dict[str, Any]:
return self._payload
class _FakeHttpxAsyncClient(_StrictFakeMixin):
_component_name = "httpx.AsyncClient"
def __init__(self, *args: Any, **kwargs: Any):
del args, kwargs
async def __aenter__(self) -> _FakeHttpxAsyncClient:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def post(self, url: str, **kwargs: Any) -> _FakeTokenResponse:
if url != _TOKEN_URL:
raise NotImplementedError(f"Unexpected Notion OAuth POST url={url!r}")
data = kwargs.get("json") or {}
headers = kwargs.get("headers") or {}
if "Authorization" not in headers:
raise ValueError(
"Notion OAuth token exchange missing Authorization header."
)
grant_type = data.get("grant_type")
if grant_type == "authorization_code":
if data.get("code") != "fake-notion-oauth-code":
raise ValueError(
f"Unexpected fake Notion OAuth code: {data.get('code')!r}"
)
elif grant_type == "refresh_token":
if data.get("refresh_token") != _REFRESH_TOKEN:
raise ValueError(
f"Unexpected fake Notion refresh token: {data.get('refresh_token')!r}"
)
else:
raise ValueError(f"Unexpected fake Notion grant_type: {grant_type!r}")
return _FakeTokenResponse(
{
"access_token": _ACCESS_TOKEN,
"refresh_token": _REFRESH_TOKEN,
"expires_in": 3600,
"workspace_id": "fake-notion-workspace-001",
"workspace_name": "SurfSense E2E Notion Workspace",
"workspace_icon": "https://surfsense.example/notion-icon.png",
"bot_id": "fake-notion-bot-001",
}
)
class _FakeHttpxModule(_StrictFakeMixin):
_component_name = "httpx"
AsyncClient = _FakeHttpxAsyncClient
def install(active_patches: list[Any]) -> None:
"""Patch production bindings that cannot be covered by sys.modules hijack."""
targets = [
(
"app.routes.notion_add_connector_route.httpx",
_FakeHttpxModule(),
),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,190 @@
"""Strict Microsoft OneDrive Graph fakes for Playwright E2E.
This module patches the OneDrive OAuth route and indexer consumer-site
bindings. It keeps the production add/callback/indexing flow intact while
serving deterministic Microsoft-shaped token, profile, metadata, and file
content responses.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import patch
import httpx
from .binary_loader import _resolve_file_bytes
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
_ONEDRIVE_FIXTURE_PATH = _FIXTURES_DIR / "onedrive_files.json"
def _load_onedrive_fixture() -> dict[str, Any]:
with _ONEDRIVE_FIXTURE_PATH.open() as f:
return json.load(f)
_ONEDRIVE_FIXTURE = _load_onedrive_fixture()
class _StrictFakeMixin:
_component_name: str = "<unknown>"
def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
f"E2E OneDrive fake missing surface: "
f"{self._component_name}.{name!r}. Add it to "
f"surfsense_backend/tests/e2e/fakes/onedrive_graph.py."
)
class _FakeOneDriveClient(_StrictFakeMixin):
_component_name = "OneDriveClient"
def __init__(self, session: Any, connector_id: int):
self._session = session
self._connector_id = connector_id
async def list_children(
self, item_id: str = "root"
) -> tuple[list[dict[str, Any]], str | None]:
items = _ONEDRIVE_FIXTURE.get(item_id)
if not isinstance(items, list):
return [], f"E2E OneDrive fake has no children for item_id={item_id!r}."
return [dict(item) for item in items], None
async def get_item_metadata(
self, item_id: str
) -> tuple[dict[str, Any] | None, str | None]:
metadata = _onedrive_get_metadata(item_id)
if metadata is None:
return None, f"E2E OneDrive fake has no metadata for item_id={item_id!r}."
return metadata, None
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
content = _resolve_file_bytes(_ONEDRIVE_FIXTURE, item_id, _FIXTURES_DIR)
if content is None:
return None, f"E2E OneDrive fake has no content for item_id={item_id!r}."
return content, None
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
content = _resolve_file_bytes(_ONEDRIVE_FIXTURE, item_id, _FIXTURES_DIR)
if content is None:
return f"E2E OneDrive fake has no content for item_id={item_id!r}."
with open(dest_path, "wb") as f:
f.write(content)
return None
async def get_delta(
self, folder_id: str | None = None, delta_link: str | None = None
) -> tuple[list[dict[str, Any]], str | None, str | None]:
folder_key = folder_id or "root"
if delta_link:
folder_key = delta_link.rsplit("/", 1)[-1].removesuffix("-delta")
if folder_key not in _ONEDRIVE_FIXTURE:
return (
[],
None,
f"E2E OneDrive fake has no delta for folder={folder_key!r}.",
)
return (
[],
f"https://graph.microsoft.com/v1.0/fake-delta/{folder_key}-delta",
None,
)
class _FakeAsyncClient(_StrictFakeMixin):
_component_name = "httpx.AsyncClient"
def __init__(self, *args: Any, **kwargs: Any):
del args, kwargs
async def __aenter__(self) -> _FakeAsyncClient:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
del exc_type, exc, tb
async def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
del args, kwargs
if "login.microsoftonline.com" in url and url.endswith("/token"):
return _json_response(
"POST",
url,
{
"access_token": "fake-onedrive-access-token",
"refresh_token": "fake-onedrive-refresh-token",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
},
)
raise NotImplementedError(f"E2E OneDrive fake unexpected POST URL: {url!r}")
async def get(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
del args, kwargs
if url == "https://graph.microsoft.com/v1.0/me":
return _json_response(
"GET",
url,
{
"mail": "onedrive-e2e@surfsense.example",
"userPrincipalName": "onedrive-e2e@surfsense.example",
"displayName": "SurfSense OneDrive E2E",
},
)
raise NotImplementedError(f"E2E OneDrive fake unexpected GET URL: {url!r}")
async def request(
self, method: str, url: str, *args: Any, **kwargs: Any
) -> httpx.Response:
del args, kwargs
raise NotImplementedError(
f"E2E OneDrive fake unexpected request: {method!r} {url!r}"
)
class _FakeHttpxModule(_StrictFakeMixin):
_component_name = "httpx"
AsyncClient = _FakeAsyncClient
def _json_response(
method: str, url: str, payload: dict[str, Any], status_code: int = 200
) -> httpx.Response:
return httpx.Response(
status_code=status_code,
json=payload,
request=httpx.Request(method, url),
)
def _onedrive_get_metadata(item_id: str | None) -> dict[str, Any] | None:
for items in _ONEDRIVE_FIXTURE.values():
if not isinstance(items, list):
continue
for entry in items:
if entry.get("id") == item_id:
return dict(entry)
return None
def install(active_patches: list[Any]) -> None:
"""Patch production OneDrive bindings to use strict Graph fakes."""
targets = [
("app.routes.onedrive_add_connector_route.httpx", _FakeHttpxModule()),
("app.routes.onedrive_add_connector_route.OneDriveClient", _FakeOneDriveClient),
(
"app.tasks.connector_indexers.onedrive_indexer.OneDriveClient",
_FakeOneDriveClient,
),
("app.connectors.onedrive.client.httpx", _FakeHttpxModule()),
]
for target, replacement in targets:
p = patch(target, replacement)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,213 @@
"""Strict Slack MCP OAuth/tool fakes for Playwright E2E."""
from __future__ import annotations
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from unittest.mock import patch
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "slack_messages.json"
_AUTHORIZATION_URL = "https://slack.com/oauth/v2_user/authorize"
_REGISTRATION_URL = "https://e2e-fake.invalid/mcp/slack-unused-register"
_TOKEN_URL = "https://slack.com/api/oauth.v2.user.access"
_MCP_URL = "https://mcp.slack.com/mcp"
_CLIENT_ID = "fake-slack-mcp-client-id"
_CLIENT_SECRET = "fake-slack-mcp-client-secret"
_ACCESS_TOKEN = "fake-slack-mcp-access-token"
_REFRESH_TOKEN = "fake-slack-mcp-refresh-token"
_OAUTH_CODE = "fake-slack-oauth-code"
_SCOPE = (
"search:read.public search:read.private search:read.mpim search:read.im "
"channels:history groups:history mpim:history im:history"
)
def _load_fixture() -> dict[str, Any]:
with _FIXTURE_PATH.open() as f:
return json.load(f)
_FIXTURE = _load_fixture()
async def _list_tools() -> SimpleNamespace:
return SimpleNamespace(
tools=[
SimpleNamespace(
name="slack_search_channels",
description="Search Slack channels visible to the authenticated user.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Text to search for in Slack channel names.",
},
"limit": {
"type": "integer",
"description": "Maximum number of channels to return.",
},
},
"required": [],
},
),
SimpleNamespace(
name="slack_read_channel",
description="Read messages from a Slack channel.",
inputSchema={
"type": "object",
"properties": {
"channel_id": {
"type": "string",
"description": "Slack channel id.",
}
},
"required": ["channel_id"],
},
),
SimpleNamespace(
name="slack_read_thread",
description="Read a Slack thread from a channel.",
inputSchema={
"type": "object",
"properties": {
"channel_id": {
"type": "string",
"description": "Slack channel id.",
},
"thread_ts": {
"type": "string",
"description": "Slack thread timestamp.",
},
},
"required": ["channel_id", "thread_ts"],
},
),
]
)
async def _call_tool(
tool_name: str, arguments: dict[str, Any] | None = None
) -> SimpleNamespace:
arguments = arguments or {}
channel = _FIXTURE["channel"]
message = _FIXTURE["messages"][0]
if tool_name == "slack_search_channels":
query = str(arguments.get("query", ""))
if query and channel["name"].lower() not in query.lower():
raise ValueError(f"Unexpected Slack channel query: {query!r}")
text = (
f"#{channel['name']} ({channel['id']})\n"
f"purpose: {channel['purpose']}\n"
f"latest_message: {message['text']}"
)
return SimpleNamespace(content=[SimpleNamespace(text=text)])
if tool_name in {"slack_read_channel", "slack_read_thread"}:
raise NotImplementedError(
f"Slack E2E fake does not exercise {tool_name!r}; "
"extend slack_module.py before using it in a journey."
)
raise NotImplementedError(f"Unexpected Slack MCP tool call: {tool_name!r}")
async def _fake_exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict[str, Any]:
if token_endpoint != _TOKEN_URL:
return await mcp_oauth_runtime._fake_exchange_code_for_tokens(
token_endpoint,
code,
redirect_uri,
client_id,
client_secret,
code_verifier,
timeout=timeout,
)
del timeout
if code != _OAUTH_CODE:
raise ValueError(f"Unexpected fake Slack OAuth code: {code!r}")
if "/api/v1/auth/mcp/slack/connector/callback" not in redirect_uri:
raise ValueError(f"Unexpected Slack redirect_uri={redirect_uri!r}")
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
raise ValueError(
"Unexpected Slack client credentials: "
f"client_id={client_id!r} client_secret={client_secret!r}"
)
if not code_verifier:
raise ValueError("Slack token exchange missing code_verifier.")
team = _FIXTURE["team"]
return {
"ok": True,
"scope": _SCOPE,
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": _REFRESH_TOKEN,
"authed_user": {
"id": "U_FAKE_SLACK_USER",
"scope": _SCOPE,
"access_token": _ACCESS_TOKEN,
"refresh_token": _REFRESH_TOKEN,
"expires_in": 3600,
"token_type": "Bearer",
},
"team": {
"id": team["id"],
"name": team["name"],
},
}
def install(active_patches: list[Any]) -> None:
"""Register Slack MCP OAuth/tool handlers with the shared dispatchers."""
mcp_oauth_runtime.register_service(
mcp_url=_MCP_URL,
discovery_metadata={
"issuer": "https://slack.com",
"authorization_endpoint": _AUTHORIZATION_URL,
"token_endpoint": _TOKEN_URL,
"registration_endpoint": _REGISTRATION_URL,
"code_challenge_methods_supported": ["S256"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"response_types_supported": ["code"],
},
client_id=_CLIENT_ID,
client_secret=_CLIENT_SECRET,
token_endpoint=_TOKEN_URL,
registration_endpoint=_REGISTRATION_URL,
oauth_code=_OAUTH_CODE,
access_token=_ACCESS_TOKEN,
refresh_token=_REFRESH_TOKEN,
scope=_SCOPE,
redirect_uri_substring="/api/v1/auth/mcp/slack/connector/callback",
)
mcp_runtime.register(
url=_MCP_URL,
expected_bearer=_ACCESS_TOKEN,
list_tools=_list_tools,
call_tool=_call_tool,
)
p = patch(
"app.services.mcp_oauth.discovery.exchange_code_for_tokens",
_fake_exchange_code_for_tokens,
)
p.start()
active_patches.append(p)

View file

@ -0,0 +1,4 @@
"""Test-only middleware. Mounted on the FastAPI `app` object inside
`tests/e2e/run_backend.py`, never registered by production startup
(`python main.py`).
"""

View file

@ -0,0 +1,54 @@
"""X-E2E-Scenario middleware.
Reads the X-E2E-Scenario request header and pipes the value into a
ContextVar that the strict fakes consult to switch between happy-path
and error scenarios on a per-request basis.
Mounted by tests/e2e/run_backend.py only. Production never adds this
middleware, so production never reads the header.
Supported scenarios:
- "happy" (default): everything succeeds with deterministic fixtures.
- "denied": Composio.connected_accounts.initiate returns a redirect URL
pointing at our callback with ?error=access_denied.
- "auth_expired": GOOGLEDRIVE_LIST_FILES returns an authentication
failure that the route translates to connector.config.auth_expired.
- "duplicate": no special fake behavior; the duplicate path is exercised
by running the OAuth flow twice with the same toolkit.
"""
from __future__ import annotations
from contextvars import ContextVar
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
_scenario: ContextVar[str] = ContextVar("e2e_scenario", default="happy")
def current_scenario() -> str:
"""Return the active E2E scenario for the current request context."""
return _scenario.get()
class ScenarioMiddleware(BaseHTTPMiddleware):
"""Reads X-E2E-Scenario and exposes it via a ContextVar.
The header is also forwarded as state on the request so route
handlers can branch if they ever need to (Composio routes do not).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response:
value = request.headers.get("X-E2E-Scenario", "happy")
token = _scenario.set(value)
try:
request.state.e2e_scenario = value
return await call_next(request)
finally:
_scenario.reset(token)

View file

@ -0,0 +1,247 @@
"""E2E backend entrypoint.
Hijacks third-party SDKs at sys.modules level BEFORE any production
code is imported, then starts the same FastAPI app + uvicorn that
`main.py` would run.
Production code is byte-identical with or without this file:
- `python main.py` is the production entrypoint (unchanged).
- `python tests/e2e/run_backend.py` is the test entrypoint, never imported by production.
- `surfsense_backend/.dockerignore` excludes `tests/`, so this file
physically does not exist in the production Docker image.
Defense in depth (see Composio Drive E2E Phase 1 plan):
1. sys.modules hijack here (Composio).
2. Strict __getattr__ inside fakes (NotImplementedError on unknown surface).
3. Network deny-list set in CI env (HTTPS_PROXY=http://127.0.0.1:1
plus sentinel API keys) so any leaked outbound HTTP fails loudly.
Usage:
cd surfsense_backend
uv run python tests/e2e/run_backend.py
"""
from __future__ import annotations
import logging
import os
import sys
# ---------------------------------------------------------------------------
# 1) Hijack sys.modules BEFORE any production import.
# Production: composio_service.py:11 does `from composio import Composio`.
# With this hijack in place, that import resolves to our strict fake.
# ---------------------------------------------------------------------------
# Make the surfsense_backend root importable as a top-level package so
# `import tests.e2e.fakes...` works regardless of how the entrypoint is
# invoked (uv run python tests/e2e/run_backend.py from repo root or from
# surfsense_backend/).
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_BACKEND_ROOT = os.path.abspath(os.path.join(_THIS_DIR, "..", ".."))
if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
# ---------------------------------------------------------------------------
# 2) Standard logging + dotenv so the rest of the app behaves like main.py.
# ---------------------------------------------------------------------------
from dotenv import load_dotenv # noqa: E402
load_dotenv()
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("surfsense.e2e.backend")
logger.warning(
"*** SURFSENSE E2E BACKEND ENTRYPOINT — fake Composio + LLM + embeddings ***"
)
# ---------------------------------------------------------------------------
# 3) Now import the production app. Every module in app.* loads here,
# creating their bindings (some of which we will patch in step 4).
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 4) Patch LLM + embedding bindings at every consumer site.
# Composio is already covered by the sys.modules hijack in step 1.
# ---------------------------------------------------------------------------
from unittest.mock import patch # noqa: E402
from app.app import app # noqa: E402
from tests.e2e.fakes import ( # noqa: E402
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
from tests.e2e.fakes.chat_llm import ( # noqa: E402
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
_active_patches: list = []
def _patch_llm_bindings() -> None:
"""Replace LLM factories at every known binding site."""
targets = [
"app.services.llm_service.get_user_long_context_llm",
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.google_gmail_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.notion_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.onedrive_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.dropbox_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.local_folder_indexer.get_user_long_context_llm",
"app.tasks.document_processors._save.get_user_long_context_llm",
"app.tasks.document_processors.markdown_processor.get_user_long_context_llm",
]
for target in targets:
try:
p = patch(target, fake_get_user_long_context_llm)
p.start()
_active_patches.append(p)
logger.info("[fake-llm] patched %s", target)
except (ModuleNotFoundError, AttributeError) as exc:
# Some indexers may not be loaded in every env. Log and move
# on — but do not silently let a known binding through.
logger.warning(
"[fake-llm] could not patch %s: %s. If production code "
"uses this path in E2E it will hit the real provider; "
"update tests/e2e/run_backend.py.",
target,
exc,
)
chat_targets = [
(
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]
for target, replacement in chat_targets:
try:
p = patch(target, replacement)
p.start()
_active_patches.append(p)
logger.info("[fake-chat-llm] patched %s", target)
except (ModuleNotFoundError, AttributeError) as exc:
logger.warning("[fake-chat-llm] could not patch %s: %s.", target, exc)
_patch_llm_bindings()
_fake_embeddings.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
# ---------------------------------------------------------------------------
# 5) Mount test-only middleware. Production never reaches this code.
# ---------------------------------------------------------------------------
from tests.e2e.middleware.scenario import ScenarioMiddleware # noqa: E402
app.add_middleware(ScenarioMiddleware)
# ---------------------------------------------------------------------------
# 6) Start uvicorn, mirroring main.py's behaviour.
# ---------------------------------------------------------------------------
import asyncio # noqa: E402
import uvicorn # noqa: E402
def _main() -> None:
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
host = os.getenv("UVICORN_HOST", "0.0.0.0")
port = int(os.getenv("UVICORN_PORT", "8000"))
log_level = os.getenv("UVICORN_LOG_LEVEL", "info")
config = uvicorn.Config(
app=app,
host=host,
port=port,
log_level=log_level,
reload=False,
)
server = uvicorn.Server(config)
server.run()
if __name__ == "__main__":
_main()

View file

@ -0,0 +1,214 @@
"""E2E Celery worker entrypoint.
Same sys.modules hijack + LLM/embedding patches as run_backend.py,
applied before importing the production celery_app. Celery workers
run in a separate Python interpreter, so the patches must be applied
here too they would NOT carry over from the FastAPI process.
Production is unaffected: celery_worker.py at the repo root is the
production entrypoint and never imports this file.
Usage:
cd surfsense_backend
uv run python tests/e2e/run_celery.py
"""
from __future__ import annotations
import logging
import os
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_BACKEND_ROOT = os.path.abspath(os.path.join(_THIS_DIR, "..", ".."))
if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
# ---------------------------------------------------------------------------
# 1) Hijack sys.modules BEFORE production celery imports anything.
# ---------------------------------------------------------------------------
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
sys.modules["composio"] = _fake_composio
sys.modules["notion_client"] = _fake_notion
sys.modules["notion_client.errors"] = _fake_notion.errors
# ---------------------------------------------------------------------------
# 2) Logging + dotenv.
# ---------------------------------------------------------------------------
from dotenv import load_dotenv # noqa: E402
load_dotenv()
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
os.environ.setdefault(
"CONFLUENCE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
)
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
os.environ.setdefault(
"NOTION_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/notion/connector/callback",
)
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
os.environ.setdefault(
"ONEDRIVE_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
)
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
os.environ.setdefault(
"DROPBOX_REDIRECT_URI",
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
)
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("surfsense.e2e.celery")
logger.warning("*** SURFSENSE E2E CELERY WORKER — fake Composio + LLM + embeddings ***")
# ---------------------------------------------------------------------------
# 3) Import the production celery_app. All task modules load here.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 4) Patch LLM + embedding bindings inside the worker process.
# ---------------------------------------------------------------------------
from unittest.mock import patch # noqa: E402
from app.celery_app import celery_app # noqa: E402
from tests.e2e.fakes import ( # noqa: E402
clickup_module as _fake_clickup_module,
confluence_indexer as _fake_confluence_indexer,
confluence_oauth as _fake_confluence_oauth,
dropbox_api as _fake_dropbox_api,
embeddings as _fake_embeddings,
jira_module as _fake_jira_module,
linear_module as _fake_linear_module,
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
mcp_runtime as _fake_mcp_runtime,
native_google as _fake_native_google,
notion_module as _fake_notion_module,
onedrive_graph as _fake_onedrive_graph,
slack_module as _fake_slack_module,
)
from tests.e2e.fakes.chat_llm import ( # noqa: E402
fake_create_chat_litellm_from_agent_config,
fake_create_chat_litellm_from_config,
)
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
_active_patches: list = []
def _patch_llm_bindings() -> None:
targets = [
"app.services.llm_service.get_user_long_context_llm",
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.google_gmail_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.notion_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.onedrive_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.dropbox_indexer.get_user_long_context_llm",
"app.tasks.connector_indexers.local_folder_indexer.get_user_long_context_llm",
"app.tasks.document_processors._save.get_user_long_context_llm",
"app.tasks.document_processors.markdown_processor.get_user_long_context_llm",
]
for target in targets:
try:
p = patch(target, fake_get_user_long_context_llm)
p.start()
_active_patches.append(p)
logger.info("[fake-llm] patched %s in celery worker", target)
except (ModuleNotFoundError, AttributeError) as exc:
logger.warning(
"[fake-llm] could not patch %s in celery worker: %s.",
target,
exc,
)
chat_targets = [
(
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]
for target, replacement in chat_targets:
try:
p = patch(target, replacement)
p.start()
_active_patches.append(p)
logger.info("[fake-chat-llm] patched %s in celery worker", target)
except (ModuleNotFoundError, AttributeError) as exc:
logger.warning(
"[fake-chat-llm] could not patch %s in celery worker: %s.",
target,
exc,
)
_patch_llm_bindings()
_fake_embeddings.install(_active_patches)
_fake_confluence_oauth.install(_active_patches)
_fake_confluence_indexer.install(_active_patches)
_fake_native_google.install(_active_patches)
_fake_onedrive_graph.install(_active_patches)
_fake_dropbox_api.install(_active_patches)
_fake_notion_module.install(_active_patches)
_fake_linear_module.install(_active_patches)
_fake_jira_module.install(_active_patches)
_fake_clickup_module.install(_active_patches)
_fake_mcp_runtime.install(_active_patches)
_fake_mcp_oauth_runtime.install(_active_patches)
_fake_slack_module.install(_active_patches)
# ---------------------------------------------------------------------------
# 5) Start the worker.
# ---------------------------------------------------------------------------
def _main() -> None:
# Default queues mirror production (default queue + connectors queue
# so Drive indexing tasks are picked up).
queue_name = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
queues = f"{queue_name},{queue_name}.connectors"
celery_app.worker_main(
argv=[
"worker",
"--loglevel=info",
f"--queues={queues}",
"--concurrency=2",
"--without-gossip",
"--without-mingle",
]
)
if __name__ == "__main__":
_main()

View file

@ -367,18 +367,26 @@ class TestPersistUserTurn:
db_thread,
patched_shielded_session,
):
"""The full ``{id, title, document_type}`` triple forwarded by
the FE must round-trip into a single ``mentioned-documents``
ContentPart on the persisted user message the history loader
renders the chips on reload from this part directly.
"""The full ``{id, title, document_type, kind}`` chip metadata
forwarded by the FE must round-trip into a single
``mentioned-documents`` ContentPart on the persisted user
message the history loader renders the chips on reload from
this part directly. Folder chips ride alongside doc chips so
the FE can render mixed mention bars without a second fetch.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8200"
mentioned = [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
{
"id": 33,
"title": "Reports",
"document_type": "FOLDER",
"kind": "folder",
},
]
msg_id = await persist_user_turn(
chat_id=thread_id,
@ -397,8 +405,61 @@ class TestPersistUserTurn:
assert row.content[1] == {
"type": "mentioned-documents",
"documents": [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
{
"id": 33,
"title": "Reports",
"document_type": "FOLDER",
"kind": "folder",
},
],
}
async def test_legacy_chip_without_kind_defaults_to_doc(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""Pre-folder clients send chips without ``kind``. The persistence
layer defaults them to ``"doc"`` so the round-trip stays
consistent on reload the FE schema's optional default
produces the same value, but persisting it explicitly keeps
the DB row self-describing.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8201"
mentioned = [
{"id": 77, "title": "Legacy", "document_type": "GENERAL"},
]
msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hi",
mentioned_documents=mentioned,
)
assert isinstance(msg_id, int)
row = await db_session.get(NewChatMessage, msg_id)
assert row is not None
assert isinstance(row.content, list)
mentioned_part = next(
p for p in row.content if p.get("type") == "mentioned-documents"
)
assert mentioned_part == {
"type": "mentioned-documents",
"documents": [
{
"id": 77,
"title": "Legacy",
"document_type": "GENERAL",
"kind": "doc",
},
],
}

View file

@ -0,0 +1 @@
"""Integration tests for Composio connector routes."""

View file

@ -0,0 +1,90 @@
"""Composio route integration fixtures.
The sys.modules hijack happens at module import time, before importing
app.app, so production `from composio import Composio` bindings resolve to
the strict E2E fake in this pytest process too.
"""
from __future__ import annotations
import sys
from collections.abc import AsyncGenerator
import httpx
import pytest
import pytest_asyncio
from httpx import ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession
from tests.e2e.fakes import composio_module as _fake_composio
sys.modules["composio"] = _fake_composio
from app.app import app, limiter # noqa: E402
from app.config import config # noqa: E402
from app.db import ( # noqa: E402
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user # noqa: E402
pytestmark = pytest.mark.integration
limiter.enabled = False
config.COMPOSIO_ENABLED = True
config.COMPOSIO_API_KEY = "e2e-integration-composio-sentinel"
config.NEXT_FRONTEND_URL = "http://localhost:3000"
@pytest_asyncio.fixture
async def client(
db_session: AsyncSession,
db_user: User,
) -> AsyncGenerator[httpx.AsyncClient, None]:
async def override_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_user() -> User:
return db_user
previous_overrides = app.dependency_overrides.copy()
app.dependency_overrides[get_async_session] = override_session
app.dependency_overrides[current_active_user] = override_user
try:
async with httpx.AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
timeout=30.0,
follow_redirects=False,
) as test_client:
yield test_client
finally:
app.dependency_overrides.clear()
app.dependency_overrides.update(previous_overrides)
@pytest_asyncio.fixture
async def drive_connector(
db_session: AsyncSession,
db_user: User,
db_search_space,
) -> SearchSourceConnector:
connector = SearchSourceConnector(
name="Google Drive (Composio) - e2e-fake@surfsense.example",
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
is_indexable=True,
config={
"composio_connected_account_id": "fake-acct-googledrive-existing",
"toolkit_id": "googledrive",
"toolkit_name": "Google Drive",
"is_indexable": True,
},
search_space_id=db_search_space.id,
user_id=db_user.id,
)
db_session.add(connector)
await db_session.flush()
return connector

View file

@ -0,0 +1,99 @@
from typing import Any
import httpx
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSourceConnector
from tests.e2e.fakes import composio_module
pytestmark = pytest.mark.integration
async def test_root_listing_returns_canned_items(
client: httpx.AsyncClient,
drive_connector: SearchSourceConnector,
):
response = await client.get(
f"/api/v1/connectors/{drive_connector.id}/composio-drive/folders"
)
assert response.status_code == 200
items = response.json()["items"]
names = {item["name"] for item in items}
assert "Projects" in names
assert "e2e-canary.txt" in names
assert any(
item["id"] == "fake-folder-projects" and item["isFolder"] is True
for item in items
)
async def test_save_round_trips_selected_files(
client: httpx.AsyncClient,
db_session: AsyncSession,
drive_connector: SearchSourceConnector,
):
selected_files = [
{
"id": "fake-file-canary",
"name": "e2e-canary.txt",
"mimeType": "text/plain",
}
]
response = await client.put(
f"/api/v1/search-source-connectors/{drive_connector.id}",
json={
"config": {
"selected_folders": [],
"selected_files": selected_files,
"indexing_options": {
"max_files_per_folder": 10,
"incremental_sync": False,
"include_subfolders": False,
},
}
},
)
assert response.status_code == 200
await db_session.refresh(drive_connector)
assert drive_connector.config["selected_files"] == selected_files
assert drive_connector.config["selected_folders"] == []
async def test_auth_expired_error_classifies_and_flags_connector(
monkeypatch: pytest.MonkeyPatch,
client: httpx.AsyncClient,
db_session: AsyncSession,
drive_connector: SearchSourceConnector,
):
def raise_auth_expired(
self: Any,
*,
slug: str,
connected_account_id: str,
user_id: str | None = None,
arguments: dict[str, Any] | None = None,
dangerously_skip_version_check: bool = True,
**kwargs: Any,
) -> dict[str, Any]:
raise RuntimeError(
"Token has been expired or revoked. (HTTP 401: invalid_grant)"
)
monkeypatch.setattr(composio_module._Tools, "execute", raise_auth_expired)
response = await client.get(
f"/api/v1/connectors/{drive_connector.id}/composio-drive/folders"
)
assert response.status_code == 400
body = response.text.lower()
assert "authentication" in body
assert "expired" in body
await db_session.refresh(drive_connector)
assert drive_connector.config["auth_expired"] is True

View file

@ -0,0 +1,113 @@
from uuid import UUID
import httpx
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
User,
)
from app.utils.oauth_security import OAuthStateManager
pytestmark = pytest.mark.integration
def _state_for(space_id: int, user_id: UUID, toolkit_id: str = "googledrive") -> str:
return OAuthStateManager(config.SECRET_KEY).generate_secure_state(
space_id=space_id,
user_id=user_id,
toolkit_id=toolkit_id,
)
async def _drive_connectors(
session: AsyncSession,
*,
user_id: UUID,
search_space_id: int,
) -> list[SearchSourceConnector]:
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
)
)
return list(result.scalars().all())
async def test_callback_with_error_param_redirects_to_denied_page(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
state = _state_for(db_search_space.id, db_user.id)
response = await client.get(
f"/api/v1/auth/composio/connector/callback?state={state}&error=access_denied"
)
assert response.status_code in {302, 303, 307}
location = response.headers["location"]
assert (
f"/dashboard/{db_search_space.id}/connectors/callback?"
"error=composio_oauth_denied"
) in location
connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert connectors == []
async def test_second_oauth_for_same_toolkit_takes_reconnection_branch(
client: httpx.AsyncClient,
db_session: AsyncSession,
db_user: User,
db_search_space: SearchSpace,
):
first_state = _state_for(db_search_space.id, db_user.id)
first_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={first_state}&connectedAccountId=fake-acct-googledrive-first"
)
assert first_response.status_code in {302, 303, 307}
first_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(first_connectors) == 1
first_connector = first_connectors[0]
assert first_connector.config["composio_connected_account_id"] == (
"fake-acct-googledrive-first"
)
second_state = _state_for(db_search_space.id, db_user.id)
second_response = await client.get(
"/api/v1/auth/composio/connector/callback"
f"?state={second_state}&connectedAccountId=fake-acct-googledrive-second"
)
assert second_response.status_code in {302, 303, 307}
second_connectors = await _drive_connectors(
db_session,
user_id=db_user.id,
search_space_id=db_search_space.id,
)
assert len(second_connectors) == 1
assert second_connectors[0].id == first_connector.id
assert second_connectors[0].config["composio_connected_account_id"] == (
"fake-acct-googledrive-second"
)

View file

@ -1,13 +1,26 @@
"""Integration tests: Drive indexer credential resolution for Composio vs native connectors.
"""Integration tests: Drive indexer client + credential resolution.
Exercises ``index_google_drive_files`` with a real PostgreSQL database
containing seeded connector records. Google API and Composio SDK are
mocked at their system boundaries.
Locks in the post-cea8618 architectural contract:
- Composio Drive connectors MUST use ``ComposioDriveClient`` (which routes
through ``composio.tools.execute``) and MUST NOT depend on a raw OAuth
access token via ``ComposioService.get_access_token``.
- Native Drive connectors MUST continue to use ``GoogleDriveClient`` with
credentials loaded from the connector config (no Composio involvement).
- Composio Drive connectors missing ``composio_connected_account_id`` MUST
short-circuit with a clear error before any client is constructed.
Background: prior to ``cea8618`` the Composio path used
``build_composio_credentials GoogleDriveClient``. That broke in production
once Composio's "Mask Connected Account Secrets" project toggle was on,
because the masked token failed the ``len(access_token) >= 20`` guard in
``ComposioService.get_access_token``. The structural assertions here make
any future regression to that token-based path fail at PR time.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
@ -25,6 +38,19 @@ pytestmark = pytest.mark.integration
_COMPOSIO_ACCOUNT_ID = "composio-test-account-123"
_INDEXER_MODULE = "app.tasks.connector_indexers.google_drive_indexer"
_GET_ACCESS_TOKEN = "app.services.composio_service.ComposioService.get_access_token"
def _mock_drive_client(*, list_files_return: tuple = ([], None, None)) -> MagicMock:
"""Duck-typed client mock whose ``list_files`` yields the supplied tuple.
Returning an empty file list short-circuits the indexer's full-scan
loop after the first page so the test exercises only the
construction + listing path, not download / ETL / DB writes.
"""
mock = MagicMock()
mock.list_files = AsyncMock(return_value=list_files_return)
return mock
@pytest_asyncio.fixture
@ -69,25 +95,31 @@ async def committed_composio_no_account_id(async_engine):
await cleanup_space(async_engine, data["search_space_id"])
@patch(_GET_ACCESS_TOKEN)
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
async def test_composio_connector_uses_composio_credentials(
mock_build_creds,
mock_client_cls,
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
async def test_composio_drive_indexer_uses_composio_drive_client(
mock_composio_client_cls,
mock_native_client_cls,
mock_task_logger_cls,
mock_get_access_token,
async_engine,
committed_drive_connector,
):
"""Drive indexer calls build_composio_credentials for a Composio connector
and passes the result to GoogleDriveClient."""
"""Composio Drive must construct ComposioDriveClient and never read raw tokens.
Reverting to the pre-cea8618 ``build_composio_credentials GoogleDriveClient``
path would either trip ``mock_native_client_cls.assert_not_called()`` (because
GoogleDriveClient would be constructed) or ``mock_get_access_token.assert_not_called()``
(because the credential builder reads the raw token).
"""
from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files,
)
data = committed_drive_connector
mock_creds = MagicMock(name="composio-credentials")
mock_build_creds.return_value = mock_creds
mock_composio_client_cls.return_value = _mock_drive_client()
mock_task_logger_cls.return_value = mock_task_logger()
maker = make_session_factory(async_engine)
@ -100,21 +132,29 @@ async def test_composio_connector_uses_composio_credentials(
folder_id="test-folder-id",
)
mock_build_creds.assert_called_once_with(_COMPOSIO_ACCOUNT_ID)
mock_client_cls.assert_called_once()
_, kwargs = mock_client_cls.call_args
assert kwargs.get("credentials") is mock_creds
mock_composio_client_cls.assert_called_once_with(
ANY,
data["connector_id"],
_COMPOSIO_ACCOUNT_ID,
entity_id=f"surfsense_{data['user_id']}",
)
mock_native_client_cls.assert_not_called()
mock_get_access_token.assert_not_called()
@patch(_GET_ACCESS_TOKEN)
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
async def test_composio_connector_without_account_id_returns_error(
mock_build_creds,
mock_composio_client_cls,
mock_native_client_cls,
mock_task_logger_cls,
mock_get_access_token,
async_engine,
committed_composio_no_account_id,
):
"""Drive indexer returns an error when Composio connector lacks connected_account_id."""
"""Missing ``composio_connected_account_id`` must short-circuit before any client construction."""
from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files,
)
@ -134,28 +174,32 @@ async def test_composio_connector_without_account_id_returns_error(
assert count == 0
assert error is not None
assert (
"composio_connected_account_id" in error.lower() or "composio" in error.lower()
)
mock_build_creds.assert_not_called()
assert "composio" in error.lower()
assert "connected_account_id" in error.lower()
mock_composio_client_cls.assert_not_called()
mock_native_client_cls.assert_not_called()
mock_get_access_token.assert_not_called()
@patch(_GET_ACCESS_TOKEN)
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
async def test_native_connector_does_not_use_composio_credentials(
mock_build_creds,
mock_client_cls,
async def test_native_connector_uses_google_drive_client(
mock_native_client_cls,
mock_composio_client_cls,
mock_task_logger_cls,
mock_get_access_token,
async_engine,
committed_native_drive_connector,
):
"""Drive indexer does NOT call build_composio_credentials for a native connector."""
"""Native Drive connector must use GoogleDriveClient (no Composio involvement at all)."""
from app.tasks.connector_indexers.google_drive_indexer import (
index_google_drive_files,
)
data = committed_native_drive_connector
mock_native_client_cls.return_value = _mock_drive_client()
mock_task_logger_cls.return_value = mock_task_logger()
maker = make_session_factory(async_engine)
@ -168,7 +212,6 @@ async def test_native_connector_does_not_use_composio_credentials(
folder_id="test-folder-id",
)
mock_build_creds.assert_not_called()
mock_client_cls.assert_called_once()
_, kwargs = mock_client_cls.call_args
assert kwargs.get("credentials") is None
mock_native_client_cls.assert_called_once_with(ANY, data["connector_id"])
mock_composio_client_cls.assert_not_called()
mock_get_access_token.assert_not_called()

View file

@ -0,0 +1,285 @@
"""Tests for the @-mention resolver.
These tests pin down the contract that ``mention_resolver`` is the
single seam between ``MentionedDocumentInfo`` chips (frontend) and the
canonical ``/documents/...`` virtual paths (agent). The streaming task,
priority middleware, and persistence layer all consume the resolver's
output keeping the tests focused on substitute-in-text + the
returned id partition keeps the seam stable across refactors.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat import mention_resolver
from app.agents.new_chat.mention_resolver import (
ResolvedMention,
ResolvedMentionSet,
resolve_mentions,
substitute_in_text,
)
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex
from app.schemas.new_chat import MentionedDocumentInfo
pytestmark = pytest.mark.unit
class TestSubstituteInText:
"""``substitute_in_text`` is a pure string transform and is exercised
on every cloud-mode turn, so it has to be both fast and behaviour-
identical to the frontend's ``parseMentionSegments`` (longest-token
first, single forward pass)."""
def test_returns_text_unchanged_when_no_tokens(self):
assert substitute_in_text("hello @foo", []) == "hello @foo"
def test_returns_text_unchanged_when_empty(self):
assert substitute_in_text("", [("@x", "/documents/x.xml")]) == ""
def test_replaces_single_token_with_backtick_path(self):
out = substitute_in_text(
"see @notes please",
[("@notes", "/documents/notes.xml")],
)
assert out == "see `/documents/notes.xml` please"
def test_longest_token_wins_over_prefix(self):
# ``@Project Roadmap`` must NOT be partially matched by ``@Project``.
# Mirrors the FE's parseMentionSegments contract.
token_to_path = [
("@Project Roadmap", "/documents/Roadmap.xml"),
("@Project", "/documents/Project.xml"),
]
out = substitute_in_text("about @Project Roadmap today", token_to_path)
assert out == "about `/documents/Roadmap.xml` today"
def test_handles_repeated_mentions(self):
out = substitute_in_text(
"@A and @A again @B",
[
("@A", "/documents/a.xml"),
("@B", "/documents/b.xml"),
],
)
assert (
out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`"
)
def test_does_not_match_inside_word(self):
# Substitution is positional — there's no word-boundary semantics.
# ``@Pro`` inside ``foo@Project`` still matches; this is the same
# behaviour as parseMentionSegments. The test pins it so a
# future "fix" doesn't accidentally diverge between FE/BE.
out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")])
assert out == "foo`/documents/p.xml`"
def test_idempotent_after_substitution(self):
# The output starts with a backtick, not ``@``, so re-running
# the substitution leaves it alone.
once = substitute_in_text("@A", [("@A", "/documents/a.xml")])
twice = substitute_in_text(once, [("@A", "/documents/a.xml")])
assert once == twice
class TestResolveMentions:
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
a ``ResolvedMentionSet`` whose id partitions feed
``KnowledgePriorityMiddleware``."""
@pytest.mark.asyncio
async def test_returns_empty_when_no_mentions(self):
session = MagicMock()
session.execute = AsyncMock()
result = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=None,
)
assert isinstance(result, ResolvedMentionSet)
assert result.mentions == []
assert result.token_to_path == []
assert result.mentioned_document_ids == []
assert result.mentioned_folder_ids == []
# No DB roundtrips when there's nothing to resolve.
session.execute.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch):
chip = MentionedDocumentInfo(
id=42,
title="Notes",
document_type="EXTENSION",
kind="doc",
)
doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=5,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "doc"
assert mention.id == 42
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml"
assert out.mentioned_document_ids == [42]
assert out.mentioned_folder_ids == []
assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path
@pytest.mark.asyncio
async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch):
chip = MentionedDocumentInfo(
id=9,
title="Reports",
document_type="FOLDER",
kind="folder",
)
folder_row = SimpleNamespace(id=9, name="Reports")
async def fake_build_index(_session, _ssid):
return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"})
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [folder_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=3,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "folder"
assert mention.id == 9
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/"
assert out.mentioned_document_ids == []
assert out.mentioned_folder_ids == [9]
@pytest.mark.asyncio
async def test_drops_chip_when_doc_is_missing(self, monkeypatch):
chip = MentionedDocumentInfo(
id=99, title="ghost", document_type="EXTENSION", kind="doc"
)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = []
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip],
)
assert out.mentions == []
assert out.mentioned_document_ids == []
assert out.token_to_path == []
@pytest.mark.asyncio
async def test_token_to_path_is_longest_first(self, monkeypatch):
# Two chips whose titles are prefixes of each other — the
# resolver MUST sort longest-first so substitution doesn't
# break the ``@Project Roadmap`` vs ``@Project`` invariant.
chip_short = MentionedDocumentInfo(
id=1, title="A", document_type="EXTENSION", kind="doc"
)
chip_long = MentionedDocumentInfo(
id=2, title="A long one", document_type="EXTENSION", kind="doc"
)
rows = [
SimpleNamespace(id=1, title="A", folder_id=None),
SimpleNamespace(id=2, title="A long one", folder_id=None),
]
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = rows
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip_short, chip_long],
)
tokens = [tok for tok, _ in out.token_to_path]
assert tokens == sorted(tokens, key=len, reverse=True)
@pytest.mark.asyncio
async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch):
# ``mentioned_document_ids`` (the legacy parallel array) must
# still resolve when no chip metadata is available — covers
# callers that haven't migrated to the discriminated chip list.
doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=2,
mentioned_documents=None,
mentioned_document_ids=[7],
)
assert out.mentioned_document_ids == [7]
assert len(out.mentions) == 1
assert out.mentions[0].title == "Legacy"
class TestResolvedMentionEquality:
"""Smoke check on the dataclass behaviour we rely on for asserting
test outputs."""
def test_equal_when_fields_equal(self):
a = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
b = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
assert a == b

View file

@ -196,3 +196,50 @@ class TestVirtualPathToDoc:
)
assert document is target_doc
assert session.execute.await_count == 2
@pytest.mark.asyncio
async def test_resolves_double_extension_for_uploaded_pdf(self):
# Regression: the agent renders every KB document under
# ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``),
# so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as
# ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path
# must round-trip back to the row even though the title itself
# does NOT end in ``.xml``.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml",
)
assert document is target_doc
@pytest.mark.asyncio
async def test_resolves_path_without_xml_suffix(self):
# The user (or a hand-edited link) may pass the title-only form
# ``/documents/2025-W2.pdf``. The resolver must still find the row
# by literal title equality.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf",
)
assert document is target_doc

View file

@ -0,0 +1,38 @@
from tests.e2e.fakes.composio_module import _drive_list_files
def _ids(result: dict) -> set[str]:
return {item["id"] for item in result["data"]["files"]}
def test_drive_list_files_filters_shortcuts_and_trashed_items():
result = _drive_list_files(
{
"q": (
"'root' in parents and trashed = false and "
"mimeType != 'application/vnd.google-apps.shortcut'"
)
}
)
ids = _ids(result)
assert "fake-file-canary" in ids
assert "fake-shortcut-canary" not in ids
assert "fake-file-trashed" not in ids
def test_drive_list_files_filters_to_exact_mime_type():
result = _drive_list_files(
{"q": "'root' in parents and trashed = false and mimeType = 'text/plain'"}
)
assert _ids(result) == {"fake-file-canary"}
def test_drive_list_files_uses_requested_parent_folder():
result = _drive_list_files(
{"q": "'fake-folder-projects' in parents and trashed = false"}
)
assert _ids(result) == {"fake-file-roadmap"}

View file

@ -1,8 +1,17 @@
"""Unit tests: build_composio_credentials returns valid Google Credentials.
"""Unit tests: Composio credential helpers + ``get_access_token`` masking guard.
Mocks the Composio SDK (external system boundary) and verifies that the
returned ``google.oauth2.credentials.Credentials`` object is correctly
configured with a token and a working refresh handler.
Covers two seams between Surfsense and Composio:
1. ``build_composio_credentials`` returns a ``google.oauth2.credentials.Credentials``
object with a working refresh handler (mocks the whole ``ComposioService``).
2. ``ComposioService.get_access_token`` rejects masked / missing tokens with
actionable error messages (mocks only the Composio SDK boundary so the
real guard logic is exercised).
The masking guard is the boundary handler that production tripped over when
Composio's "Mask Connected Account Secrets" project setting was enabled.
The corresponding fix landed in ``cea8618``; these tests lock that contract
in place so any future weakening of the guard surfaces immediately.
"""
from datetime import UTC, datetime
@ -14,6 +23,11 @@ from google.oauth2.credentials import Credentials
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# build_composio_credentials — high-level wrapper tests
# ---------------------------------------------------------------------------
@patch("app.services.composio_service.ComposioService")
def test_returns_credentials_with_token_and_expiry(mock_composio_service):
"""build_composio_credentials returns a Credentials object with the Composio access token."""
@ -54,3 +68,85 @@ def test_refresh_handler_fetches_fresh_token(mock_composio_service):
assert new_token == "refreshed-token"
assert new_expiry > datetime.now(UTC).replace(tzinfo=None)
assert mock_service.get_access_token.call_count == 2
# ---------------------------------------------------------------------------
# ComposioService.get_access_token — boundary masking guard tests
# ---------------------------------------------------------------------------
def _service_with_account(account: object):
"""Build a real ``ComposioService`` whose underlying Composio SDK is faked.
Only the SDK boundary is patched the real ``get_access_token`` method
runs, so changes to the masking / missing-token guards surface here.
"""
from app.services import composio_service as composio_service_module
with patch.object(composio_service_module, "Composio") as mock_composio_cls:
mock_client = MagicMock()
mock_client.connected_accounts.get.return_value = account
mock_composio_cls.return_value = mock_client
service = composio_service_module.ComposioService(api_key="unit-test-api-key")
# ``service.client`` already references ``mock_client`` even after the
# patch context exits because the constructor captured it during init.
return service
@pytest.mark.parametrize("masked_token", ["x", "xxxxxxxx", "x" * 19])
def test_get_access_token_raises_on_masked_token(masked_token):
"""Tokens shorter than the 20-char unmask threshold must raise with the dashboard hint.
Composio masks ``state.val.access_token`` by default (project setting
"Mask Connected Account Secrets"). A masked token will always silently
fail downstream OAuth calls, so the guard surfaces it with the exact
text needed to fix the dashboard config.
"""
fake_account = MagicMock()
fake_account.state.val.access_token = masked_token
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match="Mask Connected Account Secrets"):
service.get_access_token("any-account-id")
def test_get_access_token_raises_when_state_val_missing():
"""No ``state.val`` on the connected account is a hard failure with an account-id hint."""
fake_account = MagicMock()
fake_account.state = None
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No state\.val.*missing-state-account"):
service.get_access_token("missing-state-account")
def test_get_access_token_raises_when_access_token_empty():
"""``state.val`` present but ``access_token`` empty must fail before the masking check."""
fake_account = MagicMock()
fake_account.state.val.access_token = ""
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No access_token.*missing-token-account"):
service.get_access_token("missing-token-account")
def test_get_access_token_raises_when_access_token_none():
"""``state.val.access_token = None`` must fail before the masking check."""
fake_account = MagicMock()
fake_account.state.val.access_token = None
service = _service_with_account(fake_account)
with pytest.raises(ValueError, match=r"No access_token.*none-token-account"):
service.get_access_token("none-token-account")
def test_get_access_token_returns_unmasked_token():
"""Happy path: a >=20-char access token is returned verbatim."""
fake_account = MagicMock()
unmasked = "u" * 32
fake_account.state.val.access_token = unmasked
service = _service_with_account(fake_account)
assert service.get_access_token("happy-account") == unmasked

View file

@ -118,12 +118,8 @@ def test_get_by_tool_call_id_matches_action_request_payload() -> None:
tasks=(
_Task(
interrupts=(
_Interrupt(
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
),
_Interrupt(
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
),
_Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"),
_Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"),
)
),
)
@ -146,9 +142,7 @@ def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
def test_interrupt_without_id_falls_back_to_none() -> None:
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
)
state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),))
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0].interrupt_id is None

View file

@ -37,9 +37,7 @@ def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
"context": {"reason": "destructive"},
}
out = normalize_interrupt_payload(raw)
assert out["action_requests"] == [
{"name": "send_email", "args": {"to": "a@b"}}
]
assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}]
assert out["review_configs"] == [
{
"action_name": "send_email",

View file

@ -158,9 +158,7 @@ def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception(
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
),
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
@ -220,7 +218,7 @@ class _FakeStreamingService:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:

View file

@ -60,8 +60,14 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
service = _StreamingService()
agent = _Agent(
[
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content="Hello")},
},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content=" world")},
},
]
)
result = StreamingResult()

View file

@ -37,7 +37,9 @@ def test_clear_ignored_for_non_task_tool() -> None:
def test_clear_ignored_when_task_run_id_mismatches() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-open")
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
clear_task_span_if_delegating_task_ended(
state, tool_name="task", run_id="run-other"
)
assert state.active_span_id is not None
assert state.active_task_run_id == "run-open"

View file

@ -240,9 +240,7 @@ class TestToolHeavyTurn:
class TestToolCallSpanMetadata:
def test_input_available_merges_new_metadata_keys_after_start(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"})
b.on_tool_input_available(
"call_t",
"task",
@ -257,9 +255,7 @@ class TestToolCallSpanMetadata:
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"})
b.on_tool_input_available(
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
)

View file

@ -0,0 +1,93 @@
import base64
import hashlib
import hmac
import json
import time
from uuid import uuid4
import pytest
from fastapi import HTTPException
from app.utils.oauth_security import OAuthStateManager
SECRET = "unit-test-secret"
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
"""Build an OAuth state payload compatible with OAuthStateManager."""
signature_payload = payload.copy()
payload_str = json.dumps(signature_payload, sort_keys=True)
computed_signature = hmac.new(
SECRET.encode(),
payload_str.encode(),
hashlib.sha256,
).hexdigest()
encoded_payload = {
**signature_payload,
"signature": signature if signature is not None else computed_signature,
}
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
def test_validate_state_accepts_fresh_signed_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
user_id = uuid4()
state = mgr.generate_secure_state(
space_id=1,
user_id=user_id,
toolkit_id="googledrive",
)
decoded = mgr.validate_state(state)
assert decoded["space_id"] == 1
assert decoded["user_id"] == str(user_id)
assert decoded["toolkit_id"] == "googledrive"
def test_validate_state_rejects_expired_state():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
expired_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()) - 3600,
"toolkit_id": "googledrive",
}
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(expired_state)
assert exc.value.status_code == 400
assert "expired" in exc.value.detail.lower()
def test_validate_state_rejects_tampered_signature():
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
tampered_state = _encode_state(
{
"space_id": 1,
"user_id": str(uuid4()),
"timestamp": int(time.time()),
"toolkit_id": "googledrive",
},
signature="deadbeef" * 8,
)
with pytest.raises(HTTPException) as exc:
mgr.validate_state(tampered_state)
assert exc.value.status_code == 400
assert "tampering" in exc.value.detail.lower()
def test_validate_state_rejects_malformed_state():
mgr = OAuthStateManager(secret_key=SECRET)
with pytest.raises(HTTPException) as exc:
mgr.validate_state("not-base64-and-not-json")
assert exc.value.status_code == 400
assert "invalid state format" in exc.value.detail.lower()