mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
246dae40a8
229 changed files with 36484 additions and 436 deletions
|
|
@ -24,4 +24,6 @@ wheels/
|
|||
*.egg
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
htmlcov/
|
||||
|
||||
tests/
|
||||
|
|
@ -46,6 +46,10 @@ class SurfSenseContextSchema:
|
|||
Read by ``KnowledgePriorityMiddleware`` to seed its priority
|
||||
list. Stays out of the compiled-agent cache key — that's the
|
||||
whole point of putting it here.
|
||||
mentioned_folder_ids: KB folders the user @-mentioned this turn
|
||||
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
|
||||
entries in ``<priority_documents>`` so the agent prioritises
|
||||
walking those folders with ``ls`` / ``find_documents``.
|
||||
file_operation_contract: One-shot file operation contract emitted
|
||||
by ``FileIntentMiddleware`` for the upcoming turn.
|
||||
turn_id / request_id: Correlation IDs surfaced by the streaming
|
||||
|
|
@ -59,6 +63,7 @@ class SurfSenseContextSchema:
|
|||
|
||||
search_space_id: int | None = None
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
file_operation_contract: FileOperationContractState | None = None
|
||||
turn_id: str | None = None
|
||||
request_id: str | None = None
|
||||
|
|
|
|||
281
surfsense_backend/app/agents/new_chat/mention_resolver.py
Normal file
281
surfsense_backend/app/agents/new_chat/mention_resolver.py
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
"""Resolve @-mention chips to canonical virtual paths and substitute the
|
||||
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
|
||||
the agent sees.
|
||||
|
||||
The frontend's mention seam is a single discriminated-union list of
|
||||
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
|
||||
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
|
||||
reaches the backend stream task we have three needs that this module
|
||||
centralises:
|
||||
|
||||
1. Map each chip to its canonical virtual path
|
||||
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
|
||||
folders) so the agent sees concrete filesystem locations instead of
|
||||
ambiguous ``@``-titles.
|
||||
2. Substitute ``@title`` tokens in the user-typed text with backtick-
|
||||
wrapped paths so the path becomes part of the ``HumanMessage`` body
|
||||
the LLM consumes — without rewriting the persisted user message
|
||||
text (which keeps ``@title`` so chip rendering on reload is
|
||||
unchanged).
|
||||
3. Surface the resolved id sets (docs + folders) to the priority
|
||||
middleware so it can render ``[USER-MENTIONED]`` priority entries
|
||||
without re-doing path resolution.
|
||||
|
||||
This is intentionally one module — see the architectural note in
|
||||
``mention-paths-and-folders`` plan: previously the doc-resolution lived
|
||||
inline in ``stream_new_chat`` and the folder mention had no resolution
|
||||
at all. Centralising both behind a single ``resolve_mentions`` call
|
||||
turns a leaky multi-field seam into a single deeper interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, Folder
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedMention:
|
||||
"""Canonical view of a single @-mention chip.
|
||||
|
||||
``virtual_path`` is the path the agent will see (no trailing slash
|
||||
for documents, trailing ``/`` for folders to match the convention
|
||||
used by ``KnowledgeTreeMiddleware``).
|
||||
"""
|
||||
|
||||
kind: str # "doc" | "folder"
|
||||
id: int
|
||||
title: str
|
||||
virtual_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedMentionSet:
|
||||
"""Aggregate result of resolving a turn's mention chips.
|
||||
|
||||
``token_to_path`` maps ``@title`` (the literal token the user typed
|
||||
and the editor emitted) to the canonical virtual path for that
|
||||
chip. It is produced longest-token-first so substitution mirrors
|
||||
``parseMentionSegments`` on the frontend (a longer title like
|
||||
``@Project Roadmap`` is never shadowed by a shorter prefix
|
||||
``@Project``).
|
||||
|
||||
``mentioned_document_ids`` collapses doc + surfsense_doc chips into
|
||||
a single ordered, deduped list because the priority middleware
|
||||
treats them uniformly downstream — see
|
||||
``KnowledgePriorityMiddleware._compute_priority_paths``.
|
||||
"""
|
||||
|
||||
mentions: list[ResolvedMention] = field(default_factory=list)
|
||||
token_to_path: list[tuple[str, str]] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
|
||||
"""Return ``/documents/Folder/Sub/`` for a folder id.
|
||||
|
||||
Falls back to the documents root when the folder is missing from
|
||||
the index (deleted or in a different search space). Trailing slash
|
||||
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
|
||||
the agent's ``ls`` can dispatch on it as a directory.
|
||||
"""
|
||||
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
return f"{base}/" if not base.endswith("/") else base
|
||||
|
||||
|
||||
async def resolve_mentions(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
) -> ResolvedMentionSet:
|
||||
"""Resolve every @-mention chip on a turn into virtual paths.
|
||||
|
||||
The function takes both the ``mentioned_documents`` discriminated
|
||||
list (chip metadata used for substitution + persistence) and the
|
||||
parallel id arrays (``mentioned_document_ids``,
|
||||
``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two
|
||||
reasons:
|
||||
|
||||
* Legacy clients that haven't migrated to the unified chip list
|
||||
still send the id arrays — we treat the union as authoritative.
|
||||
* The id arrays are the canonical input to
|
||||
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
|
||||
returning the deduped, validated lists lets the route forward
|
||||
them unchanged.
|
||||
|
||||
Resolution is best-effort: a chip whose id no longer exists (e.g.
|
||||
document was deleted between mention and submit) is silently
|
||||
dropped. The agent still sees the user's original text, just
|
||||
without a backtick-path substitution for that chip.
|
||||
"""
|
||||
chip_doc_ids: list[int] = []
|
||||
chip_folder_ids: list[int] = []
|
||||
chip_titles_by_id: dict[tuple[str, int], str] = {}
|
||||
if mentioned_documents:
|
||||
for chip in mentioned_documents:
|
||||
kind = chip.kind
|
||||
if kind == "folder":
|
||||
chip_folder_ids.append(chip.id)
|
||||
else:
|
||||
chip_doc_ids.append(chip.id)
|
||||
chip_titles_by_id[(kind, chip.id)] = chip.title
|
||||
|
||||
doc_id_pool: list[int] = list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*(mentioned_document_ids or []),
|
||||
*(mentioned_surfsense_doc_ids or []),
|
||||
*chip_doc_ids,
|
||||
]
|
||||
)
|
||||
)
|
||||
folder_id_pool: list[int] = list(
|
||||
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
|
||||
)
|
||||
|
||||
if not doc_id_pool and not folder_id_pool:
|
||||
return ResolvedMentionSet()
|
||||
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
doc_rows: dict[int, Document] = {}
|
||||
if doc_id_pool:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
doc_rows[row.id] = row
|
||||
|
||||
folder_rows: dict[int, Folder] = {}
|
||||
if folder_id_pool:
|
||||
result = await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.id.in_(folder_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
folder_rows[row.id] = row
|
||||
|
||||
resolved: list[ResolvedMention] = []
|
||||
accepted_doc_ids: list[int] = []
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
for doc_id in doc_id_pool:
|
||||
row = doc_rows.get(doc_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping doc id=%s (not found in space=%s)",
|
||||
doc_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_doc_ids.append(row.id)
|
||||
|
||||
for folder_id in folder_id_pool:
|
||||
row = folder_rows.get(folder_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping folder id=%s (not found in space=%s)",
|
||||
folder_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
|
||||
path = _folder_virtual_path(row.id, index.folder_paths)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_folder_ids.append(row.id)
|
||||
|
||||
token_to_path: list[tuple[str, str]] = []
|
||||
seen_tokens: set[str] = set()
|
||||
for mention in resolved:
|
||||
if not mention.title:
|
||||
continue
|
||||
token = f"@{mention.title}"
|
||||
if token in seen_tokens:
|
||||
continue
|
||||
seen_tokens.add(token)
|
||||
token_to_path.append((token, mention.virtual_path))
|
||||
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
|
||||
|
||||
return ResolvedMentionSet(
|
||||
mentions=resolved,
|
||||
token_to_path=token_to_path,
|
||||
mentioned_document_ids=accepted_doc_ids,
|
||||
mentioned_folder_ids=accepted_folder_ids,
|
||||
)
|
||||
|
||||
|
||||
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
|
||||
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
|
||||
|
||||
Mirrors ``parseMentionSegments`` on the frontend: longest token
|
||||
first, single forward pass, no regex (titles can contain regex
|
||||
metacharacters). The substitution is idempotent for already-
|
||||
substituted text because the backtick-wrapped path no longer
|
||||
starts with ``@``.
|
||||
|
||||
Empty / no-op cases short-circuit so callers can pass this through
|
||||
unconditionally without paying for a scan.
|
||||
"""
|
||||
if not text or not token_to_path:
|
||||
return text
|
||||
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
matched: tuple[str, str] | None = None
|
||||
for token, path in token_to_path:
|
||||
if text.startswith(token, i):
|
||||
matched = (token, path)
|
||||
break
|
||||
if matched is None:
|
||||
out.append(text[i])
|
||||
i += 1
|
||||
continue
|
||||
token, path = matched
|
||||
out.append(f"`{path}`")
|
||||
i += len(token)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResolvedMention",
|
||||
"ResolvedMentionSet",
|
||||
"resolve_mentions",
|
||||
"substitute_in_text",
|
||||
]
|
||||
|
|
@ -54,6 +54,7 @@ from app.db import (
|
|||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
Folder,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
|
@ -836,6 +837,22 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
mention_ids = list(self.mentioned_document_ids)
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# Folder mentions live alongside doc mentions on the runtime
|
||||
# context. They never feed hybrid search (folders aren't
|
||||
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
|
||||
# priority entries so the agent walks the folder with ``ls`` /
|
||||
# ``find_documents`` instead of ignoring it. Cloud filesystem
|
||||
# mode only.
|
||||
folder_mention_ids: list[int] = []
|
||||
if (
|
||||
ctx is not None
|
||||
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
|
||||
== FilesystemMode.CLOUD
|
||||
):
|
||||
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
|
||||
if ctx_folders:
|
||||
folder_mention_ids = list(ctx_folders)
|
||||
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if mention_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
|
|
@ -880,12 +897,17 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||
|
||||
if folder_mention_ids:
|
||||
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
|
||||
priority = folder_entries + priority
|
||||
|
||||
_perf_log.info(
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
len(priority),
|
||||
len(mentioned_results),
|
||||
len(folder_mention_ids),
|
||||
)
|
||||
|
||||
update: dict[str, Any] = {
|
||||
|
|
@ -899,6 +921,58 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
update["messages"] = new_messages
|
||||
return update
|
||||
|
||||
async def _materialize_folder_priority(
|
||||
self, folder_ids: list[int]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
|
||||
|
||||
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
|
||||
path (matching ``KnowledgeTreeMiddleware`` and the agent's
|
||||
``ls`` adapter) and is flagged ``mentioned=True`` so the
|
||||
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
|
||||
``None`` so the renderer prints ``n/a`` — folders aren't
|
||||
ranked, the agent decides which children to read.
|
||||
"""
|
||||
if not folder_ids:
|
||||
return []
|
||||
async with shielded_async_session() as session:
|
||||
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||
folder_rows = await session.execute(
|
||||
select(Folder.id, Folder.name).where(
|
||||
Folder.search_space_id == self.search_space_id,
|
||||
Folder.id.in_(folder_ids),
|
||||
)
|
||||
)
|
||||
folder_titles: dict[int, str] = {
|
||||
row.id: row.name for row in folder_rows.all()
|
||||
}
|
||||
|
||||
entries: list[dict[str, Any]] = []
|
||||
seen: set[int] = set()
|
||||
for folder_id in folder_ids:
|
||||
if folder_id in seen:
|
||||
continue
|
||||
seen.add(folder_id)
|
||||
base = index.folder_paths.get(folder_id)
|
||||
if base is None:
|
||||
logger.debug(
|
||||
"kb_priority: dropping folder id=%s (missing from path index)",
|
||||
folder_id,
|
||||
)
|
||||
continue
|
||||
path = base if base.endswith("/") else f"{base}/"
|
||||
entries.append(
|
||||
{
|
||||
"path": path,
|
||||
"score": None,
|
||||
"document_id": None,
|
||||
"folder_id": folder_id,
|
||||
"title": folder_titles.get(folder_id, ""),
|
||||
"mentioned": True,
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
async def _materialize_priority(
|
||||
self, merged: list[dict[str, Any]]
|
||||
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.path_resolver import virtual_path_to_doc
|
||||
from app.db import (
|
||||
Chunk,
|
||||
Document,
|
||||
|
|
@ -752,7 +753,24 @@ async def get_document_by_virtual_path(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Resolve a knowledge-base document id by exact virtual path."""
|
||||
"""Resolve a knowledge-base document by its agent-facing virtual path.
|
||||
|
||||
The agent renders every document under ``/documents/...`` with a
|
||||
``.xml`` extension appended via ``safe_filename`` (so a PDF titled
|
||||
``2025-W2.pdf`` becomes ``/documents/2025-W2.pdf.xml``). When the user
|
||||
clicks that path in an answer, this endpoint must round-trip back to
|
||||
the underlying ``Document`` row regardless of its type — agent-created
|
||||
NOTE docs (which carry ``virtual_path`` in metadata), uploaded PDFs,
|
||||
and connector docs all flow through here.
|
||||
|
||||
Resolution is delegated to :func:`virtual_path_to_doc`, the single
|
||||
source of truth that handles:
|
||||
|
||||
* ``unique_identifier_hash`` lookup (agent NOTE fast path)
|
||||
* ``" (<doc_id>).xml"`` disambiguation suffixes
|
||||
* ``.xml`` extension stripping for title-based fallback
|
||||
* ``safe_filename`` round-trip for connector titles with lossy chars
|
||||
"""
|
||||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
|
|
@ -762,24 +780,19 @@ async def get_document_by_virtual_path(
|
|||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(
|
||||
Document.id,
|
||||
Document.title,
|
||||
Document.document_type,
|
||||
).filter(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_metadata["virtual_path"].as_string() == virtual_path,
|
||||
)
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
virtual_path=virtual_path,
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return DocumentTitleRead(
|
||||
id=row.id,
|
||||
title=row.title,
|
||||
document_type=row.document_type,
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
document_type=document.document_type,
|
||||
folder_id=document.folder_id,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1781,6 +1781,7 @@ async def handle_new_chat(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=request.mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
@ -2266,6 +2267,7 @@ async def regenerate_response(
|
|||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=request.mentioned_folder_ids,
|
||||
mentioned_documents=mentioned_documents_payload,
|
||||
checkpoint_id=target_checkpoint_id,
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
|
|
|
|||
|
|
@ -201,18 +201,34 @@ class NewChatUserImagePart(BaseModel):
|
|||
|
||||
|
||||
class MentionedDocumentInfo(BaseModel):
|
||||
"""Display metadata for a single ``@``-mentioned document.
|
||||
"""Display metadata for a single ``@``-mention chip.
|
||||
|
||||
The full triple ``{id, title, document_type}`` is forwarded by the
|
||||
frontend mention chip so the server can embed it in the persisted
|
||||
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
|
||||
The history loader then renders the chips on reload without an extra
|
||||
Carries either a knowledge-base document or a knowledge-base folder
|
||||
(discriminated by ``kind``). The full triple
|
||||
``{id, title, document_type}`` is forwarded by the frontend mention
|
||||
chip so the server can embed it in the persisted user message
|
||||
``ContentPart[]`` (single ``mentioned-documents`` part). The
|
||||
history loader then renders the chips on reload without an extra
|
||||
fetch — mirrors the pre-refactor frontend ``persistUserTurn`` shape.
|
||||
|
||||
``kind`` defaults to ``"doc"`` so legacy clients and persisted rows
|
||||
that predate folder mentions deserialise unchanged.
|
||||
"""
|
||||
|
||||
id: int
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
document_type: str = Field(..., min_length=1, max_length=100)
|
||||
kind: Literal["doc", "folder"] = Field(
|
||||
default="doc",
|
||||
description=(
|
||||
"Discriminator for the chip's referent: ``doc`` is a "
|
||||
"knowledge-base ``Document`` row, ``folder`` is a "
|
||||
"knowledge-base ``Folder`` row. Folders carry the sentinel "
|
||||
"``document_type='FOLDER'`` to keep the frontend dedup key "
|
||||
"``(kind:document_type:id)`` from colliding doc and folder "
|
||||
"ids that happen to share an integer value."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class NewChatRequest(BaseModel):
|
||||
|
|
@ -228,15 +244,26 @@ class NewChatRequest(BaseModel):
|
|||
mentioned_surfsense_doc_ids: list[int] | None = (
|
||||
None # Optional SurfSense documentation IDs mentioned with @ in the chat
|
||||
)
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional knowledge-base folder IDs the user mentioned with "
|
||||
"@. Resolved to virtual paths (``/documents/.../``) by "
|
||||
"``mention_resolver`` and surfaced to the agent via "
|
||||
"(a) backtick-wrapped substitution in ``user_query`` and "
|
||||
"(b) a ``[USER-MENTIONED]`` entry in ``<priority_documents>``. "
|
||||
"The agent's ``ls`` tool can then walk the folder itself."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document. Persisted as a ``mentioned-documents`` "
|
||||
"ContentPart on the user message so reload renders chips "
|
||||
"without an extra fetch. Optional and additive — when None "
|
||||
"the user message is persisted without a mentioned-documents "
|
||||
"part."
|
||||
"Display metadata (id, title, document_type, kind) for every "
|
||||
"@-mention chip — both documents and folders. Persisted as a "
|
||||
"``mentioned-documents`` ContentPart on the user message so "
|
||||
"reload renders chips without an extra fetch. Optional and "
|
||||
"additive — when None the user message is persisted without "
|
||||
"a mentioned-documents part."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = (
|
||||
|
|
@ -290,14 +317,22 @@ class RegenerateRequest(BaseModel):
|
|||
)
|
||||
mentioned_document_ids: list[int] | None = None
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional knowledge-base folder IDs the user mentioned with "
|
||||
"@ on the edited user turn. Only used when ``user_query`` is "
|
||||
"non-None (edit). Mirrors ``NewChatRequest.mentioned_folder_ids``."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Display metadata (id, title, document_type) for every "
|
||||
"@-mentioned document on the edited user turn. Only used "
|
||||
"when ``user_query`` is non-None (edit). Persisted as a "
|
||||
"``mentioned-documents`` ContentPart on the new user "
|
||||
"message. None means no chip metadata."
|
||||
"Display metadata (id, title, document_type, kind) for every "
|
||||
"@-mention chip on the edited user turn — both documents and "
|
||||
"folders. Only used when ``user_query`` is non-None (edit). "
|
||||
"Persisted as a ``mentioned-documents`` ContentPart on the "
|
||||
"new user message. None means no chip metadata."
|
||||
),
|
||||
)
|
||||
disabled_tools: list[str] | None = None
|
||||
|
|
@ -373,6 +408,16 @@ class ResumeRequest(BaseModel):
|
|||
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
|
||||
client_platform: Literal["web", "desktop"] = "web"
|
||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||
mentioned_folder_ids: list[int] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Forwarded for symmetry with /new_chat and /regenerate. "
|
||||
"Resume reuses the original interrupted user turn so this "
|
||||
"field is informational only — the originating turn's "
|
||||
"folder mentions already shaped the priority hints baked "
|
||||
"into the agent's checkpoint."
|
||||
),
|
||||
)
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -43,9 +43,7 @@ class EmitterRegistry:
|
|||
return main_emitter()
|
||||
|
||||
def has_active_subagents(self) -> bool:
|
||||
return any(
|
||||
emitter.level == "subagent" for emitter in self._by_run_id.values()
|
||||
)
|
||||
return any(emitter.level == "subagent" for emitter in self._by_run_id.values())
|
||||
|
||||
def clear(self) -> None:
|
||||
self._by_run_id.clear()
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ from ..emitter import Emitter, attach_emitted_by
|
|||
from ..envelope import format_sse
|
||||
|
||||
|
||||
def format_reasoning_start(
|
||||
reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
def format_reasoning_start(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter)
|
||||
)
|
||||
|
|
@ -28,9 +26,7 @@ def format_reasoning_delta(
|
|||
)
|
||||
|
||||
|
||||
def format_reasoning_end(
|
||||
reasoning_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
def format_reasoning_end(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ from ..envelope import format_sse
|
|||
|
||||
|
||||
def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "text-start", "id": text_id}, emitter)
|
||||
)
|
||||
return format_sse(attach_emitted_by({"type": "text-start", "id": text_id}, emitter))
|
||||
|
||||
|
||||
def format_text_delta(
|
||||
|
|
@ -26,6 +24,4 @@ def format_text_delta(
|
|||
|
||||
|
||||
def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return format_sse(
|
||||
attach_emitted_by({"type": "text-end", "id": text_id}, emitter)
|
||||
)
|
||||
return format_sse(attach_emitted_by({"type": "text-end", "id": text_id}, emitter))
|
||||
|
|
|
|||
|
|
@ -84,9 +84,7 @@ class StreamingService:
|
|||
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
|
||||
return lifecycle.format_step_finish(emitter=emitter)
|
||||
|
||||
def format_text_start(
|
||||
self, text_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
def format_text_start(self, text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return text.format_text_start(text_id, emitter=emitter)
|
||||
|
||||
def format_text_delta(
|
||||
|
|
@ -94,9 +92,7 @@ class StreamingService:
|
|||
) -> str:
|
||||
return text.format_text_delta(text_id, delta, emitter=emitter)
|
||||
|
||||
def format_text_end(
|
||||
self, text_id: str, *, emitter: Emitter | None = None
|
||||
) -> str:
|
||||
def format_text_end(self, text_id: str, *, emitter: Emitter | None = None) -> str:
|
||||
return text.format_text_end(text_id, emitter=emitter)
|
||||
|
||||
def format_reasoning_start(
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ logger = logging.getLogger(__name__)
|
|||
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
|
||||
|
||||
|
||||
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
|
||||
def _merge_tool_part_metadata(
|
||||
part: dict[str, Any], metadata: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins.
|
||||
|
||||
Used for tool-call linkage (``spanId``, ``thinkingStepId``, …): a later
|
||||
|
|
|
|||
|
|
@ -109,17 +109,18 @@ def _build_user_content(
|
|||
[{"type": "text", "text": "..."},
|
||||
{"type": "image", "image": "data:..."},
|
||||
{"type": "mentioned-documents", "documents": [{"id": int,
|
||||
"title": str, "document_type": str}, ...]}]
|
||||
"title": str, "document_type": str, "kind": "doc" | "folder"},
|
||||
...]}]
|
||||
|
||||
The companion reader is
|
||||
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
|
||||
which expects exactly this shape — keep them in sync.
|
||||
|
||||
``mentioned_documents``: optional list of ``{id, title, document_type}``
|
||||
dicts. When non-empty (and a ``mentioned-documents`` part is not already
|
||||
in some other input shape), a single ``{"type": "mentioned-documents",
|
||||
"documents": [...]}`` part is appended. Mirrors the FE injection at
|
||||
``page.tsx:281-286`` (``persistUserTurn``).
|
||||
``mentioned_documents``: optional list of mention chip dicts. Each
|
||||
dict may include a ``kind`` discriminator (``"doc"`` or ``"folder"``)
|
||||
so the persisted ContentPart round-trips folder chips on reload.
|
||||
When ``kind`` is missing we default to ``"doc"`` so legacy clients
|
||||
that haven't migrated to the union schema still persist correctly.
|
||||
"""
|
||||
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
|
||||
for url in user_image_data_urls or ():
|
||||
|
|
@ -135,11 +136,14 @@ def _build_user_content(
|
|||
document_type = doc.get("document_type")
|
||||
if doc_id is None or title is None or document_type is None:
|
||||
continue
|
||||
kind_raw = doc.get("kind", "doc")
|
||||
kind = kind_raw if kind_raw in ("doc", "folder") else "doc"
|
||||
normalized.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"title": str(title),
|
||||
"document_type": str(document_type),
|
||||
"kind": kind,
|
||||
}
|
||||
)
|
||||
if normalized:
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from app.agents.new_chat.memory_extraction import (
|
|||
extract_and_save_memory,
|
||||
extract_and_save_team_memory,
|
||||
)
|
||||
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
end_turn,
|
||||
get_cancel_state,
|
||||
|
|
@ -929,6 +930,7 @@ async def stream_new_chat(
|
|||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_surfsense_doc_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
mentioned_documents: list[dict[str, Any]] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
needs_history_bootstrap: bool = False,
|
||||
|
|
@ -958,6 +960,7 @@ async def stream_new_chat(
|
|||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
|
||||
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
|
||||
mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode)
|
||||
checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
|
||||
|
||||
Yields:
|
||||
|
|
@ -1502,6 +1505,53 @@ async def stream_new_chat(
|
|||
)
|
||||
recent_reports = list(recent_reports_result.scalars().all())
|
||||
|
||||
# Resolve @-mention chips to canonical virtual paths and rewrite
|
||||
# the user-typed text so the LLM sees ``\`/documents/...\``` instead
|
||||
# of bare ``@title``. The persisted user-message text keeps
|
||||
# ``@title`` so chip rendering on reload is unchanged — see
|
||||
# ``persistence._build_user_content``.
|
||||
#
|
||||
# Cloud mode only: local-folder mode keeps the legacy
|
||||
# ``@title`` text path; mention support there is a follow-up
|
||||
# task because the path scheme (mount-rooted) and the picker
|
||||
# UI both need separate work.
|
||||
accepted_folder_ids: list[int] = []
|
||||
if fs_mode == FilesystemMode.CLOUD.value and (
|
||||
mentioned_document_ids
|
||||
or mentioned_surfsense_doc_ids
|
||||
or mentioned_folder_ids
|
||||
or mentioned_documents
|
||||
):
|
||||
from app.schemas.new_chat import (
|
||||
MentionedDocumentInfo as _MentionedDocumentInfo,
|
||||
)
|
||||
|
||||
chip_objs: list[_MentionedDocumentInfo] | None = None
|
||||
if mentioned_documents:
|
||||
chip_objs = []
|
||||
for raw in mentioned_documents:
|
||||
if isinstance(raw, _MentionedDocumentInfo):
|
||||
chip_objs.append(raw)
|
||||
continue
|
||||
try:
|
||||
chip_objs.append(_MentionedDocumentInfo.model_validate(raw))
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"stream_new_chat: dropping malformed mention chip %r",
|
||||
raw,
|
||||
)
|
||||
|
||||
resolved = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
mentioned_documents=chip_objs,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
|
||||
mentioned_folder_ids=mentioned_folder_ids,
|
||||
)
|
||||
user_query = substitute_in_text(user_query, resolved.token_to_path)
|
||||
accepted_folder_ids = resolved.mentioned_folder_ids
|
||||
|
||||
# Format the user query with context (SurfSense docs + reports only)
|
||||
final_query = user_query
|
||||
context_parts = []
|
||||
|
|
@ -1901,6 +1951,9 @@ async def stream_new_chat(
|
|||
runtime_context = SurfSenseContextSchema(
|
||||
search_space_id=search_space_id,
|
||||
mentioned_document_ids=list(mentioned_document_ids or []),
|
||||
mentioned_folder_ids=list(
|
||||
accepted_folder_ids or mentioned_folder_ids or []
|
||||
),
|
||||
request_id=request_id,
|
||||
turn_id=stream_result.turn_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,9 +26,7 @@ def handle_report_progress(
|
|||
return None, last_active_step_items
|
||||
|
||||
phase = data.get("phase", "")
|
||||
topic_items = [
|
||||
item for item in last_active_step_items if item.startswith("Topic:")
|
||||
]
|
||||
topic_items = [item for item in last_active_step_items if item.startswith("Topic:")]
|
||||
|
||||
if phase in ("revising_section", "adding_section"):
|
||||
plan_items = [
|
||||
|
|
@ -56,7 +54,9 @@ def handle_report_progress(
|
|||
return frame, new_items
|
||||
|
||||
|
||||
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
|
||||
def handle_document_created(
|
||||
data: dict[str, Any], *, streaming_service: Any
|
||||
) -> str | None:
|
||||
if not data.get("id"):
|
||||
return None
|
||||
return streaming_service.format_data(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@ from app.tasks.chat.streaming.handlers.tools import (
|
|||
)
|
||||
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended
|
||||
from app.tasks.chat.streaming.relay.task_span import (
|
||||
clear_task_span_if_delegating_task_ended,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
||||
|
||||
|
||||
|
|
@ -32,9 +34,7 @@ def iter_tool_end_frames(
|
|||
run_id = event.get("run_id", "")
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
raw_output = event.get("data", {}).get("output", "")
|
||||
staged_file_path = (
|
||||
state.file_path_by_run.pop(run_id, None) if run_id else None
|
||||
)
|
||||
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
|
||||
|
||||
if tool_name == "update_memory":
|
||||
state.called_update_memory = True
|
||||
|
|
@ -116,6 +116,4 @@ def iter_tool_end_frames(
|
|||
)
|
||||
yield from iter_tool_completion_emission_frames(emission_ctx)
|
||||
|
||||
clear_task_span_if_delegating_task_ended(
|
||||
state, tool_name=tool_name, run_id=run_id
|
||||
)
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name=tool_name, run_id=run_id)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
return default_thinking.resolve_completed_thinking(
|
||||
tool_name, tool_output, last_items
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
@ -44,9 +46,7 @@ def resolve_completed_thinking(
|
|||
else "Report"
|
||||
)
|
||||
word_count = (
|
||||
tool_output.get("word_count", 0)
|
||||
if isinstance(tool_output, dict)
|
||||
else 0
|
||||
tool_output.get("word_count", 0) if isinstance(tool_output, dict) else 0
|
||||
)
|
||||
is_revision = (
|
||||
tool_output.get("is_revision", False)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
return default_thinking.resolve_completed_thinking(
|
||||
tool_name, tool_output, last_items
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Editing file", last_items)
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ def iter_completion_emission_frames(
|
|||
output_text = om.group(1) if om else ""
|
||||
thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "")
|
||||
|
||||
for sf_match in re.finditer(
|
||||
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||
):
|
||||
for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE):
|
||||
fpath = sf_match.group(1).strip()
|
||||
if fpath and fpath not in ctx.stream_result.sandbox_files:
|
||||
ctx.stream_result.sandbox_files.append(fpath)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Searching files", last_items)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Searching content", last_items)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
if isinstance(tool_output, dict):
|
||||
|
|
@ -38,9 +40,7 @@ def resolve_completed_thinking(
|
|||
paths = [str(p) for p in parsed]
|
||||
except (ValueError, SyntaxError):
|
||||
paths = [
|
||||
line.strip()
|
||||
for line in ls_output.strip().split("\n")
|
||||
if line.strip()
|
||||
line.strip() for line in ls_output.strip().split("\n") if line.strip()
|
||||
]
|
||||
for p in paths:
|
||||
name = p.rstrip("/").split("/")[-1]
|
||||
|
|
|
|||
|
|
@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
d = as_tool_input_dict(tool_input)
|
||||
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
display = p if len(p) <= 80 else "…" + p[-77:]
|
||||
return ToolStartThinking(title="Creating folder", items=[display] if display else [])
|
||||
return ToolStartThinking(
|
||||
title="Creating folder", items=[display] if display else []
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Creating folder", last_items)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Moving file", last_items)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Reading file", last_items)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Deleting file", last_items)
|
||||
|
|
|
|||
|
|
@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
d = as_tool_input_dict(tool_input)
|
||||
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
|
||||
display = p if len(p) <= 80 else "…" + p[-77:]
|
||||
return ToolStartThinking(title="Deleting folder", items=[display] if display else [])
|
||||
return ToolStartThinking(
|
||||
title="Deleting folder", items=[display] if display else []
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Deleting folder", last_items)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Writing file", last_items)
|
||||
|
|
|
|||
|
|
@ -20,15 +20,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
return ToolStartThinking(
|
||||
title="Planning tasks",
|
||||
items=(
|
||||
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
|
||||
if todo_count
|
||||
else []
|
||||
[f"{todo_count} task{'s' if todo_count != 1 else ''}"] if todo_count else []
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_output, tool_name
|
||||
return ("Planning tasks", last_items)
|
||||
|
|
|
|||
|
|
@ -58,14 +58,18 @@ def _emission_module(tool_name: str) -> str:
|
|||
|
||||
def _import_thinking(tool_name: str):
|
||||
try:
|
||||
return importlib.import_module(f"{_BASE}.{_thinking_module(tool_name)}.thinking")
|
||||
return importlib.import_module(
|
||||
f"{_BASE}.{_thinking_module(tool_name)}.thinking"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return importlib.import_module(f"{_BASE}.default.thinking")
|
||||
|
||||
|
||||
def _import_emission(tool_name: str):
|
||||
try:
|
||||
return importlib.import_module(f"{_BASE}.{_emission_module(tool_name)}.emission")
|
||||
return importlib.import_module(
|
||||
f"{_BASE}.{_emission_module(tool_name)}.emission"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return importlib.import_module(f"{_BASE}.default.emission")
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
|
|||
|
||||
|
||||
def resolve_completed_thinking(
|
||||
tool_name: str, tool_output: Any, last_items: list[str],
|
||||
tool_name: str,
|
||||
tool_output: Any,
|
||||
last_items: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
del tool_name
|
||||
items = last_items
|
||||
|
|
|
|||
|
|
@ -28,11 +28,7 @@ def iter_completion_emission_frames(
|
|||
xml,
|
||||
):
|
||||
chunk_url, content = m.group(1).strip(), m.group(2).strip()
|
||||
if (
|
||||
chunk_url.startswith("http")
|
||||
and chunk_url in citations
|
||||
and content
|
||||
):
|
||||
if chunk_url.startswith("http") and chunk_url in citations and content:
|
||||
citations[chunk_url]["snippet"] = (
|
||||
content[:200] + "…" if len(content) > 200 else content
|
||||
)
|
||||
|
|
|
|||
69
surfsense_backend/tests/e2e/README.md
Normal file
69
surfsense_backend/tests/e2e/README.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Backend E2E Test Harness
|
||||
|
||||
Strict fakes + alternative entrypoints used **only** by Playwright E2E.
|
||||
Excluded from the production Docker image via `.dockerignore`.
|
||||
|
||||
## Files
|
||||
|
||||
| Path | Role |
|
||||
| -------------------------------- | ------------------------------------------------------------------------------- |
|
||||
| `run_backend.py` | FastAPI entrypoint that hijacks `sys.modules` before importing `app.app:app` |
|
||||
| `run_celery.py` | Celery worker entrypoint with the same hijack + patch logic |
|
||||
| `middleware/scenario.py` | `X-E2E-Scenario` header → ContextVar (read by fakes) |
|
||||
| `fakes/composio_module.py` | Strict drop-in for the `composio` package; raises on unknown surface |
|
||||
| `fakes/llm.py` | `fake_get_user_long_context_llm` returning a `FakeListChatModel` |
|
||||
| `fakes/embeddings.py` | Deterministic 0.1-vector `embed_text` / `embed_texts` |
|
||||
| `fakes/fixtures/drive_files.json`| Canned Drive listings + file contents (incl. canary tokens) |
|
||||
|
||||
## Why a sys.modules hijack?
|
||||
|
||||
Production code does `from composio import Composio` at module load
|
||||
time. By the time the FastAPI app object exists, that binding has
|
||||
already been resolved. The hijack runs **before** any `app.*` import,
|
||||
so the binding resolves to our strict fake. No production source
|
||||
changes; fakes are physically excluded from production images.
|
||||
|
||||
Belt + suspenders + no internet: the strict `__getattr__` in every
|
||||
fake raises `NotImplementedError` if a future production code path
|
||||
introduces a new SDK call. CI also sets `HTTPS_PROXY=http://127.0.0.1:1`
|
||||
plus sentinel API keys so any leaked outbound HTTP fails immediately.
|
||||
|
||||
## Adding a new fake
|
||||
|
||||
1. Create `fakes/<sdk>_module.py` modelled on `composio_module.py`.
|
||||
2. In `run_backend.py` and `run_celery.py`, register
|
||||
`sys.modules["<sdk>"] = _fake_<sdk>` before the `from app.app import app`
|
||||
line.
|
||||
3. If the new fake needs scenario branching, read from
|
||||
`tests.e2e.middleware.scenario.current_scenario()`.
|
||||
|
||||
## Reused by backend integration tests
|
||||
|
||||
The strict fakes are not only for Playwright. Backend route integration
|
||||
tests can import the same fake before importing `app.app`, so Composio
|
||||
route tests exercise production route code without touching the real
|
||||
SDK:
|
||||
|
||||
```python
|
||||
from tests.e2e.fakes import composio_module as _fake_composio
|
||||
sys.modules["composio"] = _fake_composio
|
||||
from app.app import app
|
||||
```
|
||||
|
||||
See `surfsense_backend/tests/integration/composio/conftest.py` for the
|
||||
current pattern.
|
||||
|
||||
## Running locally
|
||||
|
||||
```bash
|
||||
cd surfsense_backend
|
||||
uv run python tests/e2e/run_backend.py
|
||||
# in a second shell:
|
||||
uv run python tests/e2e/run_celery.py
|
||||
```
|
||||
|
||||
Then in `surfsense_web`:
|
||||
|
||||
```bash
|
||||
pnpm test:e2e
|
||||
```
|
||||
7
surfsense_backend/tests/e2e/__init__.py
Normal file
7
surfsense_backend/tests/e2e/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""E2E test harness root.
|
||||
|
||||
This package is loaded only by the test entrypoints
|
||||
(`tests/e2e/run_backend.py`, `tests/e2e/run_celery.py`). It is excluded
|
||||
from the production Docker image via `surfsense_backend/.dockerignore`,
|
||||
so production binaries never see this code.
|
||||
"""
|
||||
8
surfsense_backend/tests/e2e/fakes/__init__.py
Normal file
8
surfsense_backend/tests/e2e/fakes/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Strict fakes for third-party SDKs, used in E2E mode only.
|
||||
|
||||
Every fake here implements __getattr__ that raises NotImplementedError
|
||||
on any unknown surface. Combined with sys.modules-level hijacking in
|
||||
run_backend.py / run_celery.py, this makes silent pass-through to the
|
||||
real SDK impossible: a future production code path that introduces a
|
||||
new SDK call site fails CI with a clear "add this to the fake" message.
|
||||
"""
|
||||
23
surfsense_backend/tests/e2e/fakes/binary_loader.py
Normal file
23
surfsense_backend/tests/e2e/fakes/binary_loader.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Helpers for serving text and binary fixture file bodies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _resolve_file_bytes(
|
||||
fixture: dict[str, Any], key: str | None, fixtures_dir: Path
|
||||
) -> bytes | None:
|
||||
"""Resolve a fake file body, preferring binary fixture files over text."""
|
||||
if not key:
|
||||
return None
|
||||
|
||||
binary_path = fixture.get("_file_binary_paths", {}).get(key)
|
||||
if binary_path is not None:
|
||||
return (fixtures_dir / binary_path).read_bytes()
|
||||
|
||||
content = fixture.get("_file_contents", {}).get(key)
|
||||
if content is None:
|
||||
return None
|
||||
return content.encode("utf-8")
|
||||
733
surfsense_backend/tests/e2e/fakes/chat_llm.py
Normal file
733
surfsense_backend/tests/e2e/fakes/chat_llm.py
Normal file
|
|
@ -0,0 +1,733 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Self
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
DRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001"
|
||||
DRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DRIVE_PDF_001"
|
||||
DRIVE_PDF_CANARY_FILE = "e2e-canary.pdf"
|
||||
COMPOSIO_DRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_COMPOSIO_DRIVE_PDF_001"
|
||||
COMPOSIO_DRIVE_PDF_CANARY_FILE = "e2e-composio-canary.pdf"
|
||||
GMAIL_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_GMAIL_001"
|
||||
GMAIL_CANARY_SUBJECT = "E2E Canary Email"
|
||||
GMAIL_CANARY_MESSAGE_ID = "fake-msg-canary-001"
|
||||
CALENDAR_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CALENDAR_001"
|
||||
CALENDAR_CANARY_SUMMARY = "E2E Canary Calendar Event"
|
||||
ONEDRIVE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_001"
|
||||
ONEDRIVE_CANARY_FILE = "e2e-onedrive-canary.txt"
|
||||
ONEDRIVE_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_PDF_001"
|
||||
ONEDRIVE_PDF_CANARY_FILE = "e2e-onedrive-canary.pdf"
|
||||
DROPBOX_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_001"
|
||||
DROPBOX_CANARY_FILE = "e2e-dropbox-canary.txt"
|
||||
DROPBOX_PDF_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_PDF_001"
|
||||
DROPBOX_PDF_CANARY_FILE = "e2e-dropbox-canary.pdf"
|
||||
NOTION_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_NOTION_001"
|
||||
NOTION_CANARY_TITLE = "E2E Canary Notion Page"
|
||||
CONFLUENCE_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CONFLUENCE_001"
|
||||
CONFLUENCE_CANARY_TITLE = "E2E Canary Confluence Page"
|
||||
LINEAR_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_LINEAR_001"
|
||||
LINEAR_CANARY_TITLE = "E2E Canary Linear Issue"
|
||||
JIRA_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_JIRA_001"
|
||||
JIRA_CANARY_SUMMARY = "E2E Canary Jira Issue"
|
||||
JIRA_CANARY_KEY = "E2E-101"
|
||||
SLACK_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_SLACK_001"
|
||||
SLACK_CANARY_CHANNEL = "slack-e2e-canary"
|
||||
CLICKUP_CANARY_TOKEN = "SURFSENSE_E2E_CANARY_TOKEN_CLICKUP_001"
|
||||
CLICKUP_CANARY_TITLE = "E2E Canary ClickUp Task"
|
||||
CLICKUP_CANARY_TASK_ID = "fake-clickup-task-canary-001"
|
||||
MANUAL_UPLOAD_MD_CANARY_TOKEN = "E2E-MANUAL-UPLOAD-MD-CANARY-7f3a"
|
||||
MANUAL_UPLOAD_MD_CANARY_FILE = "canary.md"
|
||||
MANUAL_UPLOAD_PDF_CANARY_TOKEN = "E2E-MANUAL-UPLOAD-PDF-CANARY-9d2b"
|
||||
MANUAL_UPLOAD_PDF_CANARY_FILE = "canary.pdf"
|
||||
NO_RELEVANT_CONTENT_SENTINEL = "No relevant indexed content found."
|
||||
NO_RELEVANT_CONTENT_QUERY = "E2E_NO_RELEVANT_CONTENT_SMOKE"
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return " ".join(_content_to_text(item) for item in content)
|
||||
if isinstance(content, dict):
|
||||
text = content.get("text") or content.get("content")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return json.dumps(content, sort_keys=True)
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def _messages_to_text(messages: list[BaseMessage]) -> str:
|
||||
return "\n".join(_content_to_text(message.content) for message in messages)
|
||||
|
||||
|
||||
def _contains_any(text: str, needles: tuple[str, ...]) -> bool:
|
||||
lowered = text.lower()
|
||||
return any(needle.lower() in lowered for needle in needles)
|
||||
|
||||
|
||||
def _latest_tool_message(messages: list[BaseMessage]) -> BaseMessage | None:
|
||||
return next(
|
||||
(message for message in reversed(messages) if message.type == "tool"), None
|
||||
)
|
||||
|
||||
|
||||
class FakeChatLLM(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "e2e-fake-chat"
|
||||
|
||||
def bind_tools(self, tools: Any, **kwargs: Any) -> Self:
|
||||
return self
|
||||
|
||||
def _response_for(self, messages: list[BaseMessage]) -> str:
|
||||
latest_human = next(
|
||||
(
|
||||
_content_to_text(message.content)
|
||||
for message in reversed(messages)
|
||||
if message.type == "human"
|
||||
),
|
||||
"",
|
||||
)
|
||||
if NO_RELEVANT_CONTENT_QUERY in latest_human:
|
||||
return NO_RELEVANT_CONTENT_SENTINEL
|
||||
|
||||
prompt_text = _messages_to_text(messages)
|
||||
latest_tool = _latest_tool_message(messages)
|
||||
latest_tool_name = getattr(latest_tool, "name", None)
|
||||
latest_tool_text = _content_to_text(latest_tool.content) if latest_tool else ""
|
||||
|
||||
if (
|
||||
latest_tool_name == "read_gmail_email"
|
||||
and GMAIL_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"Gmail live tool content found: {GMAIL_CANARY_TOKEN}"
|
||||
if (
|
||||
latest_tool_name == "search_gmail"
|
||||
and GMAIL_CANARY_MESSAGE_ID in latest_tool_text
|
||||
):
|
||||
return "Reading the matching Gmail message next."
|
||||
if (
|
||||
latest_tool_name == "search_calendar_events"
|
||||
and CALENDAR_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"Calendar live tool content found: {CALENDAR_CANARY_TOKEN}"
|
||||
if (
|
||||
latest_tool_name == "list_issues"
|
||||
and LINEAR_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"Linear live tool content found: {LINEAR_CANARY_TOKEN}"
|
||||
if (
|
||||
latest_tool_name == "searchJiraIssuesUsingJql"
|
||||
and JIRA_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"Jira live tool content found: {JIRA_CANARY_TOKEN}"
|
||||
if (
|
||||
latest_tool_name == "slack_search_channels"
|
||||
and SLACK_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"Slack live tool content found: {SLACK_CANARY_TOKEN}"
|
||||
if (
|
||||
latest_tool_name in {"clickup_search", "clickup_get_task"}
|
||||
and CLICKUP_CANARY_TOKEN in latest_tool_text
|
||||
):
|
||||
return f"ClickUp live tool content found: {CLICKUP_CANARY_TOKEN}"
|
||||
|
||||
wants_gmail = _contains_any(
|
||||
latest_human,
|
||||
("gmail", "email", "message", GMAIL_CANARY_SUBJECT),
|
||||
)
|
||||
wants_calendar = _contains_any(
|
||||
latest_human,
|
||||
("calendar", "event", "meeting", CALENDAR_CANARY_SUMMARY),
|
||||
)
|
||||
wants_drive = _contains_any(
|
||||
latest_human,
|
||||
("drive", "file", "e2e-canary.txt"),
|
||||
)
|
||||
wants_drive_pdf = _contains_any(
|
||||
latest_human,
|
||||
(
|
||||
"drive pdf",
|
||||
DRIVE_PDF_CANARY_FILE,
|
||||
DRIVE_PDF_CANARY_TOKEN,
|
||||
COMPOSIO_DRIVE_PDF_CANARY_FILE,
|
||||
COMPOSIO_DRIVE_PDF_CANARY_TOKEN,
|
||||
),
|
||||
) or (wants_drive and "pdf" in latest_human.lower())
|
||||
wants_onedrive = _contains_any(
|
||||
latest_human,
|
||||
("onedrive", ONEDRIVE_CANARY_FILE, ONEDRIVE_CANARY_TOKEN),
|
||||
)
|
||||
wants_onedrive_pdf = wants_onedrive and _contains_any(
|
||||
latest_human,
|
||||
("pdf", ONEDRIVE_PDF_CANARY_FILE, ONEDRIVE_PDF_CANARY_TOKEN),
|
||||
)
|
||||
wants_dropbox = _contains_any(
|
||||
latest_human,
|
||||
("dropbox", DROPBOX_CANARY_FILE, DROPBOX_CANARY_TOKEN),
|
||||
)
|
||||
wants_dropbox_pdf = wants_dropbox and _contains_any(
|
||||
latest_human,
|
||||
("pdf", DROPBOX_PDF_CANARY_FILE, DROPBOX_PDF_CANARY_TOKEN),
|
||||
)
|
||||
wants_notion = _contains_any(
|
||||
latest_human,
|
||||
("notion", "page", NOTION_CANARY_TITLE),
|
||||
)
|
||||
wants_confluence = _contains_any(
|
||||
latest_human,
|
||||
("confluence", CONFLUENCE_CANARY_TITLE),
|
||||
)
|
||||
wants_linear = _contains_any(
|
||||
latest_human,
|
||||
("linear", "issue", LINEAR_CANARY_TITLE),
|
||||
)
|
||||
wants_jira = _contains_any(
|
||||
latest_human,
|
||||
(
|
||||
"jira",
|
||||
"atlassian",
|
||||
JIRA_CANARY_SUMMARY,
|
||||
JIRA_CANARY_KEY,
|
||||
"surfsense-e2e.atlassian.net",
|
||||
"fake-jira-cloud-001",
|
||||
),
|
||||
)
|
||||
wants_slack = _contains_any(
|
||||
latest_human,
|
||||
("slack", SLACK_CANARY_TOKEN),
|
||||
)
|
||||
wants_clickup = _contains_any(
|
||||
latest_human,
|
||||
("clickup", CLICKUP_CANARY_TITLE),
|
||||
)
|
||||
wants_manual_upload = _contains_any(
|
||||
latest_human,
|
||||
(
|
||||
"uploaded",
|
||||
"manual upload",
|
||||
MANUAL_UPLOAD_MD_CANARY_FILE,
|
||||
MANUAL_UPLOAD_PDF_CANARY_FILE,
|
||||
MANUAL_UPLOAD_MD_CANARY_TOKEN,
|
||||
MANUAL_UPLOAD_PDF_CANARY_TOKEN,
|
||||
),
|
||||
)
|
||||
wants_manual_upload_pdf = wants_manual_upload and _contains_any(
|
||||
latest_human,
|
||||
("pdf", MANUAL_UPLOAD_PDF_CANARY_FILE, MANUAL_UPLOAD_PDF_CANARY_TOKEN),
|
||||
)
|
||||
wants_manual_upload_md = wants_manual_upload and _contains_any(
|
||||
latest_human,
|
||||
(
|
||||
"markdown",
|
||||
".md",
|
||||
MANUAL_UPLOAD_MD_CANARY_FILE,
|
||||
MANUAL_UPLOAD_MD_CANARY_TOKEN,
|
||||
),
|
||||
)
|
||||
has_gmail_evidence = (
|
||||
GMAIL_CANARY_SUBJECT in prompt_text
|
||||
or GMAIL_CANARY_MESSAGE_ID in prompt_text
|
||||
or GMAIL_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_calendar_evidence = (
|
||||
CALENDAR_CANARY_SUMMARY in prompt_text
|
||||
or CALENDAR_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_drive_evidence = (
|
||||
"e2e-canary.txt" in prompt_text
|
||||
or "fake-file-canary" in prompt_text
|
||||
or DRIVE_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_native_drive_pdf_evidence = (
|
||||
DRIVE_PDF_CANARY_FILE in prompt_text
|
||||
or "fake-file-pdf-native" in prompt_text
|
||||
or DRIVE_PDF_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_composio_drive_pdf_evidence = (
|
||||
COMPOSIO_DRIVE_PDF_CANARY_FILE in prompt_text
|
||||
or "fake-file-pdf-composio" in prompt_text
|
||||
or COMPOSIO_DRIVE_PDF_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_onedrive_evidence = (
|
||||
ONEDRIVE_CANARY_FILE in prompt_text
|
||||
or "fake-onedrive-canary" in prompt_text
|
||||
or ONEDRIVE_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_onedrive_pdf_evidence = (
|
||||
ONEDRIVE_PDF_CANARY_FILE in prompt_text
|
||||
or "fake-onedrive-pdf-canary" in prompt_text
|
||||
or ONEDRIVE_PDF_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_dropbox_evidence = (
|
||||
DROPBOX_CANARY_FILE in prompt_text
|
||||
or "id:fake-dropbox-canary" in prompt_text
|
||||
or DROPBOX_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_dropbox_pdf_evidence = (
|
||||
DROPBOX_PDF_CANARY_FILE in prompt_text
|
||||
or "id:fake-dropbox-pdf-canary" in prompt_text
|
||||
or DROPBOX_PDF_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_notion_evidence = (
|
||||
NOTION_CANARY_TITLE in prompt_text or NOTION_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_confluence_evidence = (
|
||||
CONFLUENCE_CANARY_TITLE in prompt_text
|
||||
or CONFLUENCE_CANARY_TOKEN in prompt_text
|
||||
or "fake-confluence-page-canary-001" in prompt_text
|
||||
or "fake-confluence-space-001" in prompt_text
|
||||
)
|
||||
has_linear_evidence = (
|
||||
LINEAR_CANARY_TITLE in prompt_text
|
||||
or LINEAR_CANARY_TOKEN in prompt_text
|
||||
or "fake-linear-issue-canary-001" in prompt_text
|
||||
)
|
||||
has_jira_evidence = (
|
||||
JIRA_CANARY_SUMMARY in prompt_text
|
||||
or JIRA_CANARY_TOKEN in prompt_text
|
||||
or JIRA_CANARY_KEY in prompt_text
|
||||
or "fake-jira-issue-canary-001" in prompt_text
|
||||
or "fake-jira-cloud-001" in prompt_text
|
||||
or "surfsense-e2e.atlassian.net" in prompt_text
|
||||
)
|
||||
has_slack_evidence = (
|
||||
SLACK_CANARY_CHANNEL in prompt_text
|
||||
or SLACK_CANARY_TOKEN in prompt_text
|
||||
or "C_FAKE_SLACK_CANARY" in prompt_text
|
||||
or "T_FAKE_SLACK_TEAM" in prompt_text
|
||||
)
|
||||
has_clickup_evidence = (
|
||||
CLICKUP_CANARY_TITLE in prompt_text
|
||||
or CLICKUP_CANARY_TOKEN in prompt_text
|
||||
or CLICKUP_CANARY_TASK_ID in prompt_text
|
||||
)
|
||||
has_manual_upload_md_evidence = (
|
||||
MANUAL_UPLOAD_MD_CANARY_FILE in prompt_text
|
||||
or MANUAL_UPLOAD_MD_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
has_manual_upload_pdf_evidence = (
|
||||
MANUAL_UPLOAD_PDF_CANARY_FILE in prompt_text
|
||||
or MANUAL_UPLOAD_PDF_CANARY_TOKEN in prompt_text
|
||||
)
|
||||
|
||||
if wants_clickup and has_clickup_evidence:
|
||||
return f"ClickUp content found: {CLICKUP_CANARY_TOKEN}"
|
||||
if wants_slack and has_slack_evidence:
|
||||
return f"Slack content found: {SLACK_CANARY_TOKEN}"
|
||||
if wants_jira and has_jira_evidence:
|
||||
return f"Jira content found: {JIRA_CANARY_TOKEN}"
|
||||
if wants_linear and has_linear_evidence:
|
||||
return f"Linear content found: {LINEAR_CANARY_TOKEN}"
|
||||
if wants_confluence and has_confluence_evidence:
|
||||
return f"Confluence content found: {CONFLUENCE_CANARY_TOKEN}"
|
||||
if wants_notion and has_notion_evidence:
|
||||
return f"Notion content found: {NOTION_CANARY_TOKEN}"
|
||||
if wants_calendar and has_calendar_evidence:
|
||||
return f"Calendar content found: {CALENDAR_CANARY_TOKEN}"
|
||||
if wants_gmail and has_gmail_evidence:
|
||||
return f"Gmail content found: {GMAIL_CANARY_TOKEN}"
|
||||
if wants_onedrive_pdf and has_onedrive_pdf_evidence:
|
||||
return f"OneDrive PDF content found: {ONEDRIVE_PDF_CANARY_TOKEN}"
|
||||
if wants_onedrive and has_onedrive_evidence:
|
||||
return f"OneDrive content found: {ONEDRIVE_CANARY_TOKEN}"
|
||||
if wants_dropbox_pdf and has_dropbox_pdf_evidence:
|
||||
return f"Dropbox PDF content found: {DROPBOX_PDF_CANARY_TOKEN}"
|
||||
if wants_dropbox and has_dropbox_evidence:
|
||||
return f"Dropbox content found: {DROPBOX_CANARY_TOKEN}"
|
||||
if wants_drive_pdf and has_native_drive_pdf_evidence:
|
||||
return f"Drive PDF content found: {DRIVE_PDF_CANARY_TOKEN}"
|
||||
if wants_drive_pdf and has_composio_drive_pdf_evidence:
|
||||
return f"Drive PDF content found: {COMPOSIO_DRIVE_PDF_CANARY_TOKEN}"
|
||||
if wants_drive and has_drive_evidence:
|
||||
return f"Drive content found: {DRIVE_CANARY_TOKEN}"
|
||||
if wants_manual_upload_pdf and has_manual_upload_pdf_evidence:
|
||||
return f"Manual upload PDF content found: {MANUAL_UPLOAD_PDF_CANARY_TOKEN}"
|
||||
if wants_manual_upload_md and has_manual_upload_md_evidence:
|
||||
return f"Manual upload MD content found: {MANUAL_UPLOAD_MD_CANARY_TOKEN}"
|
||||
if (
|
||||
has_notion_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Notion content found: {NOTION_CANARY_TOKEN}"
|
||||
if (
|
||||
has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Confluence content found: {CONFLUENCE_CANARY_TOKEN}"
|
||||
if (
|
||||
has_jira_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Jira content found: {JIRA_CANARY_TOKEN}"
|
||||
if (
|
||||
has_linear_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Linear content found: {LINEAR_CANARY_TOKEN}"
|
||||
if (
|
||||
has_calendar_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Calendar content found: {CALENDAR_CANARY_TOKEN}"
|
||||
if (
|
||||
has_gmail_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Gmail content found: {GMAIL_CANARY_TOKEN}"
|
||||
if (
|
||||
has_onedrive_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"OneDrive content found: {ONEDRIVE_CANARY_TOKEN}"
|
||||
if (
|
||||
has_dropbox_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Dropbox content found: {DROPBOX_CANARY_TOKEN}"
|
||||
if (
|
||||
has_drive_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Drive content found: {DRIVE_CANARY_TOKEN}"
|
||||
if (
|
||||
has_slack_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Slack content found: {SLACK_CANARY_TOKEN}"
|
||||
if (
|
||||
has_clickup_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
):
|
||||
return f"ClickUp content found: {CLICKUP_CANARY_TOKEN}"
|
||||
if (
|
||||
has_manual_upload_pdf_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Manual upload PDF content found: {MANUAL_UPLOAD_PDF_CANARY_TOKEN}"
|
||||
if (
|
||||
has_manual_upload_md_evidence
|
||||
and not has_confluence_evidence
|
||||
and not has_jira_evidence
|
||||
and not has_linear_evidence
|
||||
and not has_notion_evidence
|
||||
and not has_calendar_evidence
|
||||
and not has_gmail_evidence
|
||||
and not has_drive_evidence
|
||||
and not has_onedrive_evidence
|
||||
and not has_dropbox_evidence
|
||||
and not has_slack_evidence
|
||||
and not has_clickup_evidence
|
||||
):
|
||||
return f"Manual upload MD content found: {MANUAL_UPLOAD_MD_CANARY_TOKEN}"
|
||||
return NO_RELEVANT_CONTENT_SENTINEL
|
||||
|
||||
def _tool_call_message_for(self, messages: list[BaseMessage]) -> AIMessage | None:
|
||||
latest_human = next(
|
||||
(
|
||||
_content_to_text(message.content)
|
||||
for message in reversed(messages)
|
||||
if message.type == "human"
|
||||
),
|
||||
"",
|
||||
)
|
||||
latest_tool = _latest_tool_message(messages)
|
||||
latest_tool_name = getattr(latest_tool, "name", None)
|
||||
latest_tool_text = _content_to_text(latest_tool.content) if latest_tool else ""
|
||||
|
||||
if (
|
||||
latest_tool_name == "search_gmail"
|
||||
and GMAIL_CANARY_MESSAGE_ID in latest_tool_text
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "read_gmail_email",
|
||||
"args": {"message_id": GMAIL_CANARY_MESSAGE_ID},
|
||||
"id": "call_e2e_read_gmail",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
("gmail", "email", "message", GMAIL_CANARY_SUBJECT),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search_gmail",
|
||||
"args": {
|
||||
"query": f"subject:{GMAIL_CANARY_SUBJECT}",
|
||||
"max_results": 5,
|
||||
},
|
||||
"id": "call_e2e_search_gmail",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
("calendar", "event", "meeting", CALENDAR_CANARY_SUMMARY),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search_calendar_events",
|
||||
"args": {
|
||||
"start_date": "2026-05-07",
|
||||
"end_date": "2026-05-21",
|
||||
"max_results": 10,
|
||||
},
|
||||
"id": "call_e2e_search_calendar_events",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
(
|
||||
"jira",
|
||||
"atlassian",
|
||||
JIRA_CANARY_SUMMARY,
|
||||
JIRA_CANARY_KEY,
|
||||
"surfsense-e2e.atlassian.net",
|
||||
"fake-jira-cloud-001",
|
||||
),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "searchJiraIssuesUsingJql",
|
||||
"args": {
|
||||
"jql": f'summary ~ "{JIRA_CANARY_SUMMARY}"',
|
||||
"maxResults": 5,
|
||||
},
|
||||
"id": "call_e2e_search_jira_issues",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
("linear", "issue", LINEAR_CANARY_TITLE),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "list_issues",
|
||||
"args": {"query": LINEAR_CANARY_TITLE, "limit": 5},
|
||||
"id": "call_e2e_list_linear_issues",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
("slack", SLACK_CANARY_TOKEN),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "slack_search_channels",
|
||||
"args": {"query": SLACK_CANARY_CHANNEL, "limit": 5},
|
||||
"id": "call_e2e_search_slack_channels",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if latest_tool is None and _contains_any(
|
||||
latest_human,
|
||||
("clickup", CLICKUP_CANARY_TITLE),
|
||||
):
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "clickup_search",
|
||||
"args": {"query": CLICKUP_CANARY_TITLE, "limit": 5},
|
||||
"id": "call_e2e_search_clickup_tasks",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
del stop, run_manager, kwargs
|
||||
message = self._tool_call_message_for(messages) or AIMessage(
|
||||
content=self._response_for(messages), tool_calls=[]
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
del stop, run_manager, kwargs
|
||||
tool_call_message = self._tool_call_message_for(messages)
|
||||
if tool_call_message:
|
||||
for tool_call in tool_call_message.tool_calls:
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": tool_call["name"],
|
||||
"args": json.dumps(tool_call["args"]),
|
||||
"id": tool_call["id"],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=self._response_for(messages))
|
||||
)
|
||||
|
||||
|
||||
def fake_create_chat_litellm_from_agent_config(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> FakeChatLLM:
|
||||
del args, kwargs
|
||||
return FakeChatLLM()
|
||||
|
||||
|
||||
def fake_create_chat_litellm_from_config(*args: Any, **kwargs: Any) -> FakeChatLLM:
|
||||
del args, kwargs
|
||||
return FakeChatLLM()
|
||||
134
surfsense_backend/tests/e2e/fakes/clickup_module.py
Normal file
134
surfsense_backend/tests/e2e/fakes/clickup_module.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Strict ClickUp MCP OAuth/tool fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "clickup_tasks.json"
|
||||
|
||||
_AUTHORIZATION_URL = "https://mcp.clickup.com/authorize"
|
||||
_REGISTRATION_URL = "https://mcp.clickup.com/register"
|
||||
_TOKEN_URL = "https://mcp.clickup.com/token"
|
||||
_MCP_URL = "https://mcp.clickup.com/mcp"
|
||||
|
||||
_CLIENT_ID = "fake-clickup-mcp-client-id"
|
||||
_CLIENT_SECRET = "fake-clickup-mcp-client-secret"
|
||||
_ACCESS_TOKEN = "fake-clickup-mcp-access-token"
|
||||
_REFRESH_TOKEN = "fake-clickup-mcp-refresh-token"
|
||||
_OAUTH_CODE = "fake-clickup-oauth-code"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
def _task_text(task: dict[str, Any]) -> str:
|
||||
return (
|
||||
f"{task['name']}\n"
|
||||
f"id: {task['id']}\n"
|
||||
f"workspace: {task['workspace_name']} ({task['workspace_id']})\n"
|
||||
f"list: {task['list_name']}\n"
|
||||
f"status: {task['status']}\n"
|
||||
f"description: {task['description']}"
|
||||
)
|
||||
|
||||
|
||||
async def _list_tools() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="clickup_search",
|
||||
description="Search ClickUp tasks visible to the authenticated user.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Text to search for in ClickUp tasks.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of tasks to return.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="clickup_get_task",
|
||||
description="Get a ClickUp task by id.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "ClickUp task id.",
|
||||
}
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def _call_tool(
|
||||
tool_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> SimpleNamespace:
|
||||
arguments = arguments or {}
|
||||
task = _FIXTURE["tasks"][0]
|
||||
|
||||
if tool_name == "clickup_search":
|
||||
query = str(arguments.get("query", ""))
|
||||
if query and task["name"].lower() not in query.lower():
|
||||
raise ValueError(f"Unexpected ClickUp task query: {query!r}")
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=_task_text(task))])
|
||||
|
||||
if tool_name == "clickup_get_task":
|
||||
task_id = arguments.get("task_id")
|
||||
if task_id != task["id"]:
|
||||
raise ValueError(f"Unexpected ClickUp task id: {task_id!r}")
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=_task_text(task))])
|
||||
|
||||
raise NotImplementedError(f"Unexpected ClickUp MCP tool call: {tool_name!r}")
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Register ClickUp MCP OAuth/tool handlers with the shared dispatchers."""
|
||||
del active_patches
|
||||
mcp_oauth_runtime.register_service(
|
||||
mcp_url=_MCP_URL,
|
||||
discovery_metadata={
|
||||
"issuer": "https://mcp.clickup.com",
|
||||
"authorization_endpoint": _AUTHORIZATION_URL,
|
||||
"token_endpoint": _TOKEN_URL,
|
||||
"registration_endpoint": _REGISTRATION_URL,
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"response_types_supported": ["code"],
|
||||
},
|
||||
client_id=_CLIENT_ID,
|
||||
client_secret=_CLIENT_SECRET,
|
||||
token_endpoint=_TOKEN_URL,
|
||||
registration_endpoint=_REGISTRATION_URL,
|
||||
oauth_code=_OAUTH_CODE,
|
||||
access_token=_ACCESS_TOKEN,
|
||||
refresh_token=_REFRESH_TOKEN,
|
||||
scope="read write",
|
||||
redirect_uri_substring="/api/v1/auth/mcp/clickup/connector/callback",
|
||||
)
|
||||
mcp_runtime.register(
|
||||
url=_MCP_URL,
|
||||
expected_bearer=_ACCESS_TOKEN,
|
||||
list_tools=_list_tools,
|
||||
call_tool=_call_tool,
|
||||
)
|
||||
559
surfsense_backend/tests/e2e/fakes/composio_module.py
Normal file
559
surfsense_backend/tests/e2e/fakes/composio_module.py
Normal file
|
|
@ -0,0 +1,559 @@
|
|||
"""Strict drop-in replacement for the `composio` Python SDK.
|
||||
|
||||
Registered as `sys.modules["composio"]` by `tests/e2e/run_backend.py`
|
||||
and `tests/e2e/run_celery.py` BEFORE any production code imports
|
||||
`composio`. From that point on, every `from composio import Composio`
|
||||
in production resolves to `Composio` defined here.
|
||||
|
||||
Every class implements __getattr__ that raises NotImplementedError on
|
||||
unknown attributes. A future production code path that introduces a
|
||||
new SDK call (e.g. `client.bulk_operations.run`) fails CI loudly with
|
||||
a clear "add this surface to the fake" message instead of silently
|
||||
passing through to the real SDK.
|
||||
|
||||
Scenario branching is read from the request-scoped ContextVar in
|
||||
`tests/e2e/middleware/scenario.py`, set by the X-E2E-Scenario header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .binary_loader import _resolve_file_bytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
_DRIVE_FIXTURE_PATH = _FIXTURES_DIR / "drive_files.json"
|
||||
_GMAIL_FIXTURE_PATH = _FIXTURES_DIR / "gmail_messages.json"
|
||||
_CALENDAR_FIXTURE_PATH = _FIXTURES_DIR / "calendar_events.json"
|
||||
_DRIVE_DOWNLOAD_DIR = Path("/tmp/surfsense-e2e-composio-downloads")
|
||||
|
||||
|
||||
def _load_drive_fixture() -> dict[str, Any]:
|
||||
"""Load the canned Drive fixture once per process."""
|
||||
with _DRIVE_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _load_gmail_fixture() -> dict[str, Any]:
|
||||
"""Load the canned Gmail fixture once per process."""
|
||||
with _GMAIL_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _load_calendar_fixture() -> dict[str, Any]:
|
||||
"""Load the canned Calendar fixture once per process."""
|
||||
with _CALENDAR_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_DRIVE_FIXTURE = _load_drive_fixture()
|
||||
_GMAIL_FIXTURE = _load_gmail_fixture()
|
||||
_CALENDAR_FIXTURE = _load_calendar_fixture()
|
||||
|
||||
|
||||
def _get_scenario() -> str:
|
||||
"""Return the current X-E2E-Scenario, defaulting to 'happy'.
|
||||
|
||||
Imported lazily so the fake module can be loaded BEFORE
|
||||
`tests.e2e.middleware.scenario` is importable (during sys.modules
|
||||
hijack at the very top of the entrypoint).
|
||||
"""
|
||||
try:
|
||||
from tests.e2e.middleware.scenario import current_scenario
|
||||
|
||||
return current_scenario()
|
||||
except Exception:
|
||||
return "happy"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strict mixin: every fake class raises on unknown attribute access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
"""Base class for fakes. Any unknown attribute access fails loudly."""
|
||||
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E Composio fake missing surface: {self._component_name}.{name!r}. "
|
||||
f"If production code needs this, add an explicit method to "
|
||||
f"surfsense_backend/tests/e2e/fakes/composio_module.py — "
|
||||
f"the strict fake refuses to silently fall through to the real SDK."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result objects mimicking the real SDK's response dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ConnectionRequest:
|
||||
"""Mimics composio.connected_accounts.initiate(...) return value."""
|
||||
|
||||
def __init__(self, *, redirect_url: str, account_id: str) -> None:
|
||||
self.redirect_url = redirect_url
|
||||
self.id = account_id
|
||||
|
||||
|
||||
class _ConnectedAccount:
|
||||
"""Mimics a connected_account row returned by wait_for_connection / refresh."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
account_id: str,
|
||||
status: str = "ACTIVE",
|
||||
redirect_url: str | None = None,
|
||||
access_token: str = "fake-e2e-access-token-not-real-do-not-use-32chars",
|
||||
) -> None:
|
||||
self.id = account_id
|
||||
self.status = status
|
||||
self.redirect_url = redirect_url
|
||||
self.state = _AccountState(access_token=access_token)
|
||||
|
||||
|
||||
class _AccountState:
|
||||
def __init__(self, *, access_token: str) -> None:
|
||||
self.val = _AccountStateVal(access_token=access_token)
|
||||
|
||||
|
||||
class _AccountStateVal:
|
||||
def __init__(self, *, access_token: str) -> None:
|
||||
self.access_token = access_token
|
||||
|
||||
|
||||
class _AuthConfig:
|
||||
"""Mimics one auth_config row returned by client.auth_configs.list().items."""
|
||||
|
||||
def __init__(self, *, config_id: str, toolkit_slug: str) -> None:
|
||||
self.id = config_id
|
||||
self.toolkit = _Toolkit(slug=toolkit_slug)
|
||||
|
||||
|
||||
class _Toolkit:
|
||||
def __init__(self, *, slug: str) -> None:
|
||||
self.slug = slug
|
||||
|
||||
|
||||
class _AuthConfigsListResult:
|
||||
def __init__(self, items: list[_AuthConfig]) -> None:
|
||||
self.items = items
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sub-clients on the Composio top-level object
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ConnectedAccounts(_StrictFakeMixin):
|
||||
"""Strict fake for client.connected_accounts.*"""
|
||||
|
||||
_component_name = "connected_accounts"
|
||||
|
||||
def initiate(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
auth_config_id: str,
|
||||
callback_url: str,
|
||||
allow_multiple: bool = True,
|
||||
**_: Any,
|
||||
) -> _ConnectionRequest:
|
||||
scenario = _get_scenario()
|
||||
# Synthesize a deterministic account ID. Same toolkit on the same
|
||||
# entity yields the same ID so the duplicate-connection scenario
|
||||
# exercises the reconnect branch in composio_routes.py.
|
||||
toolkit_id = auth_config_id.replace("auth-config-", "")
|
||||
account_id = f"fake-acct-{toolkit_id}-{user_id}"
|
||||
|
||||
# The SDK's redirect_url normally points to Composio's hosted OAuth
|
||||
# UI which then bounces to the third-party provider. In E2E we
|
||||
# short-circuit straight back to OUR same-origin callback to avoid
|
||||
# any real network requirement.
|
||||
if scenario == "denied":
|
||||
redirect = (
|
||||
f"{callback_url}&error=access_denied"
|
||||
if "?" in callback_url
|
||||
else f"{callback_url}?error=access_denied"
|
||||
)
|
||||
else:
|
||||
redirect = (
|
||||
f"{callback_url}&connectedAccountId={account_id}"
|
||||
if "?" in callback_url
|
||||
else f"{callback_url}?connectedAccountId={account_id}"
|
||||
)
|
||||
logger.info(
|
||||
"[fake-composio] initiate scenario=%s toolkit=%s redirect=%s",
|
||||
scenario,
|
||||
toolkit_id,
|
||||
redirect,
|
||||
)
|
||||
return _ConnectionRequest(redirect_url=redirect, account_id=account_id)
|
||||
|
||||
def wait_for_connection(
|
||||
self, *, id: str, timeout: float = 30.0, **_: Any
|
||||
) -> _ConnectedAccount:
|
||||
return _ConnectedAccount(account_id=id, status="ACTIVE")
|
||||
|
||||
def get(self, *, nanoid: str, **_: Any) -> _ConnectedAccount:
|
||||
return _ConnectedAccount(account_id=nanoid, status="ACTIVE")
|
||||
|
||||
def delete(self, account_id: str, /, **_: Any) -> dict[str, Any]:
|
||||
logger.info("[fake-composio] delete account=%s", account_id)
|
||||
return {"success": True, "id": account_id}
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
*,
|
||||
nanoid: str,
|
||||
body_redirect_url: str | None = None,
|
||||
**_: Any,
|
||||
) -> _ConnectedAccount:
|
||||
return _ConnectedAccount(
|
||||
account_id=nanoid,
|
||||
status="ACTIVE",
|
||||
redirect_url=body_redirect_url,
|
||||
)
|
||||
|
||||
|
||||
class _AuthConfigs(_StrictFakeMixin):
|
||||
"""Strict fake for client.auth_configs.*"""
|
||||
|
||||
_component_name = "auth_configs"
|
||||
|
||||
def list(self, **_: Any) -> _AuthConfigsListResult:
|
||||
# Return one auth config per toolkit we plan to test. The real
|
||||
# SDK lets you have multiple, but one is enough for E2E.
|
||||
return _AuthConfigsListResult(
|
||||
items=[
|
||||
_AuthConfig(
|
||||
config_id="auth-config-googledrive", toolkit_slug="googledrive"
|
||||
),
|
||||
_AuthConfig(config_id="auth-config-gmail", toolkit_slug="gmail"),
|
||||
_AuthConfig(
|
||||
config_id="auth-config-googlecalendar",
|
||||
toolkit_slug="googlecalendar",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _Tools(_StrictFakeMixin):
|
||||
"""Strict fake for client.tools.*"""
|
||||
|
||||
_component_name = "tools"
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
slug: str,
|
||||
connected_account_id: str,
|
||||
user_id: str | None = None,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
dangerously_skip_version_check: bool = True,
|
||||
**_: Any,
|
||||
) -> dict[str, Any]:
|
||||
scenario = _get_scenario()
|
||||
args = arguments or {}
|
||||
logger.info(
|
||||
"[fake-composio] tools.execute slug=%s scenario=%s args=%s",
|
||||
slug,
|
||||
scenario,
|
||||
list(args.keys()),
|
||||
)
|
||||
|
||||
if scenario == "auth_expired":
|
||||
# Match the error strings that composio_routes.py classifies
|
||||
# as authentication failures (see lines ~720-728).
|
||||
raise _AuthExpiredError(
|
||||
"Token has been expired or revoked. (HTTP 401: invalid_grant)"
|
||||
)
|
||||
|
||||
if slug == "GOOGLEDRIVE_LIST_FILES":
|
||||
return _drive_list_files(args)
|
||||
if slug == "GOOGLEDRIVE_DOWNLOAD_FILE":
|
||||
return _drive_download_file(args)
|
||||
if slug == "GOOGLEDRIVE_GET_FILE_METADATA":
|
||||
return _drive_get_metadata(args)
|
||||
if slug == "GOOGLEDRIVE_GET_CHANGES_START_PAGE_TOKEN":
|
||||
return {"data": {"startPageToken": "fake-start-page-token-1"}}
|
||||
if slug == "GOOGLEDRIVE_LIST_CHANGES":
|
||||
return {
|
||||
"data": {"changes": [], "newStartPageToken": "fake-start-page-token-1"}
|
||||
}
|
||||
if slug == "GOOGLEDRIVE_GET_ABOUT":
|
||||
# Used by ComposioService.get_connected_account_email for
|
||||
# googledrive. Returning a fake email lets the connector get a
|
||||
# nice display name; failure is non-fatal.
|
||||
return {"data": {"user": {"emailAddress": "e2e-fake@surfsense.example"}}}
|
||||
if slug == "GMAIL_GET_PROFILE":
|
||||
return {"data": {"emailAddress": "e2e-fake@surfsense.example"}}
|
||||
if slug == "GMAIL_FETCH_EMAILS":
|
||||
return _gmail_fetch_emails(args)
|
||||
if slug == "GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID":
|
||||
return _gmail_fetch_message_by_message_id(args)
|
||||
if slug == "GOOGLECALENDAR_EVENTS_LIST":
|
||||
return _calendar_events_list(args)
|
||||
if slug == "GOOGLECALENDAR_GET_CALENDAR":
|
||||
return {
|
||||
"data": {
|
||||
"id": "primary",
|
||||
"summary": "e2e-fake@surfsense.example",
|
||||
"primary": True,
|
||||
"timeZone": "UTC",
|
||||
}
|
||||
}
|
||||
if slug == "GOOGLECALENDAR_CALENDARS_LIST":
|
||||
return {
|
||||
"data": {
|
||||
"items": [
|
||||
{
|
||||
"id": "primary",
|
||||
"summary": "e2e-fake@surfsense.example",
|
||||
"primary": True,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# No silent passthrough: a slug we have not modelled is a test bug.
|
||||
raise NotImplementedError(
|
||||
f"E2E Composio fake has no handler for tool slug {slug!r}. "
|
||||
f"Add it to surfsense_backend/tests/e2e/fakes/composio_module.py "
|
||||
f"in `_Tools.execute` if production code needs it."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drive tool handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _drive_list_files(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GOOGLEDRIVE_LIST_FILES.
|
||||
|
||||
The real SDK accepts a Drive-style `q=` query like
|
||||
`'<folder_id>' in parents and trashed = false ...`. We parse out the
|
||||
folder id and serve the matching fixture list.
|
||||
"""
|
||||
q = args.get("q", "")
|
||||
folder_id = "root"
|
||||
if "in parents" in q:
|
||||
# q looks like: '<folder_id>' in parents and trashed = false ...
|
||||
try:
|
||||
folder_id = q.split("'")[1]
|
||||
except IndexError:
|
||||
folder_id = "root"
|
||||
|
||||
files = _filter_drive_files_for_query(q, _DRIVE_FIXTURE.get(folder_id, []))
|
||||
return {
|
||||
"data": {
|
||||
"files": files,
|
||||
"nextPageToken": None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_quoted_value(q: str, anchor: str) -> str | None:
|
||||
anchor_idx = q.find(anchor)
|
||||
if anchor_idx == -1:
|
||||
return None
|
||||
|
||||
after_anchor = q[anchor_idx + len(anchor) :]
|
||||
first_quote_idx = after_anchor.find("'")
|
||||
if first_quote_idx == -1:
|
||||
return None
|
||||
|
||||
after_first_quote = after_anchor[first_quote_idx + 1 :]
|
||||
second_quote_idx = after_first_quote.find("'")
|
||||
if second_quote_idx == -1:
|
||||
return None
|
||||
|
||||
return after_first_quote[:second_quote_idx]
|
||||
|
||||
|
||||
def _filter_drive_files_for_query(
|
||||
q: str, files: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered = list(files)
|
||||
|
||||
if "trashed = false" in q:
|
||||
filtered = [entry for entry in filtered if entry.get("trashed") is not True]
|
||||
|
||||
excluded_mime_type = _extract_quoted_value(q, "mimeType !=")
|
||||
if excluded_mime_type:
|
||||
filtered = [
|
||||
entry for entry in filtered if entry.get("mimeType") != excluded_mime_type
|
||||
]
|
||||
|
||||
included_mime_type = _extract_quoted_value(q, "mimeType =")
|
||||
if included_mime_type:
|
||||
filtered = [
|
||||
entry for entry in filtered if entry.get("mimeType") == included_mime_type
|
||||
]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def _drive_download_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GOOGLEDRIVE_DOWNLOAD_FILE.
|
||||
|
||||
The real SDK writes the downloaded bytes to a local file and returns
|
||||
the path. composio_service.py then reads bytes from that path. We
|
||||
reproduce that behaviour by writing fixture content into a tmp
|
||||
directory and returning the path.
|
||||
"""
|
||||
file_id = args.get("file_id", "")
|
||||
contents = _resolve_file_bytes(_DRIVE_FIXTURE, file_id, _FIXTURES_DIR)
|
||||
if contents is None:
|
||||
# Unknown file id is a test bug, fail loudly.
|
||||
raise NotImplementedError(
|
||||
f"E2E Composio fake has no canned content for file_id={file_id!r}. "
|
||||
f"Add it under '_file_contents' in "
|
||||
f"surfsense_backend/tests/e2e/fakes/fixtures/drive_files.json."
|
||||
)
|
||||
|
||||
_DRIVE_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
metadata = _drive_get_metadata({"file_id": file_id})["data"]
|
||||
file_name = metadata.get("name") or f"{file_id}.txt"
|
||||
out_path = _DRIVE_DOWNLOAD_DIR / file_name
|
||||
out_path.write_bytes(contents)
|
||||
return {
|
||||
"data": {
|
||||
"file_path": str(out_path),
|
||||
"file_name": file_name,
|
||||
"size": len(contents),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _drive_get_metadata(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GOOGLEDRIVE_GET_FILE_METADATA."""
|
||||
file_id = args.get("file_id", "")
|
||||
# Search every folder fixture for the file
|
||||
for items in _DRIVE_FIXTURE.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for entry in items:
|
||||
if entry.get("id") == file_id:
|
||||
return {"data": entry}
|
||||
raise NotImplementedError(
|
||||
f"E2E fake: no metadata fixture for file_id={file_id!r}. "
|
||||
f"Add it to drive_files.json."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gmail tool handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _gmail_fetch_emails(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GMAIL_FETCH_EMAILS.
|
||||
|
||||
The production indexer uses this as a list page, then calls
|
||||
GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID for the full body of each id.
|
||||
"""
|
||||
del args
|
||||
messages = list(_GMAIL_FIXTURE.get("messages", []))
|
||||
return {
|
||||
"data": {
|
||||
"messages": messages,
|
||||
"nextPageToken": None,
|
||||
"resultSizeEstimate": len(messages),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _gmail_fetch_message_by_message_id(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GMAIL_FETCH_MESSAGE_BY_MESSAGE_ID."""
|
||||
message_id = args.get("message_id", "")
|
||||
details = _GMAIL_FIXTURE.get("details", {})
|
||||
detail = details.get(message_id)
|
||||
if detail is None:
|
||||
raise NotImplementedError(
|
||||
f"E2E Composio fake has no Gmail detail fixture for "
|
||||
f"message_id={message_id!r}. Add it under 'details' in "
|
||||
f"surfsense_backend/tests/e2e/fakes/fixtures/gmail_messages.json."
|
||||
)
|
||||
return {"data": detail}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Google Calendar tool handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _calendar_events_list(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Mimic GOOGLECALENDAR_EVENTS_LIST for live calendar chat tools."""
|
||||
max_results = int(args.get("max_results", 250) or 250)
|
||||
items = list(_CALENDAR_FIXTURE.get("items", []))[:max_results]
|
||||
return {
|
||||
"data": {
|
||||
"items": items,
|
||||
"summary": "e2e-fake@surfsense.example",
|
||||
"timeZone": "UTC",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AuthExpiredError(Exception):
|
||||
"""Raised by the fake when scenario=auth_expired.
|
||||
|
||||
composio_service.execute_tool catches every exception and surfaces
|
||||
str(error) inside the result dict; composio_routes.py then classifies
|
||||
"expired or revoked" / "401" tokens and sets connector.config.auth_expired.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level Composio class — the only public symbol production imports
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Composio(_StrictFakeMixin):
|
||||
"""Drop-in replacement for `composio.Composio`.
|
||||
|
||||
Production calls: `Composio(api_key=..., file_download_dir=...)`
|
||||
"""
|
||||
|
||||
_component_name = "Composio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
file_download_dir: str | None = None,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.file_download_dir = file_download_dir
|
||||
self.connected_accounts = _ConnectedAccounts()
|
||||
self.tools = _Tools()
|
||||
self.auth_configs = _AuthConfigs()
|
||||
logger.info(
|
||||
"[fake-composio] Composio() constructed (E2E mode, no real network)"
|
||||
)
|
||||
|
||||
|
||||
# Public re-exports so `from composio import Composio` resolves correctly
|
||||
# when this module is registered as sys.modules["composio"].
|
||||
__all__ = ["Composio"]
|
||||
58
surfsense_backend/tests/e2e/fakes/confluence_indexer.py
Normal file
58
surfsense_backend/tests/e2e/fakes/confluence_indexer.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Strict Confluence indexer fake for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "confluence_pages.json"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
class _FakeConfluenceHistoryConnector:
|
||||
def __init__(self, *, session: Any, connector_id: int, **kwargs: Any):
|
||||
del session, kwargs
|
||||
self.connector_id = connector_id
|
||||
|
||||
async def get_pages_by_date_range(
|
||||
self,
|
||||
*,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
include_comments: bool = False,
|
||||
) -> tuple[list[dict[str, Any]], None]:
|
||||
if not start_date or not end_date:
|
||||
raise ValueError(
|
||||
"Confluence indexer fake expected start_date and end_date."
|
||||
)
|
||||
del include_comments
|
||||
return _FIXTURE["pages"], None
|
||||
|
||||
async def close(self) -> None:
|
||||
return None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
"E2E Confluence indexer fake missing surface: "
|
||||
f"ConfluenceHistoryConnector.{name!r}. "
|
||||
"Add it to surfsense_backend/tests/e2e/fakes/confluence_indexer.py."
|
||||
)
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch only the Confluence indexer's bound connector class."""
|
||||
p = patch(
|
||||
"app.tasks.connector_indexers.confluence_indexer.ConfluenceHistoryConnector",
|
||||
_FakeConfluenceHistoryConnector,
|
||||
)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
146
surfsense_backend/tests/e2e/fakes/confluence_oauth.py
Normal file
146
surfsense_backend/tests/e2e/fakes/confluence_oauth.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Strict Confluence OAuth fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "confluence_pages.json"
|
||||
|
||||
_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
_RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
_ACCESS_TOKEN = "fake-confluence-access-token"
|
||||
_REFRESH_TOKEN = "fake-confluence-refresh-token"
|
||||
_OAUTH_CODE = "fake-confluence-oauth-code"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E Confluence OAuth fake missing surface: "
|
||||
f"{self._component_name}.{name!r}. "
|
||||
"Add it to surfsense_backend/tests/e2e/fakes/confluence_oauth.py."
|
||||
)
|
||||
|
||||
|
||||
class _FakeResponse(_StrictFakeMixin):
|
||||
_component_name = "httpx.Response"
|
||||
|
||||
def __init__(self, payload: Any, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(payload, sort_keys=True)
|
||||
|
||||
def json(self) -> Any:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _FakeHttpxAsyncClient(_StrictFakeMixin):
|
||||
_component_name = "httpx.AsyncClient"
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
del args, kwargs
|
||||
|
||||
async def __aenter__(self) -> _FakeHttpxAsyncClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def post(self, url: str, **kwargs: Any) -> _FakeResponse:
|
||||
if url != _TOKEN_URL:
|
||||
raise NotImplementedError(f"Unexpected Confluence OAuth POST url={url!r}")
|
||||
|
||||
data = kwargs.get("json") or {}
|
||||
headers = kwargs.get("headers") or {}
|
||||
if headers.get("Content-Type") != "application/json":
|
||||
raise ValueError("Confluence OAuth token exchange expected JSON headers.")
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
if grant_type == "authorization_code":
|
||||
if data.get("code") != _OAUTH_CODE:
|
||||
raise ValueError(
|
||||
f"Unexpected fake Confluence OAuth code: {data.get('code')!r}"
|
||||
)
|
||||
if not data.get("client_id") or not data.get("client_secret"):
|
||||
raise ValueError(
|
||||
"Confluence OAuth token exchange missing client creds."
|
||||
)
|
||||
if "/api/v1/auth/confluence/connector/callback" not in str(
|
||||
data.get("redirect_uri", "")
|
||||
):
|
||||
raise ValueError(
|
||||
"Confluence OAuth token exchange got unexpected redirect_uri: "
|
||||
f"{data.get('redirect_uri')!r}"
|
||||
)
|
||||
elif grant_type == "refresh_token":
|
||||
if data.get("refresh_token") != _REFRESH_TOKEN:
|
||||
raise ValueError(
|
||||
"Unexpected fake Confluence refresh token: "
|
||||
f"{data.get('refresh_token')!r}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected fake Confluence grant_type: {grant_type!r}")
|
||||
|
||||
return _FakeResponse(
|
||||
{
|
||||
"access_token": _ACCESS_TOKEN,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"expires_in": 3600,
|
||||
"scope": "read:confluence-user read:space:confluence read:page:confluence",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
async def get(self, url: str, **kwargs: Any) -> _FakeResponse:
|
||||
if url != _RESOURCES_URL:
|
||||
raise NotImplementedError(f"Unexpected Confluence OAuth GET url={url!r}")
|
||||
|
||||
headers = kwargs.get("headers") or {}
|
||||
auth = headers.get("Authorization")
|
||||
if auth != f"Bearer {_ACCESS_TOKEN}":
|
||||
raise ValueError(f"Unexpected Confluence resources Authorization: {auth!r}")
|
||||
|
||||
site = _FIXTURE["site"]
|
||||
return _FakeResponse(
|
||||
[
|
||||
{
|
||||
"id": site["cloud_id"],
|
||||
"name": site["name"],
|
||||
"url": site["url"],
|
||||
"scopes": [
|
||||
"read:confluence-user",
|
||||
"read:space:confluence",
|
||||
"read:page:confluence",
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _FakeHttpxModule(_StrictFakeMixin):
|
||||
_component_name = "httpx"
|
||||
|
||||
AsyncClient = _FakeHttpxAsyncClient
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch only Confluence route-local HTTP OAuth calls."""
|
||||
p = patch(
|
||||
"app.routes.confluence_add_connector_route.httpx",
|
||||
_FakeHttpxModule(),
|
||||
)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
182
surfsense_backend/tests/e2e/fakes/dropbox_api.py
Normal file
182
surfsense_backend/tests/e2e/fakes/dropbox_api.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""Strict Dropbox HTTP/API fakes for Playwright E2E.
|
||||
|
||||
This module patches the Dropbox OAuth route and indexer consumer-site
|
||||
bindings. It keeps the production add/callback/indexing flow intact while
|
||||
serving deterministic Dropbox-shaped token, profile, metadata, and file
|
||||
content responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
from .binary_loader import _resolve_file_bytes
|
||||
|
||||
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
_DROPBOX_FIXTURE_PATH = _FIXTURES_DIR / "dropbox_files.json"
|
||||
|
||||
|
||||
def _load_dropbox_fixture() -> dict[str, Any]:
|
||||
with _DROPBOX_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_DROPBOX_FIXTURE = _load_dropbox_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E Dropbox fake missing surface: "
|
||||
f"{self._component_name}.{name!r}. Add it to "
|
||||
f"surfsense_backend/tests/e2e/fakes/dropbox_api.py."
|
||||
)
|
||||
|
||||
|
||||
class _FakeDropboxClient(_StrictFakeMixin):
|
||||
_component_name = "DropboxClient"
|
||||
|
||||
def __init__(self, session: Any, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
return "fake-dropbox-access-token"
|
||||
|
||||
async def list_folder(
|
||||
self, path: str = ""
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
items = _DROPBOX_FIXTURE.get(path)
|
||||
if not isinstance(items, list):
|
||||
return [], f"E2E Dropbox fake has no folder for path={path!r}."
|
||||
return [dict(item) for item in items], None
|
||||
|
||||
async def get_latest_cursor(self, path: str = "") -> tuple[str | None, str | None]:
|
||||
if path not in _DROPBOX_FIXTURE:
|
||||
return None, f"E2E Dropbox fake has no cursor for path={path!r}."
|
||||
return f"fake-dropbox-cursor:{path or 'root'}", None
|
||||
|
||||
async def get_changes(
|
||||
self, cursor: str
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
return [], cursor, None
|
||||
|
||||
async def get_metadata(self, path: str) -> tuple[dict[str, Any] | None, str | None]:
|
||||
metadata = _dropbox_get_metadata(path)
|
||||
if metadata is None:
|
||||
return None, f"E2E Dropbox fake has no metadata for path={path!r}."
|
||||
return metadata, None
|
||||
|
||||
async def download_file(self, path: str) -> tuple[bytes | None, str | None]:
|
||||
content = _resolve_file_bytes(_DROPBOX_FIXTURE, path, _FIXTURES_DIR)
|
||||
if content is None:
|
||||
return None, f"E2E Dropbox fake has no content for path={path!r}."
|
||||
return content, None
|
||||
|
||||
async def download_file_to_disk(self, path: str, dest_path: str) -> str | None:
|
||||
content = _resolve_file_bytes(_DROPBOX_FIXTURE, path, _FIXTURES_DIR)
|
||||
if content is None:
|
||||
return f"E2E Dropbox fake has no content for path={path!r}."
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(content)
|
||||
return None
|
||||
|
||||
async def get_current_account(self) -> tuple[dict[str, Any] | None, str | None]:
|
||||
return _dropbox_current_account(), None
|
||||
|
||||
|
||||
class _FakeAsyncClient(_StrictFakeMixin):
|
||||
_component_name = "httpx.AsyncClient"
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
del args, kwargs
|
||||
|
||||
async def __aenter__(self) -> _FakeAsyncClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
|
||||
del args, kwargs
|
||||
if url == "https://api.dropboxapi.com/oauth2/token":
|
||||
return _json_response(
|
||||
"POST",
|
||||
url,
|
||||
{
|
||||
"access_token": "fake-dropbox-access-token",
|
||||
"refresh_token": "fake-dropbox-refresh-token",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"account_id": "dbid:fake-dropbox-account",
|
||||
},
|
||||
)
|
||||
if url == "https://api.dropboxapi.com/2/users/get_current_account":
|
||||
return _json_response("POST", url, _dropbox_current_account())
|
||||
raise NotImplementedError(f"E2E Dropbox fake unexpected POST URL: {url!r}")
|
||||
|
||||
async def request(
|
||||
self, method: str, url: str, *args: Any, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
del args, kwargs
|
||||
raise NotImplementedError(
|
||||
f"E2E Dropbox fake unexpected request: {method!r} {url!r}"
|
||||
)
|
||||
|
||||
|
||||
class _FakeHttpxModule(_StrictFakeMixin):
|
||||
_component_name = "httpx"
|
||||
|
||||
AsyncClient = _FakeAsyncClient
|
||||
|
||||
|
||||
def _json_response(
|
||||
method: str, url: str, payload: dict[str, Any], status_code: int = 200
|
||||
) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
json=payload,
|
||||
request=httpx.Request(method, url),
|
||||
)
|
||||
|
||||
|
||||
def _dropbox_current_account() -> dict[str, Any]:
|
||||
return {
|
||||
"email": "dropbox-e2e@surfsense.example",
|
||||
"name": {"display_name": "SurfSense Dropbox E2E"},
|
||||
"account_id": "dbid:fake-dropbox-account",
|
||||
}
|
||||
|
||||
|
||||
def _dropbox_get_metadata(path: str | None) -> dict[str, Any] | None:
|
||||
for items in _DROPBOX_FIXTURE.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for entry in items:
|
||||
if entry.get("path_lower") == path or entry.get("id") == path:
|
||||
return dict(entry)
|
||||
return None
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch production Dropbox bindings to use strict Dropbox fakes."""
|
||||
targets = [
|
||||
("app.routes.dropbox_add_connector_route.httpx", _FakeHttpxModule()),
|
||||
("app.routes.dropbox_add_connector_route.DropboxClient", _FakeDropboxClient),
|
||||
(
|
||||
"app.tasks.connector_indexers.dropbox_indexer.DropboxClient",
|
||||
_FakeDropboxClient,
|
||||
),
|
||||
("app.connectors.dropbox.client.httpx", _FakeHttpxModule()),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
80
surfsense_backend/tests/e2e/fakes/embeddings.py
Normal file
80
surfsense_backend/tests/e2e/fakes/embeddings.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Deterministic embedding fakes for E2E.
|
||||
|
||||
Mirrors the existing `patched_embed_texts` fixture in
|
||||
`surfsense_backend/tests/integration/conftest.py`:
|
||||
|
||||
MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
|
||||
|
||||
The dimension matches whatever `config.embedding_model_instance.dimension`
|
||||
returns in the running process so the fakes are vector-compatible with
|
||||
the documents.embedding pgvector column.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _embedding_dim() -> int:
|
||||
"""Resolve the dimension once, lazily, so tests work for any embedding model."""
|
||||
return int(config.embedding_model_instance.dimension)
|
||||
|
||||
|
||||
def fake_embed_text(text: str) -> np.ndarray:
|
||||
"""Deterministic single-text embedding."""
|
||||
return np.full(shape=(_embedding_dim(),), fill_value=0.1, dtype=np.float32)
|
||||
|
||||
|
||||
def fake_embed_texts(texts: list[str]) -> list[np.ndarray]:
|
||||
"""Deterministic batch embedding. One vector per input text."""
|
||||
if not texts:
|
||||
return []
|
||||
dim = _embedding_dim()
|
||||
return [np.full(shape=(dim,), fill_value=0.1, dtype=np.float32) for _ in texts]
|
||||
|
||||
|
||||
def install(patches: list[Any]) -> None:
|
||||
"""Install embedding patches at every binding site we know about.
|
||||
|
||||
Caller passes a `patches` list that the entrypoint will track in
|
||||
order to start them (and, in principle, stop them on shutdown — we
|
||||
intentionally never stop because the process exits when the test
|
||||
server stops).
|
||||
"""
|
||||
from unittest.mock import patch as _patch
|
||||
|
||||
targets = [
|
||||
# Source binding (where the real implementation lives)
|
||||
("app.utils.document_converters.embed_text", fake_embed_text),
|
||||
("app.utils.document_converters.embed_texts", fake_embed_texts),
|
||||
# Consumers that did `from app.utils.document_converters import embed_text/texts`
|
||||
("app.indexing_pipeline.document_embedder.embed_text", fake_embed_text),
|
||||
("app.indexing_pipeline.document_embedder.embed_texts", fake_embed_texts),
|
||||
# Pipeline service binding (the actual call site for indexing.index)
|
||||
(
|
||||
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
||||
fake_embed_texts,
|
||||
),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
try:
|
||||
p = _patch(target, replacement)
|
||||
p.start()
|
||||
patches.append(p)
|
||||
logger.info("[fake-embeddings] patched %s", target)
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
# If a future refactor moves a binding, fail loudly — silent
|
||||
# passthrough to a real embedding model would be expensive
|
||||
# and non-deterministic.
|
||||
raise RuntimeError(
|
||||
f"Could not patch embedding binding {target!r}: {exc!s}. "
|
||||
f"Update surfsense_backend/tests/e2e/fakes/embeddings.py "
|
||||
f"to point at the new binding site."
|
||||
) from exc
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,84 @@
|
|||
"""Generate deterministic one-page PDFs for connector E2E fixtures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
PDF_FIXTURES = {
|
||||
"drive-canary.pdf": (
|
||||
"Native Drive PDF Canary",
|
||||
"This one-page text-layer PDF proves native Drive Docling coverage.",
|
||||
"SURFSENSE_E2E_CANARY_TOKEN_DRIVE_PDF_001",
|
||||
),
|
||||
"onedrive-canary.pdf": (
|
||||
"OneDrive PDF Canary",
|
||||
"This one-page text-layer PDF proves OneDrive Docling coverage.",
|
||||
"SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_PDF_001",
|
||||
),
|
||||
"dropbox-canary.pdf": (
|
||||
"Dropbox PDF Canary",
|
||||
"This one-page text-layer PDF proves Dropbox Docling coverage.",
|
||||
"SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_PDF_001",
|
||||
),
|
||||
"composio-drive-canary.pdf": (
|
||||
"Composio Drive PDF Canary",
|
||||
"This one-page text-layer PDF proves Composio Drive Docling coverage.",
|
||||
"SURFSENSE_E2E_CANARY_TOKEN_COMPOSIO_DRIVE_PDF_001",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _escape_pdf_text(text: str) -> str:
|
||||
return text.replace("\\", "\\\\").replace("(", "\\(").replace(")", "\\)")
|
||||
|
||||
|
||||
def _build_pdf(lines: tuple[str, str, str]) -> bytes:
|
||||
text_ops = ["BT", "/F1 12 Tf", "72 760 Td"]
|
||||
for index, line in enumerate(lines):
|
||||
if index:
|
||||
text_ops.append("0 -18 Td")
|
||||
text_ops.append(f"({_escape_pdf_text(line)}) Tj")
|
||||
text_ops.append("ET")
|
||||
stream = "\n".join(text_ops).encode("ascii")
|
||||
|
||||
objects = [
|
||||
b"<< /Type /Catalog /Pages 2 0 R >>",
|
||||
b"<< /Type /Pages /Kids [3 0 R] /Count 1 >>",
|
||||
b"<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] "
|
||||
b"/Resources << /Font << /F1 4 0 R >> >> /Contents 5 0 R >>",
|
||||
b"<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>",
|
||||
b"<< /Length "
|
||||
+ str(len(stream)).encode("ascii")
|
||||
+ b" >>\nstream\n"
|
||||
+ stream
|
||||
+ b"\nendstream",
|
||||
]
|
||||
|
||||
pdf = bytearray(b"%PDF-1.4\n")
|
||||
offsets = [0]
|
||||
for obj_number, obj in enumerate(objects, start=1):
|
||||
offsets.append(len(pdf))
|
||||
pdf.extend(f"{obj_number} 0 obj\n".encode("ascii"))
|
||||
pdf.extend(obj)
|
||||
pdf.extend(b"\nendobj\n")
|
||||
|
||||
xref_offset = len(pdf)
|
||||
pdf.extend(f"xref\n0 {len(objects) + 1}\n".encode("ascii"))
|
||||
pdf.extend(b"0000000000 65535 f \n")
|
||||
for offset in offsets[1:]:
|
||||
pdf.extend(f"{offset:010d} 00000 n \n".encode("ascii"))
|
||||
pdf.extend(
|
||||
f"trailer\n<< /Size {len(objects) + 1} /Root 1 0 R >>\n"
|
||||
f"startxref\n{xref_offset}\n%%EOF\n".encode("ascii")
|
||||
)
|
||||
return bytes(pdf)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
out_dir = Path(__file__).parent
|
||||
for filename, lines in PDF_FIXTURES.items():
|
||||
(out_dir / filename).write_bytes(_build_pdf(lines))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
|
|
@ -0,0 +1,48 @@
|
|||
{
|
||||
"items": [
|
||||
{
|
||||
"id": "fake-calendar-event-canary-001",
|
||||
"status": "confirmed",
|
||||
"summary": "E2E Canary Calendar Event",
|
||||
"description": "This Calendar event proves the Composio Calendar live tool fetched event details. SURFSENSE_E2E_CANARY_TOKEN_CALENDAR_001",
|
||||
"location": "SurfSense E2E Room",
|
||||
"htmlLink": "https://calendar.google.com/calendar/event?eid=fake-calendar-event-canary-001",
|
||||
"start": {
|
||||
"dateTime": "2026-05-12T10:00:00Z",
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"end": {
|
||||
"dateTime": "2026-05-12T10:30:00Z",
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"attendees": [
|
||||
{
|
||||
"email": "e2e-fake@surfsense.example",
|
||||
"responseStatus": "accepted"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "fake-calendar-event-planning-001",
|
||||
"status": "confirmed",
|
||||
"summary": "E2E Planning Sync",
|
||||
"description": "Non-canary planning sync used to prove list responses can contain multiple events.",
|
||||
"location": "SurfSense Planning Room",
|
||||
"htmlLink": "https://calendar.google.com/calendar/event?eid=fake-calendar-event-planning-001",
|
||||
"start": {
|
||||
"dateTime": "2026-05-13T15:00:00Z",
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"end": {
|
||||
"dateTime": "2026-05-13T15:45:00Z",
|
||||
"timeZone": "UTC"
|
||||
},
|
||||
"attendees": [
|
||||
{
|
||||
"email": "planner@surfsense.example",
|
||||
"responseStatus": "needsAction"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"tasks": [
|
||||
{
|
||||
"id": "fake-clickup-task-canary-001",
|
||||
"name": "E2E Canary ClickUp Task",
|
||||
"list_name": "SurfSense E2E ClickUp List",
|
||||
"workspace_id": "fake-clickup-workspace-001",
|
||||
"workspace_name": "SurfSense E2E ClickUp Workspace",
|
||||
"status": "open",
|
||||
"description": "Canary task body containing SURFSENSE_E2E_CANARY_TOKEN_CLICKUP_001"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"site": {
|
||||
"cloud_id": "fake-confluence-cloud-001",
|
||||
"name": "SurfSense E2E Confluence",
|
||||
"url": "https://surfsense-e2e-confluence.atlassian.net"
|
||||
},
|
||||
"pages": [
|
||||
{
|
||||
"id": "fake-confluence-page-canary-001",
|
||||
"title": "E2E Canary Confluence Page",
|
||||
"spaceId": "fake-confluence-space-001",
|
||||
"body": {
|
||||
"storage": {
|
||||
"value": "<h1>E2E Canary Confluence Page</h1><p>This page proves Confluence OAuth indexing works end-to-end. SURFSENSE_E2E_CANARY_TOKEN_CONFLUENCE_001</p>"
|
||||
}
|
||||
},
|
||||
"comments": [
|
||||
{
|
||||
"id": "fake-confluence-comment-canary-001",
|
||||
"body": {
|
||||
"storage": {
|
||||
"value": "<p>Confluence comment content is included in indexed markdown.</p>"
|
||||
}
|
||||
},
|
||||
"version": {
|
||||
"authorId": "fake-confluence-user-001",
|
||||
"createdAt": "2026-05-08T00:00:00.000Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
99
surfsense_backend/tests/e2e/fakes/fixtures/drive_files.json
Normal file
99
surfsense_backend/tests/e2e/fakes/fixtures/drive_files.json
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"root": [
|
||||
{
|
||||
"id": "fake-folder-projects",
|
||||
"name": "Projects",
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
"modifiedTime": "2025-01-15T10:00:00.000Z",
|
||||
"createdTime": "2024-12-01T08:00:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-folder-archive",
|
||||
"name": "Archive",
|
||||
"mimeType": "application/vnd.google-apps.folder",
|
||||
"modifiedTime": "2024-11-20T14:30:00.000Z",
|
||||
"createdTime": "2024-09-10T12:00:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-readme",
|
||||
"name": "README.md",
|
||||
"mimeType": "text/markdown",
|
||||
"modifiedTime": "2025-02-01T09:15:00.000Z",
|
||||
"createdTime": "2025-01-30T16:20:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-canary",
|
||||
"name": "e2e-canary.txt",
|
||||
"mimeType": "text/plain",
|
||||
"modifiedTime": "2025-02-10T11:00:00.000Z",
|
||||
"createdTime": "2025-02-10T11:00:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-pdf-native",
|
||||
"name": "e2e-canary.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 735,
|
||||
"modifiedTime": "2025-02-10T11:05:00.000Z",
|
||||
"createdTime": "2025-02-10T11:05:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-pdf-composio",
|
||||
"name": "e2e-composio-canary.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 748,
|
||||
"modifiedTime": "2025-02-10T11:10:00.000Z",
|
||||
"createdTime": "2025-02-10T11:10:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-budget",
|
||||
"name": "Q1-Budget.csv",
|
||||
"mimeType": "text/csv",
|
||||
"modifiedTime": "2025-01-25T13:45:00.000Z",
|
||||
"createdTime": "2025-01-25T13:45:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-shortcut-canary",
|
||||
"name": "Shortcut to Canary",
|
||||
"mimeType": "application/vnd.google-apps.shortcut",
|
||||
"modifiedTime": "2025-02-10T12:00:00.000Z",
|
||||
"createdTime": "2025-02-10T12:00:00.000Z"
|
||||
},
|
||||
{
|
||||
"id": "fake-file-trashed",
|
||||
"name": "trashed-e2e-note.txt",
|
||||
"mimeType": "text/plain",
|
||||
"modifiedTime": "2025-02-11T09:00:00.000Z",
|
||||
"createdTime": "2025-02-11T09:00:00.000Z",
|
||||
"trashed": true
|
||||
}
|
||||
],
|
||||
"fake-folder-projects": [
|
||||
{
|
||||
"id": "fake-file-roadmap",
|
||||
"name": "2025-Roadmap.md",
|
||||
"mimeType": "text/markdown",
|
||||
"modifiedTime": "2025-02-12T08:30:00.000Z",
|
||||
"createdTime": "2025-01-05T10:00:00.000Z"
|
||||
}
|
||||
],
|
||||
"fake-folder-archive": [
|
||||
{
|
||||
"id": "fake-file-old-notes",
|
||||
"name": "old-meeting-notes.txt",
|
||||
"mimeType": "text/plain",
|
||||
"modifiedTime": "2024-08-15T15:00:00.000Z",
|
||||
"createdTime": "2024-08-15T15:00:00.000Z"
|
||||
}
|
||||
],
|
||||
"_file_contents": {
|
||||
"fake-file-readme": "# E2E Fake Drive\n\nThis README is served by the strict Composio fake. SURFSENSE_E2E_README_MARKER",
|
||||
"fake-file-canary": "Canary token for E2E tests: SURFSENSE_E2E_CANARY_TOKEN_DRIVE_001\nThis file's content is asserted by indexing.spec.ts to confirm the indexing pipeline ran end-to-end.",
|
||||
"fake-file-budget": "Quarter,Revenue,Expenses\nQ1,100000,75000\nSURFSENSE_E2E_BUDGET_MARKER,2025,test",
|
||||
"fake-file-roadmap": "# 2025 Roadmap\n\n- E2E Drive indexing\n- Composio Gmail/Calendar SURFSENSE_E2E_ROADMAP_MARKER",
|
||||
"fake-file-old-notes": "Old meeting notes archived in 2024. SURFSENSE_E2E_ARCHIVE_MARKER"
|
||||
},
|
||||
"_file_binary_paths": {
|
||||
"fake-file-pdf-native": "binary/drive-canary.pdf",
|
||||
"fake-file-pdf-composio": "binary/composio-drive-canary.pdf"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"": [
|
||||
{
|
||||
".tag": "file",
|
||||
"id": "id:fake-dropbox-canary",
|
||||
"name": "e2e-dropbox-canary.txt",
|
||||
"path_lower": "/e2e-dropbox-canary.txt",
|
||||
"path_display": "/e2e-dropbox-canary.txt",
|
||||
"size": 152,
|
||||
"is_downloadable": true,
|
||||
"server_modified": "2026-05-08T00:00:00Z",
|
||||
"client_modified": "2026-05-08T00:00:00Z",
|
||||
"content_hash": "fake-dropbox-hash-001"
|
||||
},
|
||||
{
|
||||
".tag": "file",
|
||||
"id": "id:fake-dropbox-pdf-canary",
|
||||
"name": "e2e-dropbox-canary.pdf",
|
||||
"path_lower": "/e2e-dropbox-canary.pdf",
|
||||
"path_display": "/e2e-dropbox-canary.pdf",
|
||||
"size": 727,
|
||||
"is_downloadable": true,
|
||||
"server_modified": "2026-05-08T00:05:00Z",
|
||||
"client_modified": "2026-05-08T00:05:00Z",
|
||||
"content_hash": "fake-dropbox-hash-pdf-001"
|
||||
}
|
||||
],
|
||||
"_file_contents": {
|
||||
"/e2e-dropbox-canary.txt": "Canary token for Dropbox E2E tests: SURFSENSE_E2E_CANARY_TOKEN_DROPBOX_001\nThis file proves the Dropbox indexing pipeline ran end-to-end."
|
||||
},
|
||||
"_file_binary_paths": {
|
||||
"/e2e-dropbox-canary.pdf": "binary/dropbox-canary.pdf"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"messages": [
|
||||
{
|
||||
"id": "fake-msg-canary-001",
|
||||
"threadId": "fake-thread-canary-001",
|
||||
"snippet": "E2E canary email body is loaded through the Gmail detail endpoint."
|
||||
},
|
||||
{
|
||||
"id": "fake-msg-planning-001",
|
||||
"threadId": "fake-thread-planning-001",
|
||||
"snippet": "Planning email used to keep Gmail fixtures representative."
|
||||
}
|
||||
],
|
||||
"details": {
|
||||
"fake-msg-canary-001": {
|
||||
"id": "fake-msg-canary-001",
|
||||
"threadId": "fake-thread-canary-001",
|
||||
"subject": "E2E Canary Email",
|
||||
"from": "sender@surfsense.example",
|
||||
"to": "e2e-fake@surfsense.example",
|
||||
"date": "Mon, 10 Feb 2025 11:00:00 +0000",
|
||||
"messageText": "This Gmail message body proves the Composio Gmail live tool fetched message details. SURFSENSE_E2E_CANARY_TOKEN_GMAIL_001"
|
||||
},
|
||||
"fake-msg-planning-001": {
|
||||
"id": "fake-msg-planning-001",
|
||||
"threadId": "fake-thread-planning-001",
|
||||
"subject": "E2E Planning Notes",
|
||||
"from": "planner@surfsense.example",
|
||||
"to": "e2e-fake@surfsense.example",
|
||||
"date": "Tue, 11 Feb 2025 09:30:00 +0000",
|
||||
"messageText": "Planning notes for a non-canary Gmail fixture."
|
||||
}
|
||||
}
|
||||
}
|
||||
15
surfsense_backend/tests/e2e/fakes/fixtures/jira_issues.json
Normal file
15
surfsense_backend/tests/e2e/fakes/fixtures/jira_issues.json
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"site": {
|
||||
"cloud_id": "fake-jira-cloud-001",
|
||||
"name": "SurfSense E2E Atlassian",
|
||||
"url": "https://surfsense-e2e.atlassian.net"
|
||||
},
|
||||
"issues": [
|
||||
{
|
||||
"id": "fake-jira-issue-canary-001",
|
||||
"key": "E2E-101",
|
||||
"summary": "E2E Canary Jira Issue",
|
||||
"description": "This Jira issue proves live MCP tool calls work end-to-end. SURFSENSE_E2E_CANARY_TOKEN_JIRA_001"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"organization": {
|
||||
"name": "SurfSense E2E Linear Org",
|
||||
"url_key": "surfsense-e2e"
|
||||
},
|
||||
"issues": [
|
||||
{
|
||||
"id": "fake-linear-issue-canary-001",
|
||||
"identifier": "E2E-101",
|
||||
"title": "E2E Canary Linear Issue",
|
||||
"description": "This Linear issue proves the live MCP tool fetched issue details. SURFSENSE_E2E_CANARY_TOKEN_LINEAR_001",
|
||||
"state": "Todo",
|
||||
"assignee": "E2E Owner",
|
||||
"team": "SurfSense E2E"
|
||||
}
|
||||
]
|
||||
}
|
||||
64
surfsense_backend/tests/e2e/fakes/fixtures/notion_pages.json
Normal file
64
surfsense_backend/tests/e2e/fakes/fixtures/notion_pages.json
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
{
|
||||
"pages": [
|
||||
{
|
||||
"object": "page",
|
||||
"id": "fake-notion-page-canary-001",
|
||||
"created_time": "2026-05-07T00:00:00.000Z",
|
||||
"last_edited_time": "2026-05-07T00:00:00.000Z",
|
||||
"url": "https://notion.so/fake-notion-page-canary-001",
|
||||
"properties": {
|
||||
"Name": {
|
||||
"id": "title",
|
||||
"type": "title",
|
||||
"title": [
|
||||
{
|
||||
"type": "text",
|
||||
"plain_text": "E2E Canary Notion Page",
|
||||
"text": {
|
||||
"content": "E2E Canary Notion Page"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"blocks": {
|
||||
"fake-notion-page-canary-001": [
|
||||
{
|
||||
"object": "block",
|
||||
"id": "fake-notion-block-heading-001",
|
||||
"type": "heading_2",
|
||||
"has_children": false,
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"plain_text": "E2E Notion Canary",
|
||||
"text": {
|
||||
"content": "E2E Notion Canary"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"object": "block",
|
||||
"id": "fake-notion-block-body-001",
|
||||
"type": "paragraph",
|
||||
"has_children": false,
|
||||
"paragraph": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"plain_text": "This Notion page proves the indexed connector fetched Notion blocks through OAuth credentials. SURFSENSE_E2E_CANARY_TOKEN_NOTION_001",
|
||||
"text": {
|
||||
"content": "This Notion page proves the indexed connector fetched Notion blocks through OAuth credentials. SURFSENSE_E2E_CANARY_TOKEN_NOTION_001"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
{
|
||||
"root": [
|
||||
{
|
||||
"id": "fake-onedrive-folder-projects",
|
||||
"name": "Projects",
|
||||
"size": 0,
|
||||
"folder": {
|
||||
"childCount": 1
|
||||
},
|
||||
"parentReference": {
|
||||
"id": "root",
|
||||
"path": "/drive/root:"
|
||||
},
|
||||
"createdDateTime": "2025-02-01T08:00:00Z",
|
||||
"lastModifiedDateTime": "2025-02-10T10:00:00Z",
|
||||
"webUrl": "https://onedrive.example/fake/projects"
|
||||
},
|
||||
{
|
||||
"id": "fake-onedrive-canary",
|
||||
"name": "e2e-onedrive-canary.txt",
|
||||
"size": 164,
|
||||
"file": {
|
||||
"mimeType": "text/plain",
|
||||
"hashes": {
|
||||
"quickXorHash": "fake-onedrive-canary-qxh"
|
||||
}
|
||||
},
|
||||
"parentReference": {
|
||||
"id": "root",
|
||||
"path": "/drive/root:"
|
||||
},
|
||||
"createdDateTime": "2025-02-10T11:00:00Z",
|
||||
"lastModifiedDateTime": "2025-02-10T11:00:00Z",
|
||||
"webUrl": "https://onedrive.example/fake/e2e-onedrive-canary.txt"
|
||||
},
|
||||
{
|
||||
"id": "fake-onedrive-pdf-canary",
|
||||
"name": "e2e-onedrive-canary.pdf",
|
||||
"size": 730,
|
||||
"file": {
|
||||
"mimeType": "application/pdf",
|
||||
"hashes": {
|
||||
"quickXorHash": "fake-onedrive-pdf-canary-qxh"
|
||||
}
|
||||
},
|
||||
"parentReference": {
|
||||
"id": "root",
|
||||
"path": "/drive/root:"
|
||||
},
|
||||
"createdDateTime": "2025-02-10T11:05:00Z",
|
||||
"lastModifiedDateTime": "2025-02-10T11:05:00Z",
|
||||
"webUrl": "https://onedrive.example/fake/e2e-onedrive-canary.pdf"
|
||||
}
|
||||
],
|
||||
"fake-onedrive-folder-projects": [
|
||||
{
|
||||
"id": "fake-onedrive-roadmap",
|
||||
"name": "OneDrive Roadmap.md",
|
||||
"size": 82,
|
||||
"file": {
|
||||
"mimeType": "text/markdown",
|
||||
"hashes": {
|
||||
"quickXorHash": "fake-onedrive-roadmap-qxh"
|
||||
}
|
||||
},
|
||||
"parentReference": {
|
||||
"id": "fake-onedrive-folder-projects",
|
||||
"path": "/drive/root:/Projects"
|
||||
},
|
||||
"createdDateTime": "2025-02-11T08:00:00Z",
|
||||
"lastModifiedDateTime": "2025-02-12T08:30:00Z",
|
||||
"webUrl": "https://onedrive.example/fake/projects/onedrive-roadmap.md"
|
||||
}
|
||||
],
|
||||
"_file_contents": {
|
||||
"fake-onedrive-canary": "Canary token for OneDrive E2E tests: SURFSENSE_E2E_CANARY_TOKEN_ONEDRIVE_001\nThis file proves the OneDrive indexing pipeline ran end-to-end.",
|
||||
"fake-onedrive-roadmap": "# OneDrive Roadmap\n\n- E2E OneDrive indexing\n- Deterministic Graph fake"
|
||||
},
|
||||
"_file_binary_paths": {
|
||||
"fake-onedrive-pdf-canary": "binary/onedrive-canary.pdf"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"team": {
|
||||
"id": "T_FAKE_SLACK_TEAM",
|
||||
"name": "SurfSense E2E Slack Workspace"
|
||||
},
|
||||
"channel": {
|
||||
"id": "C_FAKE_SLACK_CANARY",
|
||||
"name": "slack-e2e-canary",
|
||||
"purpose": "SurfSense E2E Slack canary channel"
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"ts": "1715000000.000100",
|
||||
"user": "U_FAKE_SLACK_USER",
|
||||
"text": "This Slack message proves the live MCP tool fetched channel content. SURFSENSE_E2E_CANARY_TOKEN_SLACK_001"
|
||||
}
|
||||
]
|
||||
}
|
||||
133
surfsense_backend/tests/e2e/fakes/jira_module.py
Normal file
133
surfsense_backend/tests/e2e/fakes/jira_module.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Strict Jira MCP OAuth/tool fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "jira_issues.json"
|
||||
|
||||
_AUTHORIZATION_URL = "https://mcp.atlassian.com/v1/authorize"
|
||||
_REGISTRATION_URL = "https://cf.mcp.atlassian.com/v1/register"
|
||||
_TOKEN_URL = "https://cf.mcp.atlassian.com/v1/token"
|
||||
_MCP_URL = "https://mcp.atlassian.com/v1/mcp"
|
||||
|
||||
_CLIENT_ID = "fake-jira-mcp-client-id"
|
||||
_CLIENT_SECRET = "fake-jira-mcp-client-secret"
|
||||
_ACCESS_TOKEN = "fake-jira-mcp-access-token"
|
||||
_REFRESH_TOKEN = "fake-jira-mcp-refresh-token"
|
||||
_OAUTH_CODE = "fake-jira-oauth-code"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
def _issue_text(issue: dict[str, Any]) -> str:
|
||||
return (
|
||||
f"{issue['key']} {issue['summary']}\n"
|
||||
f"id: {issue['id']}\n"
|
||||
f"description: {issue['description']}"
|
||||
)
|
||||
|
||||
|
||||
async def _list_tools() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="getAccessibleAtlassianResources",
|
||||
description="Get Jira sites accessible to the authenticated Atlassian user.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="searchJiraIssuesUsingJql",
|
||||
description="Search Jira issues using a Jira Query Language expression.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"jql": {
|
||||
"type": "string",
|
||||
"description": "JQL query used to search Jira issues.",
|
||||
},
|
||||
"maxResults": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of matching issues to return.",
|
||||
},
|
||||
},
|
||||
"required": ["jql"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def _call_tool(
|
||||
tool_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> SimpleNamespace:
|
||||
arguments = arguments or {}
|
||||
site = _FIXTURE["site"]
|
||||
issue = _FIXTURE["issues"][0]
|
||||
|
||||
if tool_name == "getAccessibleAtlassianResources":
|
||||
if arguments:
|
||||
raise ValueError(
|
||||
f"Unexpected Jira getAccessibleAtlassianResources args: {arguments!r}"
|
||||
)
|
||||
text = f"{site['name']}\ncloud_id: {site['cloud_id']}\nurl: {site['url']}"
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=text)])
|
||||
|
||||
if tool_name == "searchJiraIssuesUsingJql":
|
||||
jql = str(arguments.get("jql", ""))
|
||||
if (
|
||||
issue["summary"].lower() not in jql.lower()
|
||||
and issue["key"].lower() not in jql.lower()
|
||||
):
|
||||
raise ValueError(f"Unexpected Jira JQL query: {jql!r}")
|
||||
text = _issue_text(issue)
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=text)])
|
||||
|
||||
raise NotImplementedError(f"Unexpected Jira MCP tool call: {tool_name!r}")
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Register Jira MCP OAuth/tool handlers with the shared dispatchers."""
|
||||
del active_patches
|
||||
mcp_oauth_runtime.register_service(
|
||||
mcp_url=_MCP_URL,
|
||||
discovery_metadata={
|
||||
"issuer": "https://mcp.atlassian.com",
|
||||
"authorization_endpoint": _AUTHORIZATION_URL,
|
||||
"token_endpoint": _TOKEN_URL,
|
||||
"registration_endpoint": _REGISTRATION_URL,
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"response_types_supported": ["code"],
|
||||
},
|
||||
client_id=_CLIENT_ID,
|
||||
client_secret=_CLIENT_SECRET,
|
||||
token_endpoint=_TOKEN_URL,
|
||||
registration_endpoint=_REGISTRATION_URL,
|
||||
oauth_code=_OAUTH_CODE,
|
||||
access_token=_ACCESS_TOKEN,
|
||||
refresh_token=_REFRESH_TOKEN,
|
||||
scope="read:jira-work read:me write:jira-work",
|
||||
redirect_uri_substring="/api/v1/auth/mcp/jira/connector/callback",
|
||||
)
|
||||
mcp_runtime.register(
|
||||
url=_MCP_URL,
|
||||
expected_bearer=_ACCESS_TOKEN,
|
||||
list_tools=_list_tools,
|
||||
call_tool=_call_tool,
|
||||
)
|
||||
305
surfsense_backend/tests/e2e/fakes/linear_module.py
Normal file
305
surfsense_backend/tests/e2e/fakes/linear_module.py
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
"""Strict Linear MCP OAuth/tool fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "linear_issues.json"
|
||||
|
||||
_DISCOVERY_URL = "https://mcp.linear.app/.well-known/oauth-authorization-server"
|
||||
_AUTHORIZATION_URL = "https://mcp.linear.app/authorize"
|
||||
_REGISTRATION_URL = "https://mcp.linear.app/register"
|
||||
_TOKEN_URL = "https://mcp.linear.app/token"
|
||||
_MCP_URL = "https://mcp.linear.app/mcp"
|
||||
|
||||
_CLIENT_ID = "fake-linear-mcp-client-id"
|
||||
_CLIENT_SECRET = "fake-linear-mcp-client-secret"
|
||||
_ACCESS_TOKEN = "fake-linear-mcp-access-token"
|
||||
_REFRESH_TOKEN = "fake-linear-mcp-refresh-token"
|
||||
_OAUTH_CODE = "fake-linear-oauth-code"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E Linear fake missing surface: {self._component_name}.{name!r}. "
|
||||
"Add it to surfsense_backend/tests/e2e/fakes/linear_module.py."
|
||||
)
|
||||
|
||||
|
||||
async def _fake_discover_oauth_metadata(
|
||||
mcp_url: str,
|
||||
*,
|
||||
origin_override: str | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
del origin_override, timeout
|
||||
if mcp_url != _MCP_URL:
|
||||
raise NotImplementedError(f"Unexpected Linear MCP discovery url={mcp_url!r}")
|
||||
return {
|
||||
"issuer": "https://mcp.linear.app",
|
||||
"authorization_endpoint": _AUTHORIZATION_URL,
|
||||
"token_endpoint": _TOKEN_URL,
|
||||
"registration_endpoint": _REGISTRATION_URL,
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
|
||||
|
||||
async def _fake_register_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
client_name: str = "SurfSense",
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
if registration_endpoint != _REGISTRATION_URL:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected Linear DCR endpoint={registration_endpoint!r}"
|
||||
)
|
||||
if client_name != "SurfSense":
|
||||
raise ValueError(f"Unexpected Linear DCR client_name={client_name!r}")
|
||||
if "/api/v1/auth/mcp/linear/connector/callback" not in redirect_uri:
|
||||
raise ValueError(f"Unexpected Linear redirect_uri={redirect_uri!r}")
|
||||
return {
|
||||
"client_id": _CLIENT_ID,
|
||||
"client_secret": _CLIENT_SECRET,
|
||||
"client_id_issued_at": 1_776_621_600,
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
|
||||
async def _fake_exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
if token_endpoint != _TOKEN_URL:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected Linear token_endpoint={token_endpoint!r}"
|
||||
)
|
||||
if code != _OAUTH_CODE:
|
||||
raise ValueError(f"Unexpected fake Linear OAuth code: {code!r}")
|
||||
if "/api/v1/auth/mcp/linear/connector/callback" not in redirect_uri:
|
||||
raise ValueError(f"Unexpected Linear redirect_uri={redirect_uri!r}")
|
||||
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
|
||||
raise ValueError(
|
||||
"Unexpected Linear client credentials: "
|
||||
f"client_id={client_id!r} client_secret={client_secret!r}"
|
||||
)
|
||||
if not code_verifier:
|
||||
raise ValueError("Linear token exchange missing code_verifier.")
|
||||
return {
|
||||
"access_token": _ACCESS_TOKEN,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"expires_in": 3600,
|
||||
"scope": "read write",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
async def _fake_refresh_access_token(
|
||||
token_endpoint: str,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
if token_endpoint != _TOKEN_URL:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected Linear token_endpoint={token_endpoint!r}"
|
||||
)
|
||||
if refresh_token != _REFRESH_TOKEN:
|
||||
raise ValueError(f"Unexpected fake Linear refresh token: {refresh_token!r}")
|
||||
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
|
||||
raise ValueError(
|
||||
"Unexpected Linear refresh client credentials: "
|
||||
f"client_id={client_id!r} client_secret={client_secret!r}"
|
||||
)
|
||||
return {
|
||||
"access_token": _ACCESS_TOKEN,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"expires_in": 3600,
|
||||
"scope": "read write",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
class _FakeStreamableHttpClient(_StrictFakeMixin):
|
||||
_component_name = "streamablehttp_client"
|
||||
|
||||
def __init__(
|
||||
self, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
|
||||
):
|
||||
del kwargs
|
||||
if url != _MCP_URL:
|
||||
raise NotImplementedError(f"Unexpected Linear MCP url={url!r}")
|
||||
auth = (headers or {}).get("Authorization")
|
||||
if auth != f"Bearer {_ACCESS_TOKEN}":
|
||||
raise ValueError(f"Unexpected Linear MCP Authorization header: {auth!r}")
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
|
||||
async def __aenter__(self) -> tuple[object, object, None]:
|
||||
return object(), object(), None
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
class _FakeClientSession(_StrictFakeMixin):
|
||||
_component_name = "ClientSession"
|
||||
|
||||
def __init__(self, read: object, write: object):
|
||||
self.read = read
|
||||
self.write = write
|
||||
|
||||
async def __aenter__(self) -> _FakeClientSession:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return None
|
||||
|
||||
async def list_tools(self) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="list_issues",
|
||||
description="List Linear issues visible to the authenticated user.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Text to search for in Linear issues.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of issues to return.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="get_issue",
|
||||
description="Get a Linear issue by id or identifier.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Issue id or identifier.",
|
||||
}
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, *, arguments: dict[str, Any] | None = None
|
||||
) -> SimpleNamespace:
|
||||
arguments = arguments or {}
|
||||
issue = _FIXTURE["issues"][0]
|
||||
|
||||
if tool_name == "list_issues":
|
||||
query = str(arguments.get("query", ""))
|
||||
if query and issue["title"].lower() not in query.lower():
|
||||
raise ValueError(f"Unexpected Linear issue query: {query!r}")
|
||||
text = (
|
||||
f"{issue['identifier']} {issue['title']}\n"
|
||||
f"id: {issue['id']}\n"
|
||||
f"description: {issue['description']}"
|
||||
)
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=text)])
|
||||
|
||||
if tool_name == "get_issue":
|
||||
issue_id = arguments.get("id")
|
||||
if issue_id not in {issue["id"], issue["identifier"]}:
|
||||
raise ValueError(f"Unexpected Linear issue id: {issue_id!r}")
|
||||
text = (
|
||||
f"{issue['identifier']} {issue['title']}\n"
|
||||
f"id: {issue['id']}\n"
|
||||
f"description: {issue['description']}"
|
||||
)
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=text)])
|
||||
|
||||
raise NotImplementedError(f"Unexpected Linear MCP tool call: {tool_name!r}")
|
||||
|
||||
|
||||
def _fake_streamablehttp_client(
|
||||
url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
|
||||
) -> _FakeStreamableHttpClient:
|
||||
return _FakeStreamableHttpClient(url, headers=headers, **kwargs)
|
||||
|
||||
|
||||
async def _list_tools() -> SimpleNamespace:
|
||||
return await _FakeClientSession(object(), object()).list_tools()
|
||||
|
||||
|
||||
async def _call_tool(tool_name: str, arguments: dict[str, Any]) -> SimpleNamespace:
|
||||
return await _FakeClientSession(object(), object()).call_tool(
|
||||
tool_name, arguments=arguments
|
||||
)
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Register Linear MCP OAuth/tool handlers with the shared dispatchers."""
|
||||
del active_patches
|
||||
mcp_oauth_runtime.register_service(
|
||||
mcp_url=_MCP_URL,
|
||||
discovery_metadata={
|
||||
"issuer": "https://mcp.linear.app",
|
||||
"authorization_endpoint": _AUTHORIZATION_URL,
|
||||
"token_endpoint": _TOKEN_URL,
|
||||
"registration_endpoint": _REGISTRATION_URL,
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"response_types_supported": ["code"],
|
||||
},
|
||||
client_id=_CLIENT_ID,
|
||||
client_secret=_CLIENT_SECRET,
|
||||
token_endpoint=_TOKEN_URL,
|
||||
registration_endpoint=_REGISTRATION_URL,
|
||||
oauth_code=_OAUTH_CODE,
|
||||
access_token=_ACCESS_TOKEN,
|
||||
refresh_token=_REFRESH_TOKEN,
|
||||
scope="read write",
|
||||
redirect_uri_substring="/api/v1/auth/mcp/linear/connector/callback",
|
||||
)
|
||||
mcp_runtime.register(
|
||||
url=_MCP_URL,
|
||||
expected_bearer=_ACCESS_TOKEN,
|
||||
list_tools=_list_tools,
|
||||
call_tool=_call_tool,
|
||||
)
|
||||
48
surfsense_backend/tests/e2e/fakes/llm.py
Normal file
48
surfsense_backend/tests/e2e/fakes/llm.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Deterministic LLM fake for the E2E indexing pipeline.
|
||||
|
||||
The production indexing pipeline summarizes documents with:
|
||||
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | llm
|
||||
summary_result = await summary_chain.ainvoke({"document": ...})
|
||||
summary_content = summary_result.content
|
||||
|
||||
The `llm` parameter is supplied per-document by
|
||||
`app.services.llm_service.get_user_long_context_llm`. We patch THAT
|
||||
function to return a langchain-native FakeListChatModel so the rest of
|
||||
the chain works unchanged. No real LLM provider package is touched.
|
||||
|
||||
Run-backend / run-celery use unittest.mock.patch.start() to install
|
||||
this at every binding site (the source module + every consumer that
|
||||
did `from app.services.llm_service import get_user_long_context_llm`
|
||||
at module load time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_fake_llm() -> FakeListChatModel:
|
||||
"""Build a fresh FakeListChatModel that returns a deterministic summary."""
|
||||
# FakeListChatModel cycles through `responses` for each invocation. We
|
||||
# supply a single deterministic string. The summary content is tagged
|
||||
# with a marker that specs CAN assert on if they want, but the
|
||||
# primary indexing assertion is on the file content (chunked + stored
|
||||
# separately by the pipeline).
|
||||
fake = FakeListChatModel(
|
||||
responses=[
|
||||
"E2E_FAKE_SUMMARY: Indexed by Playwright E2E run with deterministic LLM stub."
|
||||
]
|
||||
)
|
||||
return fake
|
||||
|
||||
|
||||
async def fake_get_user_long_context_llm(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Drop-in replacement for app.services.llm_service.get_user_long_context_llm."""
|
||||
logger.info("[fake-llm] returning FakeListChatModel for E2E indexing")
|
||||
return _make_fake_llm()
|
||||
217
surfsense_backend/tests/e2e/fakes/mcp_oauth_runtime.py
Normal file
217
surfsense_backend/tests/e2e/fakes/mcp_oauth_runtime.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""Shared strict MCP OAuth fake dispatcher for E2E tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _OAuthHandler:
|
||||
mcp_url: str
|
||||
discovery_metadata: dict[str, Any]
|
||||
client_id: str
|
||||
client_secret: str
|
||||
token_endpoint: str
|
||||
registration_endpoint: str
|
||||
oauth_code: str
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
redirect_uri_substring: str
|
||||
expected_origin_override: str | None = None
|
||||
|
||||
|
||||
_SERVICES_BY_MCP_URL: dict[str, _OAuthHandler] = {}
|
||||
_SERVICES_BY_REGISTRATION_URL: dict[str, _OAuthHandler] = {}
|
||||
_SERVICES_BY_TOKEN_URL: dict[str, _OAuthHandler] = {}
|
||||
|
||||
|
||||
def register_service(
|
||||
*,
|
||||
mcp_url: str,
|
||||
discovery_metadata: dict[str, Any],
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_endpoint: str,
|
||||
registration_endpoint: str,
|
||||
oauth_code: str,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
scope: str,
|
||||
redirect_uri_substring: str,
|
||||
expected_origin_override: str | None = None,
|
||||
) -> None:
|
||||
"""Register deterministic MCP OAuth behavior for one service."""
|
||||
handler = _OAuthHandler(
|
||||
mcp_url=mcp_url,
|
||||
discovery_metadata=discovery_metadata,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
token_endpoint=token_endpoint,
|
||||
registration_endpoint=registration_endpoint,
|
||||
oauth_code=oauth_code,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
scope=scope,
|
||||
redirect_uri_substring=redirect_uri_substring,
|
||||
expected_origin_override=expected_origin_override,
|
||||
)
|
||||
_register_unique(_SERVICES_BY_MCP_URL, mcp_url, handler, "MCP URL")
|
||||
_register_unique(
|
||||
_SERVICES_BY_REGISTRATION_URL,
|
||||
registration_endpoint,
|
||||
handler,
|
||||
"registration endpoint",
|
||||
)
|
||||
_register_unique(_SERVICES_BY_TOKEN_URL, token_endpoint, handler, "token endpoint")
|
||||
|
||||
|
||||
def _register_unique(
|
||||
registry: dict[str, _OAuthHandler],
|
||||
key: str,
|
||||
handler: _OAuthHandler,
|
||||
label: str,
|
||||
) -> None:
|
||||
existing = registry.get(key)
|
||||
if existing is not None and existing != handler:
|
||||
raise ValueError(f"MCP OAuth fake {label} already registered: {key!r}.")
|
||||
registry[key] = handler
|
||||
|
||||
|
||||
async def _fake_discover_oauth_metadata(
|
||||
mcp_url: str,
|
||||
*,
|
||||
origin_override: str | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
handler = _SERVICES_BY_MCP_URL.get(mcp_url)
|
||||
if handler is None:
|
||||
raise NotImplementedError(f"Unexpected MCP OAuth discovery url={mcp_url!r}")
|
||||
if origin_override != handler.expected_origin_override:
|
||||
raise ValueError(
|
||||
f"Unexpected MCP OAuth origin_override for {mcp_url!r}: {origin_override!r}"
|
||||
)
|
||||
return dict(handler.discovery_metadata)
|
||||
|
||||
|
||||
async def _fake_register_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
client_name: str = "SurfSense",
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
handler = _SERVICES_BY_REGISTRATION_URL.get(registration_endpoint)
|
||||
if handler is None:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected MCP OAuth DCR endpoint={registration_endpoint!r}"
|
||||
)
|
||||
if client_name != "SurfSense":
|
||||
raise ValueError(f"Unexpected MCP OAuth DCR client_name={client_name!r}")
|
||||
if handler.redirect_uri_substring not in redirect_uri:
|
||||
raise ValueError(
|
||||
f"Unexpected MCP OAuth DCR redirect_uri={redirect_uri!r}; "
|
||||
f"expected {handler.redirect_uri_substring!r}"
|
||||
)
|
||||
return {
|
||||
"client_id": handler.client_id,
|
||||
"client_secret": handler.client_secret,
|
||||
"client_id_issued_at": 1_776_621_600,
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
|
||||
async def _fake_exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
handler = _SERVICES_BY_TOKEN_URL.get(token_endpoint)
|
||||
if handler is None:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected MCP OAuth token_endpoint={token_endpoint!r}"
|
||||
)
|
||||
if code != handler.oauth_code:
|
||||
raise ValueError(f"Unexpected fake MCP OAuth code: {code!r}")
|
||||
if handler.redirect_uri_substring not in redirect_uri:
|
||||
raise ValueError(
|
||||
f"Unexpected MCP OAuth token redirect_uri={redirect_uri!r}; "
|
||||
f"expected {handler.redirect_uri_substring!r}"
|
||||
)
|
||||
if client_id != handler.client_id or client_secret != handler.client_secret:
|
||||
raise ValueError(
|
||||
"Unexpected MCP OAuth client credentials: "
|
||||
f"client_id={client_id!r} client_secret={client_secret!r}"
|
||||
)
|
||||
if not code_verifier:
|
||||
raise ValueError("MCP OAuth token exchange missing code_verifier.")
|
||||
return {
|
||||
"access_token": handler.access_token,
|
||||
"refresh_token": handler.refresh_token,
|
||||
"expires_in": 3600,
|
||||
"scope": handler.scope,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
async def _fake_refresh_access_token(
|
||||
token_endpoint: str,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
del timeout
|
||||
handler = _SERVICES_BY_TOKEN_URL.get(token_endpoint)
|
||||
if handler is None:
|
||||
raise NotImplementedError(
|
||||
f"Unexpected MCP OAuth refresh token_endpoint={token_endpoint!r}"
|
||||
)
|
||||
if refresh_token != handler.refresh_token:
|
||||
raise ValueError(f"Unexpected fake MCP OAuth refresh token: {refresh_token!r}")
|
||||
if client_id != handler.client_id or client_secret != handler.client_secret:
|
||||
raise ValueError(
|
||||
"Unexpected MCP OAuth refresh client credentials: "
|
||||
f"client_id={client_id!r} client_secret={client_secret!r}"
|
||||
)
|
||||
return {
|
||||
"access_token": handler.access_token,
|
||||
"refresh_token": handler.refresh_token,
|
||||
"expires_in": 3600,
|
||||
"scope": handler.scope,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch generic MCP OAuth helper boundaries exactly once."""
|
||||
targets = [
|
||||
(
|
||||
"app.services.mcp_oauth.discovery.discover_oauth_metadata",
|
||||
_fake_discover_oauth_metadata,
|
||||
),
|
||||
("app.services.mcp_oauth.discovery.register_client", _fake_register_client),
|
||||
(
|
||||
"app.services.mcp_oauth.discovery.exchange_code_for_tokens",
|
||||
_fake_exchange_code_for_tokens,
|
||||
),
|
||||
(
|
||||
"app.services.mcp_oauth.discovery.refresh_access_token",
|
||||
_fake_refresh_access_token,
|
||||
),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
148
surfsense_backend/tests/e2e/fakes/mcp_runtime.py
Normal file
148
surfsense_backend/tests/e2e/fakes/mcp_runtime.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Shared strict MCP streamable-HTTP runtime fake for E2E tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
ListToolsFn = Callable[[], Any | Awaitable[Any]]
|
||||
CallToolFn = Callable[[str, dict[str, Any]], Any | Awaitable[Any]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RuntimeHandler:
|
||||
expected_bearer: str
|
||||
list_tools: ListToolsFn
|
||||
call_tool: CallToolFn
|
||||
|
||||
|
||||
_HANDLERS: dict[str, _RuntimeHandler] = {}
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E MCP runtime fake missing surface: {self._component_name}.{name!r}. "
|
||||
"Add it to surfsense_backend/tests/e2e/fakes/mcp_runtime.py."
|
||||
)
|
||||
|
||||
|
||||
class _FakeEndpoint(_StrictFakeMixin):
|
||||
_component_name = "streamablehttp_endpoint"
|
||||
|
||||
def __init__(self, url: str, handler: _RuntimeHandler):
|
||||
self.url = url
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class _FakeStreamableHttpClient(_StrictFakeMixin):
|
||||
_component_name = "streamablehttp_client"
|
||||
|
||||
def __init__(
|
||||
self, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
|
||||
):
|
||||
del kwargs
|
||||
handler = _HANDLERS.get(url)
|
||||
if handler is None:
|
||||
raise NotImplementedError(f"Unexpected MCP streamable-http url={url!r}")
|
||||
|
||||
auth = (headers or {}).get("Authorization")
|
||||
expected = f"Bearer {handler.expected_bearer}"
|
||||
if auth != expected:
|
||||
raise ValueError(
|
||||
f"Unexpected MCP Authorization header for {url!r}: {auth!r}"
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.handler = handler
|
||||
|
||||
async def __aenter__(self) -> tuple[_FakeEndpoint, _FakeEndpoint, None]:
|
||||
return (
|
||||
_FakeEndpoint(self.url, self.handler),
|
||||
_FakeEndpoint(self.url, self.handler),
|
||||
None,
|
||||
)
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
|
||||
class _FakeClientSession(_StrictFakeMixin):
|
||||
_component_name = "ClientSession"
|
||||
|
||||
def __init__(self, read: _FakeEndpoint, write: _FakeEndpoint):
|
||||
if read.handler is not write.handler:
|
||||
raise ValueError("MCP fake received mismatched read/write endpoints.")
|
||||
self.read = read
|
||||
self.write = write
|
||||
self.handler = read.handler
|
||||
|
||||
async def __aenter__(self) -> _FakeClientSession:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return None
|
||||
|
||||
async def list_tools(self) -> SimpleNamespace:
|
||||
result = self.handler.list_tools()
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, *, arguments: dict[str, Any] | None = None
|
||||
) -> SimpleNamespace:
|
||||
result = self.handler.call_tool(tool_name, arguments or {})
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
def _fake_streamablehttp_client(
|
||||
url: str, *, headers: dict[str, str] | None = None, **kwargs: Any
|
||||
) -> _FakeStreamableHttpClient:
|
||||
return _FakeStreamableHttpClient(url, headers=headers, **kwargs)
|
||||
|
||||
|
||||
def register(
|
||||
*,
|
||||
url: str,
|
||||
expected_bearer: str,
|
||||
list_tools: ListToolsFn,
|
||||
call_tool: CallToolFn,
|
||||
) -> None:
|
||||
"""Register a fake streamable-HTTP MCP server by canonical MCP URL."""
|
||||
existing = _HANDLERS.get(url)
|
||||
handler = _RuntimeHandler(
|
||||
expected_bearer=expected_bearer,
|
||||
list_tools=list_tools,
|
||||
call_tool=call_tool,
|
||||
)
|
||||
if existing is not None and existing != handler:
|
||||
raise ValueError(f"MCP runtime fake handler already registered for {url!r}.")
|
||||
_HANDLERS[url] = handler
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch production MCP streamable-HTTP boundaries exactly once."""
|
||||
targets = [
|
||||
(
|
||||
"app.agents.new_chat.tools.mcp_tool.streamablehttp_client",
|
||||
_fake_streamablehttp_client,
|
||||
),
|
||||
("app.agents.new_chat.tools.mcp_tool.ClientSession", _FakeClientSession),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
444
surfsense_backend/tests/e2e/fakes/native_google.py
Normal file
444
surfsense_backend/tests/e2e/fakes/native_google.py
Normal file
|
|
@ -0,0 +1,444 @@
|
|||
"""Strict native Google SDK fakes for Playwright E2E.
|
||||
|
||||
This module patches the production Google OAuth and Drive SDK bindings used by
|
||||
the native Google connector happy paths. It deliberately does not replace the
|
||||
whole Google package; unmodelled service methods fail loudly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from .binary_loader import _resolve_file_bytes
|
||||
|
||||
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
_DRIVE_FIXTURE_PATH = _FIXTURES_DIR / "drive_files.json"
|
||||
_GMAIL_FIXTURE_PATH = _FIXTURES_DIR / "gmail_messages.json"
|
||||
_CALENDAR_FIXTURE_PATH = _FIXTURES_DIR / "calendar_events.json"
|
||||
|
||||
|
||||
def _load_drive_fixture() -> dict[str, Any]:
|
||||
with _DRIVE_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_DRIVE_FIXTURE = _load_drive_fixture()
|
||||
|
||||
|
||||
def _load_gmail_fixture() -> dict[str, Any]:
|
||||
with _GMAIL_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_GMAIL_FIXTURE = _load_gmail_fixture()
|
||||
|
||||
|
||||
def _load_calendar_fixture() -> dict[str, Any]:
|
||||
with _CALENDAR_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_CALENDAR_FIXTURE = _load_calendar_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E native Google fake missing surface: "
|
||||
f"{self._component_name}.{name!r}. Add it to "
|
||||
f"surfsense_backend/tests/e2e/fakes/native_google.py."
|
||||
)
|
||||
|
||||
|
||||
class _FakeFlow(_StrictFakeMixin):
|
||||
_component_name = "Flow"
|
||||
|
||||
def __init__(
|
||||
self, *, redirect_uri: str | None = None, scopes: list[str] | None = None
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scopes = scopes or []
|
||||
self.code_verifier: str | None = None
|
||||
self.credentials = _fake_credentials(scopes=self.scopes)
|
||||
|
||||
@classmethod
|
||||
def from_client_config(
|
||||
cls,
|
||||
client_config: dict[str, Any],
|
||||
scopes: list[str],
|
||||
redirect_uri: str | None = None,
|
||||
**_: Any,
|
||||
) -> _FakeFlow:
|
||||
del client_config
|
||||
return cls(redirect_uri=redirect_uri, scopes=scopes)
|
||||
|
||||
def authorization_url(self, *, state: str, **_: Any) -> tuple[str, str]:
|
||||
if not self.redirect_uri:
|
||||
raise ValueError("Fake Google Flow requires redirect_uri.")
|
||||
|
||||
parsed = urlparse(self.redirect_uri)
|
||||
query = parse_qs(parsed.query)
|
||||
query["code"] = ["fake-native-drive-oauth-code"]
|
||||
query["state"] = [state]
|
||||
redirect = urlunparse(parsed._replace(query=urlencode(query, doseq=True)))
|
||||
return redirect, state
|
||||
|
||||
def fetch_token(self, *, code: str, **_: Any) -> None:
|
||||
if code != "fake-native-drive-oauth-code":
|
||||
raise ValueError(f"Unexpected fake Google OAuth code: {code!r}")
|
||||
self.credentials = _fake_credentials(scopes=self.scopes)
|
||||
|
||||
|
||||
def _fake_credentials(*, scopes: list[str] | None = None) -> Credentials:
|
||||
return Credentials(
|
||||
token="fake-native-drive-access-token",
|
||||
refresh_token="fake-native-drive-refresh-token",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id="fake-native-drive-client-id",
|
||||
client_secret="fake-native-drive-client-secret",
|
||||
scopes=scopes or ["https://www.googleapis.com/auth/drive"],
|
||||
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
|
||||
)
|
||||
|
||||
|
||||
class _FakeRequest(_StrictFakeMixin):
|
||||
_component_name = "request"
|
||||
|
||||
def __init__(self, payload: Any):
|
||||
self.payload = payload
|
||||
self.http = None
|
||||
|
||||
def execute(self, **_: Any) -> Any:
|
||||
return self.payload
|
||||
|
||||
|
||||
class _FakeMediaRequest(_StrictFakeMixin):
|
||||
_component_name = "media_request"
|
||||
|
||||
def __init__(self, content: bytes):
|
||||
self.content = content
|
||||
self.http = None
|
||||
|
||||
|
||||
class _FakeMediaIoBaseDownload(_StrictFakeMixin):
|
||||
_component_name = "MediaIoBaseDownload"
|
||||
|
||||
def __init__(self, fd, request: _FakeMediaRequest, chunksize: int | None = None):
|
||||
del chunksize
|
||||
self.fd = fd
|
||||
self.request = request
|
||||
self._done = False
|
||||
|
||||
def next_chunk(self) -> tuple[None, bool]:
|
||||
if not self._done:
|
||||
self.fd.write(self.request.content)
|
||||
self._done = True
|
||||
return None, True
|
||||
|
||||
|
||||
class _FakeDriveFiles(_StrictFakeMixin):
|
||||
_component_name = "drive.files"
|
||||
|
||||
def list(self, **kwargs: Any) -> _FakeRequest:
|
||||
q = kwargs.get("q", "")
|
||||
folder_id = "root"
|
||||
if "in parents" in q:
|
||||
try:
|
||||
folder_id = q.split("'")[1]
|
||||
except IndexError:
|
||||
folder_id = "root"
|
||||
|
||||
files = _filter_drive_files_for_query(q, _DRIVE_FIXTURE.get(folder_id, []))
|
||||
return _FakeRequest({"files": files, "nextPageToken": None})
|
||||
|
||||
def get(self, **kwargs: Any) -> _FakeRequest:
|
||||
file_id = kwargs.get("fileId")
|
||||
metadata = _drive_get_metadata(file_id)
|
||||
return _FakeRequest(metadata)
|
||||
|
||||
def get_media(self, **kwargs: Any) -> _FakeMediaRequest:
|
||||
file_id = kwargs.get("fileId")
|
||||
content = _resolve_file_bytes(_DRIVE_FIXTURE, file_id, _FIXTURES_DIR)
|
||||
if content is None:
|
||||
raise NotImplementedError(
|
||||
f"E2E native Google fake has no content for fileId={file_id!r}."
|
||||
)
|
||||
return _FakeMediaRequest(content)
|
||||
|
||||
def export(self, **kwargs: Any) -> _FakeRequest:
|
||||
file_id = kwargs.get("fileId")
|
||||
content = _DRIVE_FIXTURE.get("_file_contents", {}).get(file_id)
|
||||
if content is None:
|
||||
raise NotImplementedError(
|
||||
f"E2E native Google fake has no export content for fileId={file_id!r}."
|
||||
)
|
||||
return _FakeRequest(content.encode("utf-8"))
|
||||
|
||||
|
||||
class _FakeDriveChanges(_StrictFakeMixin):
|
||||
_component_name = "drive.changes"
|
||||
|
||||
def getStartPageToken(self, **_: Any) -> _FakeRequest: # noqa: N802
|
||||
return _FakeRequest({"startPageToken": "fake-native-start-page-token-1"})
|
||||
|
||||
def list(self, **_: Any) -> _FakeRequest:
|
||||
return _FakeRequest(
|
||||
{"changes": [], "newStartPageToken": "fake-native-start-page-token-1"}
|
||||
)
|
||||
|
||||
|
||||
class _FakeDriveService(_StrictFakeMixin):
|
||||
_component_name = "drive_service"
|
||||
|
||||
def files(self) -> _FakeDriveFiles:
|
||||
return _FakeDriveFiles()
|
||||
|
||||
def changes(self) -> _FakeDriveChanges:
|
||||
return _FakeDriveChanges()
|
||||
|
||||
|
||||
class _FakeGmailUsers(_StrictFakeMixin):
|
||||
_component_name = "gmail.users"
|
||||
|
||||
def getProfile(self, **kwargs: Any) -> _FakeRequest: # noqa: N802
|
||||
user_id = kwargs.get("userId")
|
||||
if user_id != "me":
|
||||
raise NotImplementedError(
|
||||
f"Unexpected fake Gmail profile userId={user_id!r}"
|
||||
)
|
||||
return _FakeRequest({"emailAddress": "native-drive-e2e@surfsense.example"})
|
||||
|
||||
def messages(self) -> _FakeGmailMessages:
|
||||
return _FakeGmailMessages()
|
||||
|
||||
|
||||
class _FakeGmailMessages(_StrictFakeMixin):
|
||||
_component_name = "gmail.messages"
|
||||
|
||||
def list(self, **kwargs: Any) -> _FakeRequest:
|
||||
user_id = kwargs.get("userId")
|
||||
if user_id != "me":
|
||||
raise NotImplementedError(f"Unexpected fake Gmail list userId={user_id!r}")
|
||||
|
||||
max_results = int(kwargs.get("maxResults", 10) or 10)
|
||||
messages = list(_GMAIL_FIXTURE.get("messages", []))[:max_results]
|
||||
return _FakeRequest(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"id": message["id"],
|
||||
"threadId": message.get("threadId"),
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
"resultSizeEstimate": len(messages),
|
||||
}
|
||||
)
|
||||
|
||||
def get(self, **kwargs: Any) -> _FakeRequest:
|
||||
user_id = kwargs.get("userId")
|
||||
if user_id != "me":
|
||||
raise NotImplementedError(f"Unexpected fake Gmail get userId={user_id!r}")
|
||||
|
||||
message_id = kwargs.get("id")
|
||||
detail = _GMAIL_FIXTURE.get("details", {}).get(message_id)
|
||||
if detail is None:
|
||||
raise NotImplementedError(
|
||||
f"E2E native Gmail fake has no message detail for id={message_id!r}."
|
||||
)
|
||||
return _FakeRequest(_gmail_detail_to_native_message(detail))
|
||||
|
||||
|
||||
class _FakeGmailService(_StrictFakeMixin):
|
||||
_component_name = "gmail_service"
|
||||
|
||||
def users(self) -> _FakeGmailUsers:
|
||||
return _FakeGmailUsers()
|
||||
|
||||
|
||||
class _FakeCalendarEvents(_StrictFakeMixin):
|
||||
_component_name = "calendar.events"
|
||||
|
||||
def list(self, **kwargs: Any) -> _FakeRequest:
|
||||
calendar_id = kwargs.get("calendarId")
|
||||
if calendar_id != "primary":
|
||||
raise NotImplementedError(
|
||||
f"Unexpected fake Calendar events calendarId={calendar_id!r}"
|
||||
)
|
||||
|
||||
max_results = int(kwargs.get("maxResults", 250) or 250)
|
||||
time_min = kwargs.get("timeMin")
|
||||
time_max = kwargs.get("timeMax")
|
||||
items = [
|
||||
event
|
||||
for event in _CALENDAR_FIXTURE.get("items", [])
|
||||
if _calendar_event_in_range(event, time_min, time_max)
|
||||
][:max_results]
|
||||
|
||||
return _FakeRequest(
|
||||
{
|
||||
"items": items,
|
||||
"summary": "native-calendar-e2e@surfsense.example",
|
||||
"timeZone": "UTC",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _FakeCalendarService(_StrictFakeMixin):
|
||||
_component_name = "calendar_service"
|
||||
|
||||
def events(self) -> _FakeCalendarEvents:
|
||||
return _FakeCalendarEvents()
|
||||
|
||||
|
||||
def _fake_build(service_name: str, version: str, **_: Any) -> Any:
|
||||
if service_name == "drive" and version == "v3":
|
||||
return _FakeDriveService()
|
||||
if service_name == "gmail" and version == "v1":
|
||||
return _FakeGmailService()
|
||||
if service_name == "calendar" and version == "v3":
|
||||
return _FakeCalendarService()
|
||||
raise NotImplementedError(
|
||||
f"E2E native Google fake cannot build {service_name!r} {version!r}."
|
||||
)
|
||||
|
||||
|
||||
def _extract_quoted_value(q: str, anchor: str) -> str | None:
|
||||
anchor_idx = q.find(anchor)
|
||||
if anchor_idx == -1:
|
||||
return None
|
||||
after_anchor = q[anchor_idx + len(anchor) :]
|
||||
first_quote_idx = after_anchor.find("'")
|
||||
if first_quote_idx == -1:
|
||||
return None
|
||||
after_first_quote = after_anchor[first_quote_idx + 1 :]
|
||||
second_quote_idx = after_first_quote.find("'")
|
||||
if second_quote_idx == -1:
|
||||
return None
|
||||
return after_first_quote[:second_quote_idx]
|
||||
|
||||
|
||||
def _filter_drive_files_for_query(
|
||||
q: str, files: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered = list(files)
|
||||
|
||||
if "trashed = false" in q:
|
||||
filtered = [entry for entry in filtered if entry.get("trashed") is not True]
|
||||
|
||||
excluded_mime_type = _extract_quoted_value(q, "mimeType !=")
|
||||
if excluded_mime_type:
|
||||
filtered = [
|
||||
entry for entry in filtered if entry.get("mimeType") != excluded_mime_type
|
||||
]
|
||||
|
||||
included_mime_type = _extract_quoted_value(q, "mimeType =")
|
||||
if included_mime_type:
|
||||
filtered = [
|
||||
entry for entry in filtered if entry.get("mimeType") == included_mime_type
|
||||
]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def _drive_get_metadata(file_id: str | None) -> dict[str, Any]:
|
||||
for items in _DRIVE_FIXTURE.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for entry in items:
|
||||
if entry.get("id") == file_id:
|
||||
return dict(entry)
|
||||
raise NotImplementedError(
|
||||
f"E2E native Google fake has no metadata for fileId={file_id!r}."
|
||||
)
|
||||
|
||||
|
||||
def _parse_rfc3339(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def _calendar_event_start(event: dict[str, Any]) -> datetime | None:
|
||||
start = event.get("start", {})
|
||||
value = start.get("dateTime") or start.get("date")
|
||||
parsed = _parse_rfc3339(value)
|
||||
if parsed and parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=UTC)
|
||||
return parsed
|
||||
|
||||
|
||||
def _calendar_event_in_range(
|
||||
event: dict[str, Any], time_min: str | None, time_max: str | None
|
||||
) -> bool:
|
||||
event_start = _calendar_event_start(event)
|
||||
if event_start is None:
|
||||
return True
|
||||
|
||||
parsed_min = _parse_rfc3339(time_min)
|
||||
parsed_max = _parse_rfc3339(time_max)
|
||||
if parsed_min and event_start < parsed_min:
|
||||
return False
|
||||
return not (parsed_max and event_start > parsed_max)
|
||||
|
||||
|
||||
def _gmail_detail_to_native_message(detail: dict[str, Any]) -> dict[str, Any]:
|
||||
message_text = detail.get("messageText", "")
|
||||
encoded_body = base64.urlsafe_b64encode(message_text.encode("utf-8")).decode(
|
||||
"ascii"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": detail.get("id"),
|
||||
"threadId": detail.get("threadId"),
|
||||
"labelIds": ["INBOX", "IMPORTANT"],
|
||||
"snippet": message_text[:160],
|
||||
"payload": {
|
||||
"mimeType": "text/plain",
|
||||
"headers": [
|
||||
{"name": "Subject", "value": detail.get("subject", "No Subject")},
|
||||
{"name": "From", "value": detail.get("from", "Unknown Sender")},
|
||||
{"name": "To", "value": detail.get("to", "Unknown Recipient")},
|
||||
{"name": "Date", "value": detail.get("date", "Unknown Date")},
|
||||
],
|
||||
"body": {
|
||||
"data": encoded_body,
|
||||
"size": len(message_text),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch production bindings to use native Google SDK fakes."""
|
||||
targets = [
|
||||
("app.routes.google_drive_add_connector_route.Flow", _FakeFlow),
|
||||
("app.routes.google_gmail_add_connector_route.Flow", _FakeFlow),
|
||||
("app.routes.google_calendar_add_connector_route.Flow", _FakeFlow),
|
||||
("app.connectors.google_drive.client.build", _fake_build),
|
||||
("app.connectors.google_gmail_connector.build", _fake_build),
|
||||
("app.connectors.google_calendar_connector.build", _fake_build),
|
||||
("app.agents.new_chat.tools.google_calendar.create_event.build", _fake_build),
|
||||
("app.agents.new_chat.tools.google_calendar.update_event.build", _fake_build),
|
||||
("app.agents.new_chat.tools.google_calendar.delete_event.build", _fake_build),
|
||||
("googleapiclient.http.MediaIoBaseDownload", _FakeMediaIoBaseDownload),
|
||||
(
|
||||
"app.connectors.google_drive.client._build_thread_http",
|
||||
lambda credentials: None,
|
||||
),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
206
surfsense_backend/tests/e2e/fakes/notion_module.py
Normal file
206
surfsense_backend/tests/e2e/fakes/notion_module.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Strict Notion OAuth/API fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "notion_pages.json"
|
||||
_TOKEN_URL = "https://api.notion.com/v1/oauth/token"
|
||||
_ACCESS_TOKEN = "fake-notion-access-token"
|
||||
_REFRESH_TOKEN = "fake-notion-refresh-token"
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E Notion fake missing surface: {self._component_name}.{name!r}. "
|
||||
"Add it to surfsense_backend/tests/e2e/fakes/notion_module.py."
|
||||
)
|
||||
|
||||
|
||||
class APIResponseError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
status: int = 400,
|
||||
code: str = "validation_error",
|
||||
headers: dict[str, str] | None = None,
|
||||
body: Any | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.headers = headers or {}
|
||||
self.body = body or {"message": message}
|
||||
|
||||
|
||||
errors = types.ModuleType("notion_client.errors")
|
||||
errors.APIResponseError = APIResponseError
|
||||
|
||||
|
||||
class _FakeBlocksChildren(_StrictFakeMixin):
|
||||
_component_name = "notion.blocks.children"
|
||||
|
||||
async def list(self, **kwargs: Any) -> dict[str, Any]:
|
||||
block_id = kwargs.get("block_id")
|
||||
start_cursor = kwargs.get("start_cursor")
|
||||
if start_cursor is not None:
|
||||
raise NotImplementedError(
|
||||
f"E2E Notion fake does not model block pagination cursor={start_cursor!r}."
|
||||
)
|
||||
|
||||
blocks = _FIXTURE.get("blocks", {}).get(block_id)
|
||||
if blocks is None:
|
||||
raise APIResponseError(
|
||||
f"Could not find block: {block_id}",
|
||||
status=404,
|
||||
code="object_not_found",
|
||||
body={"message": f"Could not find block: {block_id}"},
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"results": blocks,
|
||||
"has_more": False,
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
|
||||
class _FakeBlocks(_StrictFakeMixin):
|
||||
_component_name = "notion.blocks"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.children = _FakeBlocksChildren()
|
||||
|
||||
|
||||
class AsyncClient(_StrictFakeMixin):
|
||||
_component_name = "notion.AsyncClient"
|
||||
|
||||
def __init__(self, *, auth: str, **kwargs: Any):
|
||||
del kwargs
|
||||
if auth != _ACCESS_TOKEN:
|
||||
raise ValueError(f"Unexpected fake Notion auth token: {auth!r}")
|
||||
self.auth = auth
|
||||
self.blocks = _FakeBlocks()
|
||||
|
||||
async def search(self, **kwargs: Any) -> dict[str, Any]:
|
||||
unsupported = set(kwargs) - {"filter", "sort", "start_cursor"}
|
||||
if unsupported:
|
||||
raise NotImplementedError(
|
||||
f"E2E Notion fake search got unsupported kwargs: {sorted(unsupported)}"
|
||||
)
|
||||
if kwargs.get("start_cursor") is not None:
|
||||
raise NotImplementedError(
|
||||
f"E2E Notion fake does not model search cursor={kwargs['start_cursor']!r}."
|
||||
)
|
||||
expected_filter = {"value": "page", "property": "object"}
|
||||
if kwargs.get("filter") != expected_filter:
|
||||
raise NotImplementedError(
|
||||
f"E2E Notion fake search expected filter={expected_filter!r}, "
|
||||
f"got {kwargs.get('filter')!r}."
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"results": _FIXTURE.get("pages", []),
|
||||
"has_more": False,
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTokenResponse(_StrictFakeMixin):
|
||||
_component_name = "notion.oauth.response"
|
||||
|
||||
def __init__(self, payload: dict[str, Any], status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(payload, sort_keys=True)
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _FakeHttpxAsyncClient(_StrictFakeMixin):
|
||||
_component_name = "httpx.AsyncClient"
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
del args, kwargs
|
||||
|
||||
async def __aenter__(self) -> _FakeHttpxAsyncClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def post(self, url: str, **kwargs: Any) -> _FakeTokenResponse:
|
||||
if url != _TOKEN_URL:
|
||||
raise NotImplementedError(f"Unexpected Notion OAuth POST url={url!r}")
|
||||
|
||||
data = kwargs.get("json") or {}
|
||||
headers = kwargs.get("headers") or {}
|
||||
if "Authorization" not in headers:
|
||||
raise ValueError(
|
||||
"Notion OAuth token exchange missing Authorization header."
|
||||
)
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
if grant_type == "authorization_code":
|
||||
if data.get("code") != "fake-notion-oauth-code":
|
||||
raise ValueError(
|
||||
f"Unexpected fake Notion OAuth code: {data.get('code')!r}"
|
||||
)
|
||||
elif grant_type == "refresh_token":
|
||||
if data.get("refresh_token") != _REFRESH_TOKEN:
|
||||
raise ValueError(
|
||||
f"Unexpected fake Notion refresh token: {data.get('refresh_token')!r}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected fake Notion grant_type: {grant_type!r}")
|
||||
|
||||
return _FakeTokenResponse(
|
||||
{
|
||||
"access_token": _ACCESS_TOKEN,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"expires_in": 3600,
|
||||
"workspace_id": "fake-notion-workspace-001",
|
||||
"workspace_name": "SurfSense E2E Notion Workspace",
|
||||
"workspace_icon": "https://surfsense.example/notion-icon.png",
|
||||
"bot_id": "fake-notion-bot-001",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _FakeHttpxModule(_StrictFakeMixin):
|
||||
_component_name = "httpx"
|
||||
|
||||
AsyncClient = _FakeHttpxAsyncClient
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch production bindings that cannot be covered by sys.modules hijack."""
|
||||
targets = [
|
||||
(
|
||||
"app.routes.notion_add_connector_route.httpx",
|
||||
_FakeHttpxModule(),
|
||||
),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
190
surfsense_backend/tests/e2e/fakes/onedrive_graph.py
Normal file
190
surfsense_backend/tests/e2e/fakes/onedrive_graph.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Strict Microsoft OneDrive Graph fakes for Playwright E2E.
|
||||
|
||||
This module patches the OneDrive OAuth route and indexer consumer-site
|
||||
bindings. It keeps the production add/callback/indexing flow intact while
|
||||
serving deterministic Microsoft-shaped token, profile, metadata, and file
|
||||
content responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
from .binary_loader import _resolve_file_bytes
|
||||
|
||||
_FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
_ONEDRIVE_FIXTURE_PATH = _FIXTURES_DIR / "onedrive_files.json"
|
||||
|
||||
|
||||
def _load_onedrive_fixture() -> dict[str, Any]:
|
||||
with _ONEDRIVE_FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_ONEDRIVE_FIXTURE = _load_onedrive_fixture()
|
||||
|
||||
|
||||
class _StrictFakeMixin:
|
||||
_component_name: str = "<unknown>"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
f"E2E OneDrive fake missing surface: "
|
||||
f"{self._component_name}.{name!r}. Add it to "
|
||||
f"surfsense_backend/tests/e2e/fakes/onedrive_graph.py."
|
||||
)
|
||||
|
||||
|
||||
class _FakeOneDriveClient(_StrictFakeMixin):
|
||||
_component_name = "OneDriveClient"
|
||||
|
||||
def __init__(self, session: Any, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
|
||||
async def list_children(
|
||||
self, item_id: str = "root"
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
items = _ONEDRIVE_FIXTURE.get(item_id)
|
||||
if not isinstance(items, list):
|
||||
return [], f"E2E OneDrive fake has no children for item_id={item_id!r}."
|
||||
return [dict(item) for item in items], None
|
||||
|
||||
async def get_item_metadata(
|
||||
self, item_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
metadata = _onedrive_get_metadata(item_id)
|
||||
if metadata is None:
|
||||
return None, f"E2E OneDrive fake has no metadata for item_id={item_id!r}."
|
||||
return metadata, None
|
||||
|
||||
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
|
||||
content = _resolve_file_bytes(_ONEDRIVE_FIXTURE, item_id, _FIXTURES_DIR)
|
||||
if content is None:
|
||||
return None, f"E2E OneDrive fake has no content for item_id={item_id!r}."
|
||||
return content, None
|
||||
|
||||
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
|
||||
content = _resolve_file_bytes(_ONEDRIVE_FIXTURE, item_id, _FIXTURES_DIR)
|
||||
if content is None:
|
||||
return f"E2E OneDrive fake has no content for item_id={item_id!r}."
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(content)
|
||||
return None
|
||||
|
||||
async def get_delta(
|
||||
self, folder_id: str | None = None, delta_link: str | None = None
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
folder_key = folder_id or "root"
|
||||
if delta_link:
|
||||
folder_key = delta_link.rsplit("/", 1)[-1].removesuffix("-delta")
|
||||
if folder_key not in _ONEDRIVE_FIXTURE:
|
||||
return (
|
||||
[],
|
||||
None,
|
||||
f"E2E OneDrive fake has no delta for folder={folder_key!r}.",
|
||||
)
|
||||
return (
|
||||
[],
|
||||
f"https://graph.microsoft.com/v1.0/fake-delta/{folder_key}-delta",
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _FakeAsyncClient(_StrictFakeMixin):
|
||||
_component_name = "httpx.AsyncClient"
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
del args, kwargs
|
||||
|
||||
async def __aenter__(self) -> _FakeAsyncClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
del exc_type, exc, tb
|
||||
|
||||
async def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
|
||||
del args, kwargs
|
||||
if "login.microsoftonline.com" in url and url.endswith("/token"):
|
||||
return _json_response(
|
||||
"POST",
|
||||
url,
|
||||
{
|
||||
"access_token": "fake-onedrive-access-token",
|
||||
"refresh_token": "fake-onedrive-refresh-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
|
||||
},
|
||||
)
|
||||
raise NotImplementedError(f"E2E OneDrive fake unexpected POST URL: {url!r}")
|
||||
|
||||
async def get(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response:
|
||||
del args, kwargs
|
||||
if url == "https://graph.microsoft.com/v1.0/me":
|
||||
return _json_response(
|
||||
"GET",
|
||||
url,
|
||||
{
|
||||
"mail": "onedrive-e2e@surfsense.example",
|
||||
"userPrincipalName": "onedrive-e2e@surfsense.example",
|
||||
"displayName": "SurfSense OneDrive E2E",
|
||||
},
|
||||
)
|
||||
raise NotImplementedError(f"E2E OneDrive fake unexpected GET URL: {url!r}")
|
||||
|
||||
async def request(
|
||||
self, method: str, url: str, *args: Any, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
del args, kwargs
|
||||
raise NotImplementedError(
|
||||
f"E2E OneDrive fake unexpected request: {method!r} {url!r}"
|
||||
)
|
||||
|
||||
|
||||
class _FakeHttpxModule(_StrictFakeMixin):
|
||||
_component_name = "httpx"
|
||||
|
||||
AsyncClient = _FakeAsyncClient
|
||||
|
||||
|
||||
def _json_response(
|
||||
method: str, url: str, payload: dict[str, Any], status_code: int = 200
|
||||
) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
json=payload,
|
||||
request=httpx.Request(method, url),
|
||||
)
|
||||
|
||||
|
||||
def _onedrive_get_metadata(item_id: str | None) -> dict[str, Any] | None:
|
||||
for items in _ONEDRIVE_FIXTURE.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for entry in items:
|
||||
if entry.get("id") == item_id:
|
||||
return dict(entry)
|
||||
return None
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Patch production OneDrive bindings to use strict Graph fakes."""
|
||||
targets = [
|
||||
("app.routes.onedrive_add_connector_route.httpx", _FakeHttpxModule()),
|
||||
("app.routes.onedrive_add_connector_route.OneDriveClient", _FakeOneDriveClient),
|
||||
(
|
||||
"app.tasks.connector_indexers.onedrive_indexer.OneDriveClient",
|
||||
_FakeOneDriveClient,
|
||||
),
|
||||
("app.connectors.onedrive.client.httpx", _FakeHttpxModule()),
|
||||
]
|
||||
for target, replacement in targets:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
213
surfsense_backend/tests/e2e/fakes/slack_module.py
Normal file
213
surfsense_backend/tests/e2e/fakes/slack_module.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""Strict Slack MCP OAuth/tool fakes for Playwright E2E."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from tests.e2e.fakes import mcp_oauth_runtime, mcp_runtime
|
||||
|
||||
_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "slack_messages.json"
|
||||
|
||||
_AUTHORIZATION_URL = "https://slack.com/oauth/v2_user/authorize"
|
||||
_REGISTRATION_URL = "https://e2e-fake.invalid/mcp/slack-unused-register"
|
||||
_TOKEN_URL = "https://slack.com/api/oauth.v2.user.access"
|
||||
_MCP_URL = "https://mcp.slack.com/mcp"
|
||||
|
||||
_CLIENT_ID = "fake-slack-mcp-client-id"
|
||||
_CLIENT_SECRET = "fake-slack-mcp-client-secret"
|
||||
_ACCESS_TOKEN = "fake-slack-mcp-access-token"
|
||||
_REFRESH_TOKEN = "fake-slack-mcp-refresh-token"
|
||||
_OAUTH_CODE = "fake-slack-oauth-code"
|
||||
_SCOPE = (
|
||||
"search:read.public search:read.private search:read.mpim search:read.im "
|
||||
"channels:history groups:history mpim:history im:history"
|
||||
)
|
||||
|
||||
|
||||
def _load_fixture() -> dict[str, Any]:
|
||||
with _FIXTURE_PATH.open() as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
_FIXTURE = _load_fixture()
|
||||
|
||||
|
||||
async def _list_tools() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
tools=[
|
||||
SimpleNamespace(
|
||||
name="slack_search_channels",
|
||||
description="Search Slack channels visible to the authenticated user.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Text to search for in Slack channel names.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of channels to return.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="slack_read_channel",
|
||||
description="Read messages from a Slack channel.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Slack channel id.",
|
||||
}
|
||||
},
|
||||
"required": ["channel_id"],
|
||||
},
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="slack_read_thread",
|
||||
description="Read a Slack thread from a channel.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Slack channel id.",
|
||||
},
|
||||
"thread_ts": {
|
||||
"type": "string",
|
||||
"description": "Slack thread timestamp.",
|
||||
},
|
||||
},
|
||||
"required": ["channel_id", "thread_ts"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def _call_tool(
|
||||
tool_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> SimpleNamespace:
|
||||
arguments = arguments or {}
|
||||
channel = _FIXTURE["channel"]
|
||||
message = _FIXTURE["messages"][0]
|
||||
|
||||
if tool_name == "slack_search_channels":
|
||||
query = str(arguments.get("query", ""))
|
||||
if query and channel["name"].lower() not in query.lower():
|
||||
raise ValueError(f"Unexpected Slack channel query: {query!r}")
|
||||
text = (
|
||||
f"#{channel['name']} ({channel['id']})\n"
|
||||
f"purpose: {channel['purpose']}\n"
|
||||
f"latest_message: {message['text']}"
|
||||
)
|
||||
return SimpleNamespace(content=[SimpleNamespace(text=text)])
|
||||
|
||||
if tool_name in {"slack_read_channel", "slack_read_thread"}:
|
||||
raise NotImplementedError(
|
||||
f"Slack E2E fake does not exercise {tool_name!r}; "
|
||||
"extend slack_module.py before using it in a journey."
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unexpected Slack MCP tool call: {tool_name!r}")
|
||||
|
||||
|
||||
async def _fake_exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
if token_endpoint != _TOKEN_URL:
|
||||
return await mcp_oauth_runtime._fake_exchange_code_for_tokens(
|
||||
token_endpoint,
|
||||
code,
|
||||
redirect_uri,
|
||||
client_id,
|
||||
client_secret,
|
||||
code_verifier,
|
||||
timeout=timeout,
|
||||
)
|
||||
del timeout
|
||||
|
||||
if code != _OAUTH_CODE:
|
||||
raise ValueError(f"Unexpected fake Slack OAuth code: {code!r}")
|
||||
if "/api/v1/auth/mcp/slack/connector/callback" not in redirect_uri:
|
||||
raise ValueError(f"Unexpected Slack redirect_uri={redirect_uri!r}")
|
||||
if client_id != _CLIENT_ID or client_secret != _CLIENT_SECRET:
|
||||
raise ValueError(
|
||||
"Unexpected Slack client credentials: "
|
||||
f"client_id={client_id!r} client_secret={client_secret!r}"
|
||||
)
|
||||
if not code_verifier:
|
||||
raise ValueError("Slack token exchange missing code_verifier.")
|
||||
|
||||
team = _FIXTURE["team"]
|
||||
return {
|
||||
"ok": True,
|
||||
"scope": _SCOPE,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"authed_user": {
|
||||
"id": "U_FAKE_SLACK_USER",
|
||||
"scope": _SCOPE,
|
||||
"access_token": _ACCESS_TOKEN,
|
||||
"refresh_token": _REFRESH_TOKEN,
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
"team": {
|
||||
"id": team["id"],
|
||||
"name": team["name"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def install(active_patches: list[Any]) -> None:
|
||||
"""Register Slack MCP OAuth/tool handlers with the shared dispatchers."""
|
||||
mcp_oauth_runtime.register_service(
|
||||
mcp_url=_MCP_URL,
|
||||
discovery_metadata={
|
||||
"issuer": "https://slack.com",
|
||||
"authorization_endpoint": _AUTHORIZATION_URL,
|
||||
"token_endpoint": _TOKEN_URL,
|
||||
"registration_endpoint": _REGISTRATION_URL,
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"response_types_supported": ["code"],
|
||||
},
|
||||
client_id=_CLIENT_ID,
|
||||
client_secret=_CLIENT_SECRET,
|
||||
token_endpoint=_TOKEN_URL,
|
||||
registration_endpoint=_REGISTRATION_URL,
|
||||
oauth_code=_OAUTH_CODE,
|
||||
access_token=_ACCESS_TOKEN,
|
||||
refresh_token=_REFRESH_TOKEN,
|
||||
scope=_SCOPE,
|
||||
redirect_uri_substring="/api/v1/auth/mcp/slack/connector/callback",
|
||||
)
|
||||
mcp_runtime.register(
|
||||
url=_MCP_URL,
|
||||
expected_bearer=_ACCESS_TOKEN,
|
||||
list_tools=_list_tools,
|
||||
call_tool=_call_tool,
|
||||
)
|
||||
p = patch(
|
||||
"app.services.mcp_oauth.discovery.exchange_code_for_tokens",
|
||||
_fake_exchange_code_for_tokens,
|
||||
)
|
||||
p.start()
|
||||
active_patches.append(p)
|
||||
4
surfsense_backend/tests/e2e/middleware/__init__.py
Normal file
4
surfsense_backend/tests/e2e/middleware/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""Test-only middleware. Mounted on the FastAPI `app` object inside
|
||||
`tests/e2e/run_backend.py`, never registered by production startup
|
||||
(`python main.py`).
|
||||
"""
|
||||
54
surfsense_backend/tests/e2e/middleware/scenario.py
Normal file
54
surfsense_backend/tests/e2e/middleware/scenario.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""X-E2E-Scenario middleware.
|
||||
|
||||
Reads the X-E2E-Scenario request header and pipes the value into a
|
||||
ContextVar that the strict fakes consult to switch between happy-path
|
||||
and error scenarios on a per-request basis.
|
||||
|
||||
Mounted by tests/e2e/run_backend.py only. Production never adds this
|
||||
middleware, so production never reads the header.
|
||||
|
||||
Supported scenarios:
|
||||
- "happy" (default): everything succeeds with deterministic fixtures.
|
||||
- "denied": Composio.connected_accounts.initiate returns a redirect URL
|
||||
pointing at our callback with ?error=access_denied.
|
||||
- "auth_expired": GOOGLEDRIVE_LIST_FILES returns an authentication
|
||||
failure that the route translates to connector.config.auth_expired.
|
||||
- "duplicate": no special fake behavior; the duplicate path is exercised
|
||||
by running the OAuth flow twice with the same toolkit.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
_scenario: ContextVar[str] = ContextVar("e2e_scenario", default="happy")
|
||||
|
||||
|
||||
def current_scenario() -> str:
|
||||
"""Return the active E2E scenario for the current request context."""
|
||||
return _scenario.get()
|
||||
|
||||
|
||||
class ScenarioMiddleware(BaseHTTPMiddleware):
|
||||
"""Reads X-E2E-Scenario and exposes it via a ContextVar.
|
||||
|
||||
The header is also forwarded as state on the request so route
|
||||
handlers can branch if they ever need to (Composio routes do not).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
value = request.headers.get("X-E2E-Scenario", "happy")
|
||||
token = _scenario.set(value)
|
||||
try:
|
||||
request.state.e2e_scenario = value
|
||||
return await call_next(request)
|
||||
finally:
|
||||
_scenario.reset(token)
|
||||
247
surfsense_backend/tests/e2e/run_backend.py
Normal file
247
surfsense_backend/tests/e2e/run_backend.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""E2E backend entrypoint.
|
||||
|
||||
Hijacks third-party SDKs at sys.modules level BEFORE any production
|
||||
code is imported, then starts the same FastAPI app + uvicorn that
|
||||
`main.py` would run.
|
||||
|
||||
Production code is byte-identical with or without this file:
|
||||
- `python main.py` is the production entrypoint (unchanged).
|
||||
- `python tests/e2e/run_backend.py` is the test entrypoint, never imported by production.
|
||||
- `surfsense_backend/.dockerignore` excludes `tests/`, so this file
|
||||
physically does not exist in the production Docker image.
|
||||
|
||||
Defense in depth (see Composio Drive E2E Phase 1 plan):
|
||||
1. sys.modules hijack here (Composio).
|
||||
2. Strict __getattr__ inside fakes (NotImplementedError on unknown surface).
|
||||
3. Network deny-list set in CI env (HTTPS_PROXY=http://127.0.0.1:1
|
||||
plus sentinel API keys) so any leaked outbound HTTP fails loudly.
|
||||
|
||||
Usage:
|
||||
cd surfsense_backend
|
||||
uv run python tests/e2e/run_backend.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1) Hijack sys.modules BEFORE any production import.
|
||||
# Production: composio_service.py:11 does `from composio import Composio`.
|
||||
# With this hijack in place, that import resolves to our strict fake.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Make the surfsense_backend root importable as a top-level package so
|
||||
# `import tests.e2e.fakes...` works regardless of how the entrypoint is
|
||||
# invoked (uv run python tests/e2e/run_backend.py from repo root or from
|
||||
# surfsense_backend/).
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_BACKEND_ROOT = os.path.abspath(os.path.join(_THIS_DIR, "..", ".."))
|
||||
if _BACKEND_ROOT not in sys.path:
|
||||
sys.path.insert(0, _BACKEND_ROOT)
|
||||
|
||||
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
|
||||
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
|
||||
|
||||
sys.modules["composio"] = _fake_composio
|
||||
sys.modules["notion_client"] = _fake_notion
|
||||
sys.modules["notion_client.errors"] = _fake_notion.errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2) Standard logging + dotenv so the rest of the app behaves like main.py.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from dotenv import load_dotenv # noqa: E402
|
||||
|
||||
load_dotenv()
|
||||
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
|
||||
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
|
||||
os.environ.setdefault(
|
||||
"CONFLUENCE_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
|
||||
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
|
||||
os.environ.setdefault(
|
||||
"NOTION_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/notion/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
|
||||
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
|
||||
os.environ.setdefault(
|
||||
"ONEDRIVE_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
|
||||
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
|
||||
os.environ.setdefault(
|
||||
"DROPBOX_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
|
||||
)
|
||||
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
|
||||
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("surfsense.e2e.backend")
|
||||
logger.warning(
|
||||
"*** SURFSENSE E2E BACKEND ENTRYPOINT — fake Composio + LLM + embeddings ***"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3) Now import the production app. Every module in app.* loads here,
|
||||
# creating their bindings (some of which we will patch in step 4).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4) Patch LLM + embedding bindings at every consumer site.
|
||||
# Composio is already covered by the sys.modules hijack in step 1.
|
||||
# ---------------------------------------------------------------------------
|
||||
from unittest.mock import patch # noqa: E402
|
||||
|
||||
from app.app import app # noqa: E402
|
||||
from tests.e2e.fakes import ( # noqa: E402
|
||||
clickup_module as _fake_clickup_module,
|
||||
confluence_indexer as _fake_confluence_indexer,
|
||||
confluence_oauth as _fake_confluence_oauth,
|
||||
dropbox_api as _fake_dropbox_api,
|
||||
embeddings as _fake_embeddings,
|
||||
jira_module as _fake_jira_module,
|
||||
linear_module as _fake_linear_module,
|
||||
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
|
||||
mcp_runtime as _fake_mcp_runtime,
|
||||
native_google as _fake_native_google,
|
||||
notion_module as _fake_notion_module,
|
||||
onedrive_graph as _fake_onedrive_graph,
|
||||
slack_module as _fake_slack_module,
|
||||
)
|
||||
from tests.e2e.fakes.chat_llm import ( # noqa: E402
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
fake_create_chat_litellm_from_config,
|
||||
)
|
||||
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
||||
|
||||
_active_patches: list = []
|
||||
|
||||
|
||||
def _patch_llm_bindings() -> None:
|
||||
"""Replace LLM factories at every known binding site."""
|
||||
targets = [
|
||||
"app.services.llm_service.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.google_gmail_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.notion_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.onedrive_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.dropbox_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.local_folder_indexer.get_user_long_context_llm",
|
||||
"app.tasks.document_processors._save.get_user_long_context_llm",
|
||||
"app.tasks.document_processors.markdown_processor.get_user_long_context_llm",
|
||||
]
|
||||
for target in targets:
|
||||
try:
|
||||
p = patch(target, fake_get_user_long_context_llm)
|
||||
p.start()
|
||||
_active_patches.append(p)
|
||||
logger.info("[fake-llm] patched %s", target)
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
# Some indexers may not be loaded in every env. Log and move
|
||||
# on — but do not silently let a known binding through.
|
||||
logger.warning(
|
||||
"[fake-llm] could not patch %s: %s. If production code "
|
||||
"uses this path in E2E it will hit the real provider; "
|
||||
"update tests/e2e/run_backend.py.",
|
||||
target,
|
||||
exc,
|
||||
)
|
||||
|
||||
chat_targets = [
|
||||
(
|
||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
),
|
||||
(
|
||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
||||
fake_create_chat_litellm_from_config,
|
||||
),
|
||||
(
|
||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
),
|
||||
(
|
||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
||||
fake_create_chat_litellm_from_config,
|
||||
),
|
||||
]
|
||||
for target, replacement in chat_targets:
|
||||
try:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
_active_patches.append(p)
|
||||
logger.info("[fake-chat-llm] patched %s", target)
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
logger.warning("[fake-chat-llm] could not patch %s: %s.", target, exc)
|
||||
|
||||
|
||||
_patch_llm_bindings()
|
||||
_fake_embeddings.install(_active_patches)
|
||||
_fake_confluence_oauth.install(_active_patches)
|
||||
_fake_confluence_indexer.install(_active_patches)
|
||||
_fake_native_google.install(_active_patches)
|
||||
_fake_onedrive_graph.install(_active_patches)
|
||||
_fake_dropbox_api.install(_active_patches)
|
||||
_fake_notion_module.install(_active_patches)
|
||||
_fake_linear_module.install(_active_patches)
|
||||
_fake_jira_module.install(_active_patches)
|
||||
_fake_clickup_module.install(_active_patches)
|
||||
_fake_mcp_runtime.install(_active_patches)
|
||||
_fake_mcp_oauth_runtime.install(_active_patches)
|
||||
_fake_slack_module.install(_active_patches)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5) Mount test-only middleware. Production never reaches this code.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from tests.e2e.middleware.scenario import ScenarioMiddleware # noqa: E402
|
||||
|
||||
app.add_middleware(ScenarioMiddleware)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6) Start uvicorn, mirroring main.py's behaviour.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio # noqa: E402
|
||||
|
||||
import uvicorn # noqa: E402
|
||||
|
||||
|
||||
def _main() -> None:
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
host = os.getenv("UVICORN_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("UVICORN_PORT", "8000"))
|
||||
log_level = os.getenv("UVICORN_LOG_LEVEL", "info")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
reload=False,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
||||
214
surfsense_backend/tests/e2e/run_celery.py
Normal file
214
surfsense_backend/tests/e2e/run_celery.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
"""E2E Celery worker entrypoint.
|
||||
|
||||
Same sys.modules hijack + LLM/embedding patches as run_backend.py,
|
||||
applied before importing the production celery_app. Celery workers
|
||||
run in a separate Python interpreter, so the patches must be applied
|
||||
here too — they would NOT carry over from the FastAPI process.
|
||||
|
||||
Production is unaffected: celery_worker.py at the repo root is the
|
||||
production entrypoint and never imports this file.
|
||||
|
||||
Usage:
|
||||
cd surfsense_backend
|
||||
uv run python tests/e2e/run_celery.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_BACKEND_ROOT = os.path.abspath(os.path.join(_THIS_DIR, "..", ".."))
|
||||
if _BACKEND_ROOT not in sys.path:
|
||||
sys.path.insert(0, _BACKEND_ROOT)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1) Hijack sys.modules BEFORE production celery imports anything.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import tests.e2e.fakes.composio_module as _fake_composio # noqa: E402
|
||||
import tests.e2e.fakes.notion_module as _fake_notion # noqa: E402
|
||||
|
||||
sys.modules["composio"] = _fake_composio
|
||||
sys.modules["notion_client"] = _fake_notion
|
||||
sys.modules["notion_client.errors"] = _fake_notion.errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2) Logging + dotenv.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from dotenv import load_dotenv # noqa: E402
|
||||
|
||||
load_dotenv()
|
||||
os.environ.setdefault("ATLASSIAN_CLIENT_ID", "fake-atlassian-client-id")
|
||||
os.environ.setdefault("ATLASSIAN_CLIENT_SECRET", "fake-atlassian-client-secret")
|
||||
os.environ.setdefault(
|
||||
"CONFLUENCE_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/confluence/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("NOTION_CLIENT_ID", "fake-notion-client-id")
|
||||
os.environ.setdefault("NOTION_CLIENT_SECRET", "fake-notion-client-secret")
|
||||
os.environ.setdefault(
|
||||
"NOTION_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/notion/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("MICROSOFT_CLIENT_ID", "fake-microsoft-client-id")
|
||||
os.environ.setdefault("MICROSOFT_CLIENT_SECRET", "fake-microsoft-client-secret")
|
||||
os.environ.setdefault(
|
||||
"ONEDRIVE_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/onedrive/connector/callback",
|
||||
)
|
||||
os.environ.setdefault("DROPBOX_APP_KEY", "fake-dropbox-app-key")
|
||||
os.environ.setdefault("DROPBOX_APP_SECRET", "fake-dropbox-app-secret")
|
||||
os.environ.setdefault(
|
||||
"DROPBOX_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/auth/dropbox/connector/callback",
|
||||
)
|
||||
os.environ["SLACK_CLIENT_ID"] = "fake-slack-mcp-client-id"
|
||||
os.environ["SLACK_CLIENT_SECRET"] = "fake-slack-mcp-client-secret"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("surfsense.e2e.celery")
|
||||
logger.warning("*** SURFSENSE E2E CELERY WORKER — fake Composio + LLM + embeddings ***")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3) Import the production celery_app. All task modules load here.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4) Patch LLM + embedding bindings inside the worker process.
|
||||
# ---------------------------------------------------------------------------
|
||||
from unittest.mock import patch # noqa: E402
|
||||
|
||||
from app.celery_app import celery_app # noqa: E402
|
||||
from tests.e2e.fakes import ( # noqa: E402
|
||||
clickup_module as _fake_clickup_module,
|
||||
confluence_indexer as _fake_confluence_indexer,
|
||||
confluence_oauth as _fake_confluence_oauth,
|
||||
dropbox_api as _fake_dropbox_api,
|
||||
embeddings as _fake_embeddings,
|
||||
jira_module as _fake_jira_module,
|
||||
linear_module as _fake_linear_module,
|
||||
mcp_oauth_runtime as _fake_mcp_oauth_runtime,
|
||||
mcp_runtime as _fake_mcp_runtime,
|
||||
native_google as _fake_native_google,
|
||||
notion_module as _fake_notion_module,
|
||||
onedrive_graph as _fake_onedrive_graph,
|
||||
slack_module as _fake_slack_module,
|
||||
)
|
||||
from tests.e2e.fakes.chat_llm import ( # noqa: E402
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
fake_create_chat_litellm_from_config,
|
||||
)
|
||||
from tests.e2e.fakes.llm import fake_get_user_long_context_llm # noqa: E402
|
||||
|
||||
_active_patches: list = []
|
||||
|
||||
|
||||
def _patch_llm_bindings() -> None:
|
||||
targets = [
|
||||
"app.services.llm_service.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.confluence_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.google_drive_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.google_gmail_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.notion_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.onedrive_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.dropbox_indexer.get_user_long_context_llm",
|
||||
"app.tasks.connector_indexers.local_folder_indexer.get_user_long_context_llm",
|
||||
"app.tasks.document_processors._save.get_user_long_context_llm",
|
||||
"app.tasks.document_processors.markdown_processor.get_user_long_context_llm",
|
||||
]
|
||||
for target in targets:
|
||||
try:
|
||||
p = patch(target, fake_get_user_long_context_llm)
|
||||
p.start()
|
||||
_active_patches.append(p)
|
||||
logger.info("[fake-llm] patched %s in celery worker", target)
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
logger.warning(
|
||||
"[fake-llm] could not patch %s in celery worker: %s.",
|
||||
target,
|
||||
exc,
|
||||
)
|
||||
|
||||
chat_targets = [
|
||||
(
|
||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
),
|
||||
(
|
||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
||||
fake_create_chat_litellm_from_config,
|
||||
),
|
||||
(
|
||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
||||
fake_create_chat_litellm_from_agent_config,
|
||||
),
|
||||
(
|
||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
||||
fake_create_chat_litellm_from_config,
|
||||
),
|
||||
]
|
||||
for target, replacement in chat_targets:
|
||||
try:
|
||||
p = patch(target, replacement)
|
||||
p.start()
|
||||
_active_patches.append(p)
|
||||
logger.info("[fake-chat-llm] patched %s in celery worker", target)
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
logger.warning(
|
||||
"[fake-chat-llm] could not patch %s in celery worker: %s.",
|
||||
target,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
_patch_llm_bindings()
|
||||
_fake_embeddings.install(_active_patches)
|
||||
_fake_confluence_oauth.install(_active_patches)
|
||||
_fake_confluence_indexer.install(_active_patches)
|
||||
_fake_native_google.install(_active_patches)
|
||||
_fake_onedrive_graph.install(_active_patches)
|
||||
_fake_dropbox_api.install(_active_patches)
|
||||
_fake_notion_module.install(_active_patches)
|
||||
_fake_linear_module.install(_active_patches)
|
||||
_fake_jira_module.install(_active_patches)
|
||||
_fake_clickup_module.install(_active_patches)
|
||||
_fake_mcp_runtime.install(_active_patches)
|
||||
_fake_mcp_oauth_runtime.install(_active_patches)
|
||||
_fake_slack_module.install(_active_patches)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5) Start the worker.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _main() -> None:
|
||||
# Default queues mirror production (default queue + connectors queue
|
||||
# so Drive indexing tasks are picked up).
|
||||
queue_name = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
|
||||
queues = f"{queue_name},{queue_name}.connectors"
|
||||
celery_app.worker_main(
|
||||
argv=[
|
||||
"worker",
|
||||
"--loglevel=info",
|
||||
f"--queues={queues}",
|
||||
"--concurrency=2",
|
||||
"--without-gossip",
|
||||
"--without-mingle",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
||||
|
|
@ -367,18 +367,26 @@ class TestPersistUserTurn:
|
|||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""The full ``{id, title, document_type}`` triple forwarded by
|
||||
the FE must round-trip into a single ``mentioned-documents``
|
||||
ContentPart on the persisted user message — the history loader
|
||||
renders the chips on reload from this part directly.
|
||||
"""The full ``{id, title, document_type, kind}`` chip metadata
|
||||
forwarded by the FE must round-trip into a single
|
||||
``mentioned-documents`` ContentPart on the persisted user
|
||||
message — the history loader renders the chips on reload from
|
||||
this part directly. Folder chips ride alongside doc chips so
|
||||
the FE can render mixed mention bars without a second fetch.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:8200"
|
||||
|
||||
mentioned = [
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
|
||||
{
|
||||
"id": 33,
|
||||
"title": "Reports",
|
||||
"document_type": "FOLDER",
|
||||
"kind": "folder",
|
||||
},
|
||||
]
|
||||
msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
|
|
@ -397,8 +405,61 @@ class TestPersistUserTurn:
|
|||
assert row.content[1] == {
|
||||
"type": "mentioned-documents",
|
||||
"documents": [
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
|
||||
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
|
||||
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
|
||||
{
|
||||
"id": 33,
|
||||
"title": "Reports",
|
||||
"document_type": "FOLDER",
|
||||
"kind": "folder",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async def test_legacy_chip_without_kind_defaults_to_doc(
|
||||
self,
|
||||
db_session,
|
||||
db_user,
|
||||
db_thread,
|
||||
patched_shielded_session,
|
||||
):
|
||||
"""Pre-folder clients send chips without ``kind``. The persistence
|
||||
layer defaults them to ``"doc"`` so the round-trip stays
|
||||
consistent on reload — the FE schema's optional default
|
||||
produces the same value, but persisting it explicitly keeps
|
||||
the DB row self-describing.
|
||||
"""
|
||||
thread_id = db_thread.id
|
||||
user_id_str = str(db_user.id)
|
||||
turn_id = f"{thread_id}:8201"
|
||||
|
||||
mentioned = [
|
||||
{"id": 77, "title": "Legacy", "document_type": "GENERAL"},
|
||||
]
|
||||
msg_id = await persist_user_turn(
|
||||
chat_id=thread_id,
|
||||
user_id=user_id_str,
|
||||
turn_id=turn_id,
|
||||
user_query="hi",
|
||||
mentioned_documents=mentioned,
|
||||
)
|
||||
assert isinstance(msg_id, int)
|
||||
|
||||
row = await db_session.get(NewChatMessage, msg_id)
|
||||
assert row is not None
|
||||
assert isinstance(row.content, list)
|
||||
mentioned_part = next(
|
||||
p for p in row.content if p.get("type") == "mentioned-documents"
|
||||
)
|
||||
assert mentioned_part == {
|
||||
"type": "mentioned-documents",
|
||||
"documents": [
|
||||
{
|
||||
"id": 77,
|
||||
"title": "Legacy",
|
||||
"document_type": "GENERAL",
|
||||
"kind": "doc",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
1
surfsense_backend/tests/integration/composio/__init__.py
Normal file
1
surfsense_backend/tests/integration/composio/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Integration tests for Composio connector routes."""
|
||||
90
surfsense_backend/tests/integration/composio/conftest.py
Normal file
90
surfsense_backend/tests/integration/composio/conftest.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Composio route integration fixtures.
|
||||
|
||||
The sys.modules hijack happens at module import time, before importing
|
||||
app.app, so production `from composio import Composio` bindings resolve to
|
||||
the strict E2E fake in this pytest process too.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from tests.e2e.fakes import composio_module as _fake_composio
|
||||
|
||||
sys.modules["composio"] = _fake_composio
|
||||
|
||||
from app.app import app, limiter # noqa: E402
|
||||
from app.config import config # noqa: E402
|
||||
from app.db import ( # noqa: E402
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
limiter.enabled = False
|
||||
config.COMPOSIO_ENABLED = True
|
||||
config.COMPOSIO_API_KEY = "e2e-integration-composio-sentinel"
|
||||
config.NEXT_FRONTEND_URL = "http://localhost:3000"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
async def override_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_user() -> User:
|
||||
return db_user
|
||||
|
||||
previous_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[get_async_session] = override_session
|
||||
app.dependency_overrides[current_active_user] = override_user
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test",
|
||||
timeout=30.0,
|
||||
follow_redirects=False,
|
||||
) as test_client:
|
||||
yield test_client
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
app.dependency_overrides.update(previous_overrides)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def drive_connector(
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space,
|
||||
) -> SearchSourceConnector:
|
||||
connector = SearchSourceConnector(
|
||||
name="Google Drive (Composio) - e2e-fake@surfsense.example",
|
||||
connector_type=SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
config={
|
||||
"composio_connected_account_id": "fake-acct-googledrive-existing",
|
||||
"toolkit_id": "googledrive",
|
||||
"toolkit_name": "Google Drive",
|
||||
"is_indexable": True,
|
||||
},
|
||||
search_space_id=db_search_space.id,
|
||||
user_id=db_user.id,
|
||||
)
|
||||
db_session.add(connector)
|
||||
await db_session.flush()
|
||||
return connector
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSourceConnector
|
||||
from tests.e2e.fakes import composio_module
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_root_listing_returns_canned_items(
|
||||
client: httpx.AsyncClient,
|
||||
drive_connector: SearchSourceConnector,
|
||||
):
|
||||
response = await client.get(
|
||||
f"/api/v1/connectors/{drive_connector.id}/composio-drive/folders"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
items = response.json()["items"]
|
||||
names = {item["name"] for item in items}
|
||||
|
||||
assert "Projects" in names
|
||||
assert "e2e-canary.txt" in names
|
||||
assert any(
|
||||
item["id"] == "fake-folder-projects" and item["isFolder"] is True
|
||||
for item in items
|
||||
)
|
||||
|
||||
|
||||
async def test_save_round_trips_selected_files(
|
||||
client: httpx.AsyncClient,
|
||||
db_session: AsyncSession,
|
||||
drive_connector: SearchSourceConnector,
|
||||
):
|
||||
selected_files = [
|
||||
{
|
||||
"id": "fake-file-canary",
|
||||
"name": "e2e-canary.txt",
|
||||
"mimeType": "text/plain",
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/search-source-connectors/{drive_connector.id}",
|
||||
json={
|
||||
"config": {
|
||||
"selected_folders": [],
|
||||
"selected_files": selected_files,
|
||||
"indexing_options": {
|
||||
"max_files_per_folder": 10,
|
||||
"incremental_sync": False,
|
||||
"include_subfolders": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
await db_session.refresh(drive_connector)
|
||||
assert drive_connector.config["selected_files"] == selected_files
|
||||
assert drive_connector.config["selected_folders"] == []
|
||||
|
||||
|
||||
async def test_auth_expired_error_classifies_and_flags_connector(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
client: httpx.AsyncClient,
|
||||
db_session: AsyncSession,
|
||||
drive_connector: SearchSourceConnector,
|
||||
):
|
||||
def raise_auth_expired(
|
||||
self: Any,
|
||||
*,
|
||||
slug: str,
|
||||
connected_account_id: str,
|
||||
user_id: str | None = None,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
dangerously_skip_version_check: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
raise RuntimeError(
|
||||
"Token has been expired or revoked. (HTTP 401: invalid_grant)"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(composio_module._Tools, "execute", raise_auth_expired)
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/connectors/{drive_connector.id}/composio-drive/folders"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
body = response.text.lower()
|
||||
assert "authentication" in body
|
||||
assert "expired" in body
|
||||
|
||||
await db_session.refresh(drive_connector)
|
||||
assert drive_connector.config["auth_expired"] is True
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
SearchSpace,
|
||||
User,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _state_for(space_id: int, user_id: UUID, toolkit_id: str = "googledrive") -> str:
|
||||
return OAuthStateManager(config.SECRET_KEY).generate_secure_state(
|
||||
space_id=space_id,
|
||||
user_id=user_id,
|
||||
toolkit_id=toolkit_id,
|
||||
)
|
||||
|
||||
|
||||
async def _drive_connectors(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
search_space_id: int,
|
||||
) -> list[SearchSourceConnector]:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def test_callback_with_error_param_redirects_to_denied_page(
|
||||
client: httpx.AsyncClient,
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
state = _state_for(db_search_space.id, db_user.id)
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/auth/composio/connector/callback?state={state}&error=access_denied"
|
||||
)
|
||||
|
||||
assert response.status_code in {302, 303, 307}
|
||||
location = response.headers["location"]
|
||||
assert (
|
||||
f"/dashboard/{db_search_space.id}/connectors/callback?"
|
||||
"error=composio_oauth_denied"
|
||||
) in location
|
||||
|
||||
connectors = await _drive_connectors(
|
||||
db_session,
|
||||
user_id=db_user.id,
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
assert connectors == []
|
||||
|
||||
|
||||
async def test_second_oauth_for_same_toolkit_takes_reconnection_branch(
|
||||
client: httpx.AsyncClient,
|
||||
db_session: AsyncSession,
|
||||
db_user: User,
|
||||
db_search_space: SearchSpace,
|
||||
):
|
||||
first_state = _state_for(db_search_space.id, db_user.id)
|
||||
|
||||
first_response = await client.get(
|
||||
"/api/v1/auth/composio/connector/callback"
|
||||
f"?state={first_state}&connectedAccountId=fake-acct-googledrive-first"
|
||||
)
|
||||
|
||||
assert first_response.status_code in {302, 303, 307}
|
||||
first_connectors = await _drive_connectors(
|
||||
db_session,
|
||||
user_id=db_user.id,
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
assert len(first_connectors) == 1
|
||||
first_connector = first_connectors[0]
|
||||
assert first_connector.config["composio_connected_account_id"] == (
|
||||
"fake-acct-googledrive-first"
|
||||
)
|
||||
|
||||
second_state = _state_for(db_search_space.id, db_user.id)
|
||||
second_response = await client.get(
|
||||
"/api/v1/auth/composio/connector/callback"
|
||||
f"?state={second_state}&connectedAccountId=fake-acct-googledrive-second"
|
||||
)
|
||||
|
||||
assert second_response.status_code in {302, 303, 307}
|
||||
second_connectors = await _drive_connectors(
|
||||
db_session,
|
||||
user_id=db_user.id,
|
||||
search_space_id=db_search_space.id,
|
||||
)
|
||||
assert len(second_connectors) == 1
|
||||
assert second_connectors[0].id == first_connector.id
|
||||
assert second_connectors[0].config["composio_connected_account_id"] == (
|
||||
"fake-acct-googledrive-second"
|
||||
)
|
||||
|
|
@ -1,13 +1,26 @@
|
|||
"""Integration tests: Drive indexer credential resolution for Composio vs native connectors.
|
||||
"""Integration tests: Drive indexer client + credential resolution.
|
||||
|
||||
Exercises ``index_google_drive_files`` with a real PostgreSQL database
|
||||
containing seeded connector records. Google API and Composio SDK are
|
||||
mocked at their system boundaries.
|
||||
Locks in the post-cea8618 architectural contract:
|
||||
|
||||
- Composio Drive connectors MUST use ``ComposioDriveClient`` (which routes
|
||||
through ``composio.tools.execute``) and MUST NOT depend on a raw OAuth
|
||||
access token via ``ComposioService.get_access_token``.
|
||||
- Native Drive connectors MUST continue to use ``GoogleDriveClient`` with
|
||||
credentials loaded from the connector config (no Composio involvement).
|
||||
- Composio Drive connectors missing ``composio_connected_account_id`` MUST
|
||||
short-circuit with a clear error before any client is constructed.
|
||||
|
||||
Background: prior to ``cea8618`` the Composio path used
|
||||
``build_composio_credentials → GoogleDriveClient``. That broke in production
|
||||
once Composio's "Mask Connected Account Secrets" project toggle was on,
|
||||
because the masked token failed the ``len(access_token) >= 20`` guard in
|
||||
``ComposioService.get_access_token``. The structural assertions here make
|
||||
any future regression to that token-based path fail at PR time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -25,6 +38,19 @@ pytestmark = pytest.mark.integration
|
|||
|
||||
_COMPOSIO_ACCOUNT_ID = "composio-test-account-123"
|
||||
_INDEXER_MODULE = "app.tasks.connector_indexers.google_drive_indexer"
|
||||
_GET_ACCESS_TOKEN = "app.services.composio_service.ComposioService.get_access_token"
|
||||
|
||||
|
||||
def _mock_drive_client(*, list_files_return: tuple = ([], None, None)) -> MagicMock:
|
||||
"""Duck-typed client mock whose ``list_files`` yields the supplied tuple.
|
||||
|
||||
Returning an empty file list short-circuits the indexer's full-scan
|
||||
loop after the first page so the test exercises only the
|
||||
construction + listing path, not download / ETL / DB writes.
|
||||
"""
|
||||
mock = MagicMock()
|
||||
mock.list_files = AsyncMock(return_value=list_files_return)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
@ -69,25 +95,31 @@ async def committed_composio_no_account_id(async_engine):
|
|||
await cleanup_space(async_engine, data["search_space_id"])
|
||||
|
||||
|
||||
@patch(_GET_ACCESS_TOKEN)
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_composio_connector_uses_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_client_cls,
|
||||
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
|
||||
async def test_composio_drive_indexer_uses_composio_drive_client(
|
||||
mock_composio_client_cls,
|
||||
mock_native_client_cls,
|
||||
mock_task_logger_cls,
|
||||
mock_get_access_token,
|
||||
async_engine,
|
||||
committed_drive_connector,
|
||||
):
|
||||
"""Drive indexer calls build_composio_credentials for a Composio connector
|
||||
and passes the result to GoogleDriveClient."""
|
||||
"""Composio Drive must construct ComposioDriveClient and never read raw tokens.
|
||||
|
||||
Reverting to the pre-cea8618 ``build_composio_credentials → GoogleDriveClient``
|
||||
path would either trip ``mock_native_client_cls.assert_not_called()`` (because
|
||||
GoogleDriveClient would be constructed) or ``mock_get_access_token.assert_not_called()``
|
||||
(because the credential builder reads the raw token).
|
||||
"""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
||||
data = committed_drive_connector
|
||||
mock_creds = MagicMock(name="composio-credentials")
|
||||
mock_build_creds.return_value = mock_creds
|
||||
mock_composio_client_cls.return_value = _mock_drive_client()
|
||||
mock_task_logger_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
|
|
@ -100,21 +132,29 @@ async def test_composio_connector_uses_composio_credentials(
|
|||
folder_id="test-folder-id",
|
||||
)
|
||||
|
||||
mock_build_creds.assert_called_once_with(_COMPOSIO_ACCOUNT_ID)
|
||||
mock_client_cls.assert_called_once()
|
||||
_, kwargs = mock_client_cls.call_args
|
||||
assert kwargs.get("credentials") is mock_creds
|
||||
mock_composio_client_cls.assert_called_once_with(
|
||||
ANY,
|
||||
data["connector_id"],
|
||||
_COMPOSIO_ACCOUNT_ID,
|
||||
entity_id=f"surfsense_{data['user_id']}",
|
||||
)
|
||||
mock_native_client_cls.assert_not_called()
|
||||
mock_get_access_token.assert_not_called()
|
||||
|
||||
|
||||
@patch(_GET_ACCESS_TOKEN)
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
|
||||
async def test_composio_connector_without_account_id_returns_error(
|
||||
mock_build_creds,
|
||||
mock_composio_client_cls,
|
||||
mock_native_client_cls,
|
||||
mock_task_logger_cls,
|
||||
mock_get_access_token,
|
||||
async_engine,
|
||||
committed_composio_no_account_id,
|
||||
):
|
||||
"""Drive indexer returns an error when Composio connector lacks connected_account_id."""
|
||||
"""Missing ``composio_connected_account_id`` must short-circuit before any client construction."""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
|
@ -134,28 +174,32 @@ async def test_composio_connector_without_account_id_returns_error(
|
|||
|
||||
assert count == 0
|
||||
assert error is not None
|
||||
assert (
|
||||
"composio_connected_account_id" in error.lower() or "composio" in error.lower()
|
||||
)
|
||||
mock_build_creds.assert_not_called()
|
||||
assert "composio" in error.lower()
|
||||
assert "connected_account_id" in error.lower()
|
||||
mock_composio_client_cls.assert_not_called()
|
||||
mock_native_client_cls.assert_not_called()
|
||||
mock_get_access_token.assert_not_called()
|
||||
|
||||
|
||||
@patch(_GET_ACCESS_TOKEN)
|
||||
@patch(f"{_INDEXER_MODULE}.TaskLoggingService")
|
||||
@patch(f"{_INDEXER_MODULE}.ComposioDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.GoogleDriveClient")
|
||||
@patch(f"{_INDEXER_MODULE}.build_composio_credentials")
|
||||
async def test_native_connector_does_not_use_composio_credentials(
|
||||
mock_build_creds,
|
||||
mock_client_cls,
|
||||
async def test_native_connector_uses_google_drive_client(
|
||||
mock_native_client_cls,
|
||||
mock_composio_client_cls,
|
||||
mock_task_logger_cls,
|
||||
mock_get_access_token,
|
||||
async_engine,
|
||||
committed_native_drive_connector,
|
||||
):
|
||||
"""Drive indexer does NOT call build_composio_credentials for a native connector."""
|
||||
"""Native Drive connector must use GoogleDriveClient (no Composio involvement at all)."""
|
||||
from app.tasks.connector_indexers.google_drive_indexer import (
|
||||
index_google_drive_files,
|
||||
)
|
||||
|
||||
data = committed_native_drive_connector
|
||||
mock_native_client_cls.return_value = _mock_drive_client()
|
||||
mock_task_logger_cls.return_value = mock_task_logger()
|
||||
|
||||
maker = make_session_factory(async_engine)
|
||||
|
|
@ -168,7 +212,6 @@ async def test_native_connector_does_not_use_composio_credentials(
|
|||
folder_id="test-folder-id",
|
||||
)
|
||||
|
||||
mock_build_creds.assert_not_called()
|
||||
mock_client_cls.assert_called_once()
|
||||
_, kwargs = mock_client_cls.call_args
|
||||
assert kwargs.get("credentials") is None
|
||||
mock_native_client_cls.assert_called_once_with(ANY, data["connector_id"])
|
||||
mock_composio_client_cls.assert_not_called()
|
||||
mock_get_access_token.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for the @-mention resolver.
|
||||
|
||||
These tests pin down the contract that ``mention_resolver`` is the
|
||||
single seam between ``MentionedDocumentInfo`` chips (frontend) and the
|
||||
canonical ``/documents/...`` virtual paths (agent). The streaming task,
|
||||
priority middleware, and persistence layer all consume the resolver's
|
||||
output — keeping the tests focused on substitute-in-text + the
|
||||
returned id partition keeps the seam stable across refactors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat import mention_resolver
|
||||
from app.agents.new_chat.mention_resolver import (
|
||||
ResolvedMention,
|
||||
ResolvedMentionSet,
|
||||
resolve_mentions,
|
||||
substitute_in_text,
|
||||
)
|
||||
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestSubstituteInText:
|
||||
"""``substitute_in_text`` is a pure string transform and is exercised
|
||||
on every cloud-mode turn, so it has to be both fast and behaviour-
|
||||
identical to the frontend's ``parseMentionSegments`` (longest-token
|
||||
first, single forward pass)."""
|
||||
|
||||
def test_returns_text_unchanged_when_no_tokens(self):
|
||||
assert substitute_in_text("hello @foo", []) == "hello @foo"
|
||||
|
||||
def test_returns_text_unchanged_when_empty(self):
|
||||
assert substitute_in_text("", [("@x", "/documents/x.xml")]) == ""
|
||||
|
||||
def test_replaces_single_token_with_backtick_path(self):
|
||||
out = substitute_in_text(
|
||||
"see @notes please",
|
||||
[("@notes", "/documents/notes.xml")],
|
||||
)
|
||||
assert out == "see `/documents/notes.xml` please"
|
||||
|
||||
def test_longest_token_wins_over_prefix(self):
|
||||
# ``@Project Roadmap`` must NOT be partially matched by ``@Project``.
|
||||
# Mirrors the FE's parseMentionSegments contract.
|
||||
token_to_path = [
|
||||
("@Project Roadmap", "/documents/Roadmap.xml"),
|
||||
("@Project", "/documents/Project.xml"),
|
||||
]
|
||||
out = substitute_in_text("about @Project Roadmap today", token_to_path)
|
||||
assert out == "about `/documents/Roadmap.xml` today"
|
||||
|
||||
def test_handles_repeated_mentions(self):
|
||||
out = substitute_in_text(
|
||||
"@A and @A again @B",
|
||||
[
|
||||
("@A", "/documents/a.xml"),
|
||||
("@B", "/documents/b.xml"),
|
||||
],
|
||||
)
|
||||
assert (
|
||||
out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`"
|
||||
)
|
||||
|
||||
def test_does_not_match_inside_word(self):
|
||||
# Substitution is positional — there's no word-boundary semantics.
|
||||
# ``@Pro`` inside ``foo@Project`` still matches; this is the same
|
||||
# behaviour as parseMentionSegments. The test pins it so a
|
||||
# future "fix" doesn't accidentally diverge between FE/BE.
|
||||
out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")])
|
||||
assert out == "foo`/documents/p.xml`"
|
||||
|
||||
def test_idempotent_after_substitution(self):
|
||||
# The output starts with a backtick, not ``@``, so re-running
|
||||
# the substitution leaves it alone.
|
||||
once = substitute_in_text("@A", [("@A", "/documents/a.xml")])
|
||||
twice = substitute_in_text(once, [("@A", "/documents/a.xml")])
|
||||
assert once == twice
|
||||
|
||||
|
||||
class TestResolveMentions:
|
||||
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
|
||||
a ``ResolvedMentionSet`` whose id partitions feed
|
||||
``KnowledgePriorityMiddleware``."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_mentions(self):
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
result = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=None,
|
||||
)
|
||||
assert isinstance(result, ResolvedMentionSet)
|
||||
assert result.mentions == []
|
||||
assert result.token_to_path == []
|
||||
assert result.mentioned_document_ids == []
|
||||
assert result.mentioned_folder_ids == []
|
||||
# No DB roundtrips when there's nothing to resolve.
|
||||
session.execute.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=42,
|
||||
title="Notes",
|
||||
document_type="EXTENSION",
|
||||
kind="doc",
|
||||
)
|
||||
doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [doc_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=5,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert len(out.mentions) == 1
|
||||
mention = out.mentions[0]
|
||||
assert mention.kind == "doc"
|
||||
assert mention.id == 42
|
||||
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml"
|
||||
assert out.mentioned_document_ids == [42]
|
||||
assert out.mentioned_folder_ids == []
|
||||
assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=9,
|
||||
title="Reports",
|
||||
document_type="FOLDER",
|
||||
kind="folder",
|
||||
)
|
||||
folder_row = SimpleNamespace(id=9, name="Reports")
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"})
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [folder_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=3,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert len(out.mentions) == 1
|
||||
mention = out.mentions[0]
|
||||
assert mention.kind == "folder"
|
||||
assert mention.id == 9
|
||||
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/"
|
||||
assert out.mentioned_document_ids == []
|
||||
assert out.mentioned_folder_ids == [9]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drops_chip_when_doc_is_missing(self, monkeypatch):
|
||||
chip = MentionedDocumentInfo(
|
||||
id=99, title="ghost", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = []
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=[chip],
|
||||
)
|
||||
assert out.mentions == []
|
||||
assert out.mentioned_document_ids == []
|
||||
assert out.token_to_path == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_to_path_is_longest_first(self, monkeypatch):
|
||||
# Two chips whose titles are prefixes of each other — the
|
||||
# resolver MUST sort longest-first so substitution doesn't
|
||||
# break the ``@Project Roadmap`` vs ``@Project`` invariant.
|
||||
chip_short = MentionedDocumentInfo(
|
||||
id=1, title="A", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
chip_long = MentionedDocumentInfo(
|
||||
id=2, title="A long one", document_type="EXTENSION", kind="doc"
|
||||
)
|
||||
rows = [
|
||||
SimpleNamespace(id=1, title="A", folder_id=None),
|
||||
SimpleNamespace(id=2, title="A long one", folder_id=None),
|
||||
]
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = rows
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=1,
|
||||
mentioned_documents=[chip_short, chip_long],
|
||||
)
|
||||
tokens = [tok for tok, _ in out.token_to_path]
|
||||
assert tokens == sorted(tokens, key=len, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch):
|
||||
# ``mentioned_document_ids`` (the legacy parallel array) must
|
||||
# still resolve when no chip metadata is available — covers
|
||||
# callers that haven't migrated to the discriminated chip list.
|
||||
doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None)
|
||||
|
||||
async def fake_build_index(_session, _ssid):
|
||||
return PathIndex()
|
||||
|
||||
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.all.return_value = [doc_row]
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
|
||||
out = await resolve_mentions(
|
||||
session,
|
||||
search_space_id=2,
|
||||
mentioned_documents=None,
|
||||
mentioned_document_ids=[7],
|
||||
)
|
||||
assert out.mentioned_document_ids == [7]
|
||||
assert len(out.mentions) == 1
|
||||
assert out.mentions[0].title == "Legacy"
|
||||
|
||||
|
||||
class TestResolvedMentionEquality:
|
||||
"""Smoke check on the dataclass behaviour we rely on for asserting
|
||||
test outputs."""
|
||||
|
||||
def test_equal_when_fields_equal(self):
|
||||
a = ResolvedMention(
|
||||
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
|
||||
)
|
||||
b = ResolvedMention(
|
||||
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
|
||||
)
|
||||
assert a == b
|
||||
|
|
@ -196,3 +196,50 @@ class TestVirtualPathToDoc:
|
|||
)
|
||||
assert document is target_doc
|
||||
assert session.execute.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_double_extension_for_uploaded_pdf(self):
|
||||
# Regression: the agent renders every KB document under
|
||||
# ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``),
|
||||
# so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as
|
||||
# ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path
|
||||
# must round-trip back to the row even though the title itself
|
||||
# does NOT end in ``.xml``.
|
||||
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml",
|
||||
)
|
||||
assert document is target_doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_path_without_xml_suffix(self):
|
||||
# The user (or a hand-edited link) may pass the title-only form
|
||||
# ``/documents/2025-W2.pdf``. The resolver must still find the row
|
||||
# by literal title equality.
|
||||
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
session.execute = AsyncMock(
|
||||
side_effect=[
|
||||
_result_from_one(None),
|
||||
_result_from_scalars([target_doc]),
|
||||
]
|
||||
)
|
||||
|
||||
document = await virtual_path_to_doc(
|
||||
session,
|
||||
search_space_id=5,
|
||||
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf",
|
||||
)
|
||||
assert document is target_doc
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from tests.e2e.fakes.composio_module import _drive_list_files
|
||||
|
||||
|
||||
def _ids(result: dict) -> set[str]:
|
||||
return {item["id"] for item in result["data"]["files"]}
|
||||
|
||||
|
||||
def test_drive_list_files_filters_shortcuts_and_trashed_items():
|
||||
result = _drive_list_files(
|
||||
{
|
||||
"q": (
|
||||
"'root' in parents and trashed = false and "
|
||||
"mimeType != 'application/vnd.google-apps.shortcut'"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
ids = _ids(result)
|
||||
|
||||
assert "fake-file-canary" in ids
|
||||
assert "fake-shortcut-canary" not in ids
|
||||
assert "fake-file-trashed" not in ids
|
||||
|
||||
|
||||
def test_drive_list_files_filters_to_exact_mime_type():
|
||||
result = _drive_list_files(
|
||||
{"q": "'root' in parents and trashed = false and mimeType = 'text/plain'"}
|
||||
)
|
||||
|
||||
assert _ids(result) == {"fake-file-canary"}
|
||||
|
||||
|
||||
def test_drive_list_files_uses_requested_parent_folder():
|
||||
result = _drive_list_files(
|
||||
{"q": "'fake-folder-projects' in parents and trashed = false"}
|
||||
)
|
||||
|
||||
assert _ids(result) == {"fake-file-roadmap"}
|
||||
|
|
@ -1,8 +1,17 @@
|
|||
"""Unit tests: build_composio_credentials returns valid Google Credentials.
|
||||
"""Unit tests: Composio credential helpers + ``get_access_token`` masking guard.
|
||||
|
||||
Mocks the Composio SDK (external system boundary) and verifies that the
|
||||
returned ``google.oauth2.credentials.Credentials`` object is correctly
|
||||
configured with a token and a working refresh handler.
|
||||
Covers two seams between Surfsense and Composio:
|
||||
|
||||
1. ``build_composio_credentials`` returns a ``google.oauth2.credentials.Credentials``
|
||||
object with a working refresh handler (mocks the whole ``ComposioService``).
|
||||
2. ``ComposioService.get_access_token`` rejects masked / missing tokens with
|
||||
actionable error messages (mocks only the Composio SDK boundary so the
|
||||
real guard logic is exercised).
|
||||
|
||||
The masking guard is the boundary handler that production tripped over when
|
||||
Composio's "Mask Connected Account Secrets" project setting was enabled.
|
||||
The corresponding fix landed in ``cea8618``; these tests lock that contract
|
||||
in place so any future weakening of the guard surfaces immediately.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -14,6 +23,11 @@ from google.oauth2.credentials import Credentials
|
|||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_composio_credentials — high-level wrapper tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("app.services.composio_service.ComposioService")
|
||||
def test_returns_credentials_with_token_and_expiry(mock_composio_service):
|
||||
"""build_composio_credentials returns a Credentials object with the Composio access token."""
|
||||
|
|
@ -54,3 +68,85 @@ def test_refresh_handler_fetches_fresh_token(mock_composio_service):
|
|||
assert new_token == "refreshed-token"
|
||||
assert new_expiry > datetime.now(UTC).replace(tzinfo=None)
|
||||
assert mock_service.get_access_token.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ComposioService.get_access_token — boundary masking guard tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _service_with_account(account: object):
|
||||
"""Build a real ``ComposioService`` whose underlying Composio SDK is faked.
|
||||
|
||||
Only the SDK boundary is patched — the real ``get_access_token`` method
|
||||
runs, so changes to the masking / missing-token guards surface here.
|
||||
"""
|
||||
from app.services import composio_service as composio_service_module
|
||||
|
||||
with patch.object(composio_service_module, "Composio") as mock_composio_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.connected_accounts.get.return_value = account
|
||||
mock_composio_cls.return_value = mock_client
|
||||
|
||||
service = composio_service_module.ComposioService(api_key="unit-test-api-key")
|
||||
|
||||
# ``service.client`` already references ``mock_client`` even after the
|
||||
# patch context exits because the constructor captured it during init.
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.parametrize("masked_token", ["x", "xxxxxxxx", "x" * 19])
|
||||
def test_get_access_token_raises_on_masked_token(masked_token):
|
||||
"""Tokens shorter than the 20-char unmask threshold must raise with the dashboard hint.
|
||||
|
||||
Composio masks ``state.val.access_token`` by default (project setting
|
||||
"Mask Connected Account Secrets"). A masked token will always silently
|
||||
fail downstream OAuth calls, so the guard surfaces it with the exact
|
||||
text needed to fix the dashboard config.
|
||||
"""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = masked_token
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match="Mask Connected Account Secrets"):
|
||||
service.get_access_token("any-account-id")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_state_val_missing():
|
||||
"""No ``state.val`` on the connected account is a hard failure with an account-id hint."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state = None
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No state\.val.*missing-state-account"):
|
||||
service.get_access_token("missing-state-account")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_access_token_empty():
|
||||
"""``state.val`` present but ``access_token`` empty must fail before the masking check."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = ""
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No access_token.*missing-token-account"):
|
||||
service.get_access_token("missing-token-account")
|
||||
|
||||
|
||||
def test_get_access_token_raises_when_access_token_none():
|
||||
"""``state.val.access_token = None`` must fail before the masking check."""
|
||||
fake_account = MagicMock()
|
||||
fake_account.state.val.access_token = None
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
with pytest.raises(ValueError, match=r"No access_token.*none-token-account"):
|
||||
service.get_access_token("none-token-account")
|
||||
|
||||
|
||||
def test_get_access_token_returns_unmasked_token():
|
||||
"""Happy path: a >=20-char access token is returned verbatim."""
|
||||
fake_account = MagicMock()
|
||||
unmasked = "u" * 32
|
||||
fake_account.state.val.access_token = unmasked
|
||||
service = _service_with_account(fake_account)
|
||||
|
||||
assert service.get_access_token("happy-account") == unmasked
|
||||
|
|
|
|||
|
|
@ -118,12 +118,8 @@ def test_get_by_tool_call_id_matches_action_request_payload() -> None:
|
|||
tasks=(
|
||||
_Task(
|
||||
interrupts=(
|
||||
_Interrupt(
|
||||
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
|
||||
),
|
||||
_Interrupt(
|
||||
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
|
||||
),
|
||||
_Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"),
|
||||
_Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
@ -146,9 +142,7 @@ def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
|
|||
|
||||
def test_interrupt_without_id_falls_back_to_none() -> None:
|
||||
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
|
||||
)
|
||||
state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),))
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0].interrupt_id is None
|
||||
|
|
|
|||
|
|
@ -37,9 +37,7 @@ def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
|
|||
"context": {"reason": "destructive"},
|
||||
}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert out["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "a@b"}}
|
||||
]
|
||||
assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}]
|
||||
assert out["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
|
|
|
|||
|
|
@ -158,9 +158,7 @@ def _classify_cases() -> list[Exception]:
|
|||
"""Inputs that the FE depends on being mapped to specific error codes."""
|
||||
return [
|
||||
Exception("totally generic error"),
|
||||
Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
|
||||
),
|
||||
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
|
||||
Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error",'
|
||||
'"code":429}}'
|
||||
|
|
@ -220,7 +218,7 @@ class _FakeStreamingService:
|
|||
self.calls.append(
|
||||
{"message": message, "error_code": error_code, "extra": extra}
|
||||
)
|
||||
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
|
||||
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
|
||||
|
||||
|
||||
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
|
||||
|
|
|
|||
|
|
@ -60,8 +60,14 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
|
|||
service = _StreamingService()
|
||||
agent = _Agent(
|
||||
[
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"data": {"chunk": _Chunk(content="Hello")},
|
||||
},
|
||||
{
|
||||
"event": "on_chat_model_stream",
|
||||
"data": {"chunk": _Chunk(content=" world")},
|
||||
},
|
||||
]
|
||||
)
|
||||
result = StreamingResult()
|
||||
|
|
|
|||
|
|
@ -37,7 +37,9 @@ def test_clear_ignored_for_non_task_tool() -> None:
|
|||
def test_clear_ignored_when_task_run_id_mismatches() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-open")
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
|
||||
clear_task_span_if_delegating_task_ended(
|
||||
state, tool_name="task", run_id="run-other"
|
||||
)
|
||||
assert state.active_span_id is not None
|
||||
assert state.active_task_run_id == "run-open"
|
||||
|
||||
|
|
|
|||
|
|
@ -240,9 +240,7 @@ class TestToolHeavyTurn:
|
|||
class TestToolCallSpanMetadata:
|
||||
def test_input_available_merges_new_metadata_keys_after_start(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
|
||||
)
|
||||
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"})
|
||||
b.on_tool_input_available(
|
||||
"call_t",
|
||||
"task",
|
||||
|
|
@ -257,9 +255,7 @@ class TestToolCallSpanMetadata:
|
|||
|
||||
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
|
||||
)
|
||||
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"})
|
||||
b.on_tool_input_available(
|
||||
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
|
||||
)
|
||||
|
|
|
|||
93
surfsense_backend/tests/unit/utils/test_oauth_security.py
Normal file
93
surfsense_backend/tests/unit/utils/test_oauth_security.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.utils.oauth_security import OAuthStateManager
|
||||
|
||||
SECRET = "unit-test-secret"
|
||||
|
||||
|
||||
def _encode_state(payload: dict, *, signature: str | None = None) -> str:
|
||||
"""Build an OAuth state payload compatible with OAuthStateManager."""
|
||||
signature_payload = payload.copy()
|
||||
payload_str = json.dumps(signature_payload, sort_keys=True)
|
||||
computed_signature = hmac.new(
|
||||
SECRET.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
encoded_payload = {
|
||||
**signature_payload,
|
||||
"signature": signature if signature is not None else computed_signature,
|
||||
}
|
||||
return base64.urlsafe_b64encode(json.dumps(encoded_payload).encode()).decode()
|
||||
|
||||
|
||||
def test_validate_state_accepts_fresh_signed_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
user_id = uuid4()
|
||||
|
||||
state = mgr.generate_secure_state(
|
||||
space_id=1,
|
||||
user_id=user_id,
|
||||
toolkit_id="googledrive",
|
||||
)
|
||||
|
||||
decoded = mgr.validate_state(state)
|
||||
|
||||
assert decoded["space_id"] == 1
|
||||
assert decoded["user_id"] == str(user_id)
|
||||
assert decoded["toolkit_id"] == "googledrive"
|
||||
|
||||
|
||||
def test_validate_state_rejects_expired_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
expired_state = _encode_state(
|
||||
{
|
||||
"space_id": 1,
|
||||
"user_id": str(uuid4()),
|
||||
"timestamp": int(time.time()) - 3600,
|
||||
"toolkit_id": "googledrive",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state(expired_state)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "expired" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_validate_state_rejects_tampered_signature():
|
||||
mgr = OAuthStateManager(secret_key=SECRET, max_age_seconds=600)
|
||||
tampered_state = _encode_state(
|
||||
{
|
||||
"space_id": 1,
|
||||
"user_id": str(uuid4()),
|
||||
"timestamp": int(time.time()),
|
||||
"toolkit_id": "googledrive",
|
||||
},
|
||||
signature="deadbeef" * 8,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state(tampered_state)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "tampering" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_validate_state_rejects_malformed_state():
|
||||
mgr = OAuthStateManager(secret_key=SECRET)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
mgr.validate_state("not-base64-and-not-json")
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "invalid state format" in exc.value.detail.lower()
|
||||
Loading…
Add table
Add a link
Reference in a new issue