feat: improved document, folder mentions rendering
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-09 22:15:51 -07:00
parent 28a02a9143
commit c8374e6c5b
59 changed files with 1725 additions and 361 deletions

View file

@ -46,6 +46,10 @@ class SurfSenseContextSchema:
Read by ``KnowledgePriorityMiddleware`` to seed its priority
list. Stays out of the compiled-agent cache key that's the
whole point of putting it here.
mentioned_folder_ids: KB folders the user @-mentioned this turn
(cloud filesystem mode). Surfaced as ``[USER-MENTIONED]``
entries in ``<priority_documents>`` so the agent prioritises
walking those folders with ``ls`` / ``find_documents``.
file_operation_contract: One-shot file operation contract emitted
by ``FileIntentMiddleware`` for the upcoming turn.
turn_id / request_id: Correlation IDs surfaced by the streaming
@ -59,6 +63,7 @@ class SurfSenseContextSchema:
search_space_id: int | None = None
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
file_operation_contract: FileOperationContractState | None = None
turn_id: str | None = None
request_id: str | None = None

View file

@ -0,0 +1,281 @@
"""Resolve @-mention chips to canonical virtual paths and substitute the
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
the agent sees.
The frontend's mention seam is a single discriminated-union list of
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
reaches the backend stream task we have three needs that this module
centralises:
1. Map each chip to its canonical virtual path
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
folders) so the agent sees concrete filesystem locations instead of
ambiguous ``@``-titles.
2. Substitute ``@title`` tokens in the user-typed text with backtick-
wrapped paths so the path becomes part of the ``HumanMessage`` body
the LLM consumes without rewriting the persisted user message
text (which keeps ``@title`` so chip rendering on reload is
unchanged).
3. Surface the resolved id sets (docs + folders) to the priority
middleware so it can render ``[USER-MENTIONED]`` priority entries
without re-doing path resolution.
This is intentionally one module see the architectural note in
``mention-paths-and-folders`` plan: previously the doc-resolution lived
inline in ``stream_new_chat`` and the folder mention had no resolution
at all. Centralising both behind a single ``resolve_mentions`` call
turns a leaky multi-field seam into a single deeper interface.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
build_path_index,
doc_to_virtual_path,
)
from app.db import Document, Folder
from app.schemas.new_chat import MentionedDocumentInfo
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ResolvedMention:
"""Canonical view of a single @-mention chip.
``virtual_path`` is the path the agent will see (no trailing slash
for documents, trailing ``/`` for folders to match the convention
used by ``KnowledgeTreeMiddleware``).
"""
kind: str # "doc" | "folder"
id: int
title: str
virtual_path: str
@dataclass
class ResolvedMentionSet:
"""Aggregate result of resolving a turn's mention chips.
``token_to_path`` maps ``@title`` (the literal token the user typed
and the editor emitted) to the canonical virtual path for that
chip. It is produced longest-token-first so substitution mirrors
``parseMentionSegments`` on the frontend (a longer title like
``@Project Roadmap`` is never shadowed by a shorter prefix
``@Project``).
``mentioned_document_ids`` collapses doc + surfsense_doc chips into
a single ordered, deduped list because the priority middleware
treats them uniformly downstream see
``KnowledgePriorityMiddleware._compute_priority_paths``.
"""
mentions: list[ResolvedMention] = field(default_factory=list)
token_to_path: list[tuple[str, str]] = field(default_factory=list)
mentioned_document_ids: list[int] = field(default_factory=list)
mentioned_folder_ids: list[int] = field(default_factory=list)
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
"""Return ``/documents/Folder/Sub/`` for a folder id.
Falls back to the documents root when the folder is missing from
the index (deleted or in a different search space). Trailing slash
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
the agent's ``ls`` can dispatch on it as a directory.
"""
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
return f"{base}/" if not base.endswith("/") else base
async def resolve_mentions(
session: AsyncSession,
*,
search_space_id: int,
mentioned_documents: list[MentionedDocumentInfo] | None,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
) -> ResolvedMentionSet:
"""Resolve every @-mention chip on a turn into virtual paths.
The function takes both the ``mentioned_documents`` discriminated
list (chip metadata used for substitution + persistence) and the
parallel id arrays (``mentioned_document_ids``,
``mentioned_surfsense_doc_ids``, ``mentioned_folder_ids``) for two
reasons:
* Legacy clients that haven't migrated to the unified chip list
still send the id arrays we treat the union as authoritative.
* The id arrays are the canonical input to
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
returning the deduped, validated lists lets the route forward
them unchanged.
Resolution is best-effort: a chip whose id no longer exists (e.g.
document was deleted between mention and submit) is silently
dropped. The agent still sees the user's original text, just
without a backtick-path substitution for that chip.
"""
chip_doc_ids: list[int] = []
chip_folder_ids: list[int] = []
chip_titles_by_id: dict[tuple[str, int], str] = {}
if mentioned_documents:
for chip in mentioned_documents:
kind = chip.kind
if kind == "folder":
chip_folder_ids.append(chip.id)
else:
chip_doc_ids.append(chip.id)
chip_titles_by_id[(kind, chip.id)] = chip.title
doc_id_pool: list[int] = list(
dict.fromkeys(
[
*(mentioned_document_ids or []),
*(mentioned_surfsense_doc_ids or []),
*chip_doc_ids,
]
)
)
folder_id_pool: list[int] = list(
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
)
if not doc_id_pool and not folder_id_pool:
return ResolvedMentionSet()
index = await build_path_index(session, search_space_id)
doc_rows: dict[int, Document] = {}
if doc_id_pool:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_id_pool),
)
)
for row in result.scalars().all():
doc_rows[row.id] = row
folder_rows: dict[int, Folder] = {}
if folder_id_pool:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.id.in_(folder_id_pool),
)
)
for row in result.scalars().all():
folder_rows[row.id] = row
resolved: list[ResolvedMention] = []
accepted_doc_ids: list[int] = []
accepted_folder_ids: list[int] = []
for doc_id in doc_id_pool:
row = doc_rows.get(doc_id)
if row is None:
logger.debug(
"mention_resolver: dropping doc id=%s (not found in space=%s)",
doc_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
path = doc_to_virtual_path(
doc_id=row.id,
title=str(row.title or "untitled"),
folder_id=row.folder_id,
index=index,
)
resolved.append(
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
)
accepted_doc_ids.append(row.id)
for folder_id in folder_id_pool:
row = folder_rows.get(folder_id)
if row is None:
logger.debug(
"mention_resolver: dropping folder id=%s (not found in space=%s)",
folder_id,
search_space_id,
)
continue
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
path = _folder_virtual_path(row.id, index.folder_paths)
resolved.append(
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
)
accepted_folder_ids.append(row.id)
token_to_path: list[tuple[str, str]] = []
seen_tokens: set[str] = set()
for mention in resolved:
if not mention.title:
continue
token = f"@{mention.title}"
if token in seen_tokens:
continue
seen_tokens.add(token)
token_to_path.append((token, mention.virtual_path))
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
return ResolvedMentionSet(
mentions=resolved,
token_to_path=token_to_path,
mentioned_document_ids=accepted_doc_ids,
mentioned_folder_ids=accepted_folder_ids,
)
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
Mirrors ``parseMentionSegments`` on the frontend: longest token
first, single forward pass, no regex (titles can contain regex
metacharacters). The substitution is idempotent for already-
substituted text because the backtick-wrapped path no longer
starts with ``@``.
Empty / no-op cases short-circuit so callers can pass this through
unconditionally without paying for a scan.
"""
if not text or not token_to_path:
return text
out: list[str] = []
i = 0
n = len(text)
while i < n:
matched: tuple[str, str] | None = None
for token, path in token_to_path:
if text.startswith(token, i):
matched = (token, path)
break
if matched is None:
out.append(text[i])
i += 1
continue
token, path = matched
out.append(f"`{path}`")
i += len(token)
return "".join(out)
__all__ = [
"ResolvedMention",
"ResolvedMentionSet",
"resolve_mentions",
"substitute_in_text",
]

View file

@ -54,6 +54,7 @@ from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
Folder,
shielded_async_session,
)
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
@ -832,6 +833,22 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
mention_ids = list(self.mentioned_document_ids)
self.mentioned_document_ids = []
# Folder mentions live alongside doc mentions on the runtime
# context. They never feed hybrid search (folders aren't
# embedded) — they're surfaced purely as ``[USER-MENTIONED]``
# priority entries so the agent walks the folder with ``ls`` /
# ``find_documents`` instead of ignoring it. Cloud filesystem
# mode only.
folder_mention_ids: list[int] = []
if (
ctx is not None
and getattr(self, "filesystem_mode", FilesystemMode.CLOUD)
== FilesystemMode.CLOUD
):
ctx_folders = getattr(ctx, "mentioned_folder_ids", None)
if ctx_folders:
folder_mention_ids = list(ctx_folders)
mentioned_results: list[dict[str, Any]] = []
if mention_ids:
mentioned_results = await fetch_mentioned_documents(
@ -876,16 +893,21 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
priority, matched_chunk_ids = await self._materialize_priority(merged)
if folder_mention_ids:
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
priority = folder_entries + priority
new_messages = list(messages)
insert_at = max(len(new_messages) - 1, 0)
new_messages.insert(insert_at, _render_priority_message(priority))
_perf_log.info(
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d",
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(priority),
len(mentioned_results),
len(folder_mention_ids),
)
return {
@ -894,6 +916,58 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
"messages": new_messages,
}
async def _materialize_folder_priority(
self, folder_ids: list[int]
) -> list[dict[str, Any]]:
"""Resolve user-mentioned folder ids to ``<priority_documents>`` entries.
Each entry uses the canonical ``/documents/Folder/Sub/`` virtual
path (matching ``KnowledgeTreeMiddleware`` and the agent's
``ls`` adapter) and is flagged ``mentioned=True`` so the
rendered line carries ``[USER-MENTIONED]``. ``score`` is left
``None`` so the renderer prints ``n/a`` folders aren't
ranked, the agent decides which children to read.
"""
if not folder_ids:
return []
async with shielded_async_session() as session:
index: PathIndex = await build_path_index(session, self.search_space_id)
folder_rows = await session.execute(
select(Folder.id, Folder.name).where(
Folder.search_space_id == self.search_space_id,
Folder.id.in_(folder_ids),
)
)
folder_titles: dict[int, str] = {
row.id: row.name for row in folder_rows.all()
}
entries: list[dict[str, Any]] = []
seen: set[int] = set()
for folder_id in folder_ids:
if folder_id in seen:
continue
seen.add(folder_id)
base = index.folder_paths.get(folder_id)
if base is None:
logger.debug(
"kb_priority: dropping folder id=%s (missing from path index)",
folder_id,
)
continue
path = base if base.endswith("/") else f"{base}/"
entries.append(
{
"path": path,
"score": None,
"document_id": None,
"folder_id": folder_id,
"title": folder_titles.get(folder_id, ""),
"mentioned": True,
}
)
return entries
async def _materialize_priority(
self, merged: list[dict[str, Any]]
) -> tuple[list[dict[str, Any]], dict[int, list[int]]]:

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.path_resolver import virtual_path_to_doc
from app.db import (
Chunk,
Document,
@ -752,7 +753,24 @@ async def get_document_by_virtual_path(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Resolve a knowledge-base document id by exact virtual path."""
"""Resolve a knowledge-base document by its agent-facing virtual path.
The agent renders every document under ``/documents/...`` with a
``.xml`` extension appended via ``safe_filename`` (so a PDF titled
``2025-W2.pdf`` becomes ``/documents/2025-W2.pdf.xml``). When the user
clicks that path in an answer, this endpoint must round-trip back to
the underlying ``Document`` row regardless of its type agent-created
NOTE docs (which carry ``virtual_path`` in metadata), uploaded PDFs,
and connector docs all flow through here.
Resolution is delegated to :func:`virtual_path_to_doc`, the single
source of truth that handles:
* ``unique_identifier_hash`` lookup (agent NOTE fast path)
* ``" (<doc_id>).xml"`` disambiguation suffixes
* ``.xml`` extension stripping for title-based fallback
* ``safe_filename`` round-trip for connector titles with lossy chars
"""
try:
await check_permission(
session,
@ -762,24 +780,19 @@ async def get_document_by_virtual_path(
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(
Document.id,
Document.title,
Document.document_type,
).filter(
Document.search_space_id == search_space_id,
Document.document_metadata["virtual_path"].as_string() == virtual_path,
)
document = await virtual_path_to_doc(
session,
search_space_id=search_space_id,
virtual_path=virtual_path,
)
row = result.first()
if row is None:
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentTitleRead(
id=row.id,
title=row.title,
document_type=row.document_type,
id=document.id,
title=document.title,
document_type=document.document_type,
folder_id=document.folder_id,
)
except HTTPException:
raise

View file

@ -1781,6 +1781,7 @@ async def handle_new_chat(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_folder_ids=request.mentioned_folder_ids,
mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
@ -2266,6 +2267,7 @@ async def regenerate_response(
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_folder_ids=request.mentioned_folder_ids,
mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,

View file

@ -201,18 +201,34 @@ class NewChatUserImagePart(BaseModel):
class MentionedDocumentInfo(BaseModel):
"""Display metadata for a single ``@``-mentioned document.
"""Display metadata for a single ``@``-mention chip.
The full triple ``{id, title, document_type}`` is forwarded by the
frontend mention chip so the server can embed it in the persisted
user message ``ContentPart[]`` (single ``mentioned-documents`` part).
The history loader then renders the chips on reload without an extra
Carries either a knowledge-base document or a knowledge-base folder
(discriminated by ``kind``). The full triple
``{id, title, document_type}`` is forwarded by the frontend mention
chip so the server can embed it in the persisted user message
``ContentPart[]`` (single ``mentioned-documents`` part). The
history loader then renders the chips on reload without an extra
fetch mirrors the pre-refactor frontend ``persistUserTurn`` shape.
``kind`` defaults to ``"doc"`` so legacy clients and persisted rows
that predate folder mentions deserialise unchanged.
"""
id: int
title: str = Field(..., min_length=1, max_length=500)
document_type: str = Field(..., min_length=1, max_length=100)
kind: Literal["doc", "folder"] = Field(
default="doc",
description=(
"Discriminator for the chip's referent: ``doc`` is a "
"knowledge-base ``Document`` row, ``folder`` is a "
"knowledge-base ``Folder`` row. Folders carry the sentinel "
"``document_type='FOLDER'`` to keep the frontend dedup key "
"``(kind:document_type:id)`` from colliding doc and folder "
"ids that happen to share an integer value."
),
)
class NewChatRequest(BaseModel):
@ -228,15 +244,26 @@ class NewChatRequest(BaseModel):
mentioned_surfsense_doc_ids: list[int] | None = (
None # Optional SurfSense documentation IDs mentioned with @ in the chat
)
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Optional knowledge-base folder IDs the user mentioned with "
"@. Resolved to virtual paths (``/documents/.../``) by "
"``mention_resolver`` and surfaced to the agent via "
"(a) backtick-wrapped substitution in ``user_query`` and "
"(b) a ``[USER-MENTIONED]`` entry in ``<priority_documents>``. "
"The agent's ``ls`` tool can then walk the folder itself."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document. Persisted as a ``mentioned-documents`` "
"ContentPart on the user message so reload renders chips "
"without an extra fetch. Optional and additive — when None "
"the user message is persisted without a mentioned-documents "
"part."
"Display metadata (id, title, document_type, kind) for every "
"@-mention chip — both documents and folders. Persisted as a "
"``mentioned-documents`` ContentPart on the user message so "
"reload renders chips without an extra fetch. Optional and "
"additive — when None the user message is persisted without "
"a mentioned-documents part."
),
)
disabled_tools: list[str] | None = (
@ -290,14 +317,22 @@ class RegenerateRequest(BaseModel):
)
mentioned_document_ids: list[int] | None = None
mentioned_surfsense_doc_ids: list[int] | None = None
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Optional knowledge-base folder IDs the user mentioned with "
"@ on the edited user turn. Only used when ``user_query`` is "
"non-None (edit). Mirrors ``NewChatRequest.mentioned_folder_ids``."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(
"Display metadata (id, title, document_type) for every "
"@-mentioned document on the edited user turn. Only used "
"when ``user_query`` is non-None (edit). Persisted as a "
"``mentioned-documents`` ContentPart on the new user "
"message. None means no chip metadata."
"Display metadata (id, title, document_type, kind) for every "
"@-mention chip on the edited user turn — both documents and "
"folders. Only used when ``user_query`` is non-None (edit). "
"Persisted as a ``mentioned-documents`` ContentPart on the "
"new user message. None means no chip metadata."
),
)
disabled_tools: list[str] | None = None
@ -373,6 +408,16 @@ class ResumeRequest(BaseModel):
filesystem_mode: Literal["cloud", "desktop_local_folder"] = "cloud"
client_platform: Literal["web", "desktop"] = "web"
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
mentioned_folder_ids: list[int] | None = Field(
default=None,
description=(
"Forwarded for symmetry with /new_chat and /regenerate. "
"Resume reuses the original interrupted user turn so this "
"field is informational only — the originating turn's "
"folder mentions already shaped the priority hints baked "
"into the agent's checkpoint."
),
)
mentioned_documents: list[MentionedDocumentInfo] | None = Field(
default=None,
description=(

View file

@ -43,9 +43,7 @@ class EmitterRegistry:
return main_emitter()
def has_active_subagents(self) -> bool:
return any(
emitter.level == "subagent" for emitter in self._by_run_id.values()
)
return any(emitter.level == "subagent" for emitter in self._by_run_id.values())
def clear(self) -> None:
self._by_run_id.clear()

View file

@ -6,9 +6,7 @@ from ..emitter import Emitter, attach_emitted_by
from ..envelope import format_sse
def format_reasoning_start(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
def format_reasoning_start(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-start", "id": reasoning_id}, emitter)
)
@ -28,9 +26,7 @@ def format_reasoning_delta(
)
def format_reasoning_end(
reasoning_id: str, *, emitter: Emitter | None = None
) -> str:
def format_reasoning_end(reasoning_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "reasoning-end", "id": reasoning_id}, emitter)
)

View file

@ -7,9 +7,7 @@ from ..envelope import format_sse
def format_text_start(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-start", "id": text_id}, emitter)
)
return format_sse(attach_emitted_by({"type": "text-start", "id": text_id}, emitter))
def format_text_delta(
@ -26,6 +24,4 @@ def format_text_delta(
def format_text_end(text_id: str, *, emitter: Emitter | None = None) -> str:
return format_sse(
attach_emitted_by({"type": "text-end", "id": text_id}, emitter)
)
return format_sse(attach_emitted_by({"type": "text-end", "id": text_id}, emitter))

View file

@ -84,9 +84,7 @@ class StreamingService:
def format_step_finish(self, *, emitter: Emitter | None = None) -> str:
return lifecycle.format_step_finish(emitter=emitter)
def format_text_start(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
def format_text_start(self, text_id: str, *, emitter: Emitter | None = None) -> str:
return text.format_text_start(text_id, emitter=emitter)
def format_text_delta(
@ -94,9 +92,7 @@ class StreamingService:
) -> str:
return text.format_text_delta(text_id, delta, emitter=emitter)
def format_text_end(
self, text_id: str, *, emitter: Emitter | None = None
) -> str:
def format_text_end(self, text_id: str, *, emitter: Emitter | None = None) -> str:
return text.format_text_end(text_id, emitter=emitter)
def format_reasoning_start(

View file

@ -51,7 +51,9 @@ logger = logging.getLogger(__name__)
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
def _merge_tool_part_metadata(
part: dict[str, Any], metadata: dict[str, Any] | None
) -> None:
"""Shallow-merge ``metadata`` into ``part["metadata"]``; first key wins.
Used for tool-call linkage (``spanId``, ``thinkingStepId``, ): a later

View file

@ -109,17 +109,18 @@ def _build_user_content(
[{"type": "text", "text": "..."},
{"type": "image", "image": "data:..."},
{"type": "mentioned-documents", "documents": [{"id": int,
"title": str, "document_type": str}, ...]}]
"title": str, "document_type": str, "kind": "doc" | "folder"},
...]}]
The companion reader is
``app.utils.user_message_multimodal.split_persisted_user_content_parts``
which expects exactly this shape keep them in sync.
``mentioned_documents``: optional list of ``{id, title, document_type}``
dicts. When non-empty (and a ``mentioned-documents`` part is not already
in some other input shape), a single ``{"type": "mentioned-documents",
"documents": [...]}`` part is appended. Mirrors the FE injection at
``page.tsx:281-286`` (``persistUserTurn``).
``mentioned_documents``: optional list of mention chip dicts. Each
dict may include a ``kind`` discriminator (``"doc"`` or ``"folder"``)
so the persisted ContentPart round-trips folder chips on reload.
When ``kind`` is missing we default to ``"doc"`` so legacy clients
that haven't migrated to the union schema still persist correctly.
"""
parts: list[dict[str, Any]] = [{"type": "text", "text": user_query or ""}]
for url in user_image_data_urls or ():
@ -135,11 +136,14 @@ def _build_user_content(
document_type = doc.get("document_type")
if doc_id is None or title is None or document_type is None:
continue
kind_raw = doc.get("kind", "doc")
kind = kind_raw if kind_raw in ("doc", "folder") else "doc"
normalized.append(
{
"id": doc_id,
"title": str(title),
"document_type": str(document_type),
"kind": kind,
}
)
if normalized:

View file

@ -43,6 +43,7 @@ from app.agents.new_chat.memory_extraction import (
extract_and_save_memory,
extract_and_save_team_memory,
)
from app.agents.new_chat.mention_resolver import resolve_mentions, substitute_in_text
from app.agents.new_chat.middleware.busy_mutex import (
end_turn,
get_cancel_state,
@ -929,6 +930,7 @@ async def stream_new_chat(
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
mentioned_folder_ids: list[int] | None = None,
mentioned_documents: list[dict[str, Any]] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
@ -958,6 +960,7 @@ async def stream_new_chat(
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
mentioned_folder_ids: Optional list of knowledge-base folder IDs mentioned with @ (cloud mode)
checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
Yields:
@ -1502,6 +1505,53 @@ async def stream_new_chat(
)
recent_reports = list(recent_reports_result.scalars().all())
# Resolve @-mention chips to canonical virtual paths and rewrite
# the user-typed text so the LLM sees ``\`/documents/...\``` instead
# of bare ``@title``. The persisted user-message text keeps
# ``@title`` so chip rendering on reload is unchanged — see
# ``persistence._build_user_content``.
#
# Cloud mode only: local-folder mode keeps the legacy
# ``@title`` text path; mention support there is a follow-up
# task because the path scheme (mount-rooted) and the picker
# UI both need separate work.
accepted_folder_ids: list[int] = []
if fs_mode == FilesystemMode.CLOUD.value and (
mentioned_document_ids
or mentioned_surfsense_doc_ids
or mentioned_folder_ids
or mentioned_documents
):
from app.schemas.new_chat import (
MentionedDocumentInfo as _MentionedDocumentInfo,
)
chip_objs: list[_MentionedDocumentInfo] | None = None
if mentioned_documents:
chip_objs = []
for raw in mentioned_documents:
if isinstance(raw, _MentionedDocumentInfo):
chip_objs.append(raw)
continue
try:
chip_objs.append(_MentionedDocumentInfo.model_validate(raw))
except Exception:
logger.debug(
"stream_new_chat: dropping malformed mention chip %r",
raw,
)
resolved = await resolve_mentions(
session,
search_space_id=search_space_id,
mentioned_documents=chip_objs,
mentioned_document_ids=mentioned_document_ids,
mentioned_surfsense_doc_ids=mentioned_surfsense_doc_ids,
mentioned_folder_ids=mentioned_folder_ids,
)
user_query = substitute_in_text(user_query, resolved.token_to_path)
accepted_folder_ids = resolved.mentioned_folder_ids
# Format the user query with context (SurfSense docs + reports only)
final_query = user_query
context_parts = []
@ -1901,6 +1951,9 @@ async def stream_new_chat(
runtime_context = SurfSenseContextSchema(
search_space_id=search_space_id,
mentioned_document_ids=list(mentioned_document_ids or []),
mentioned_folder_ids=list(
accepted_folder_ids or mentioned_folder_ids or []
),
request_id=request_id,
turn_id=stream_result.turn_id,
)

View file

@ -26,9 +26,7 @@ def handle_report_progress(
return None, last_active_step_items
phase = data.get("phase", "")
topic_items = [
item for item in last_active_step_items if item.startswith("Topic:")
]
topic_items = [item for item in last_active_step_items if item.startswith("Topic:")]
if phase in ("revising_section", "adding_section"):
plan_items = [
@ -56,7 +54,9 @@ def handle_report_progress(
return frame, new_items
def handle_document_created(data: dict[str, Any], *, streaming_service: Any) -> str | None:
def handle_document_created(
data: dict[str, Any], *, streaming_service: Any
) -> str | None:
if not data.get("id"):
return None
return streaming_service.format_data(

View file

@ -13,7 +13,9 @@ from app.tasks.chat.streaming.handlers.tools import (
)
from app.tasks.chat.streaming.helpers.tool_output import tool_output_has_error
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import clear_task_span_if_delegating_task_ended
from app.tasks.chat.streaming.relay.task_span import (
clear_task_span_if_delegating_task_ended,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
@ -32,9 +34,7 @@ def iter_tool_end_frames(
run_id = event.get("run_id", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
staged_file_path = (
state.file_path_by_run.pop(run_id, None) if run_id else None
)
staged_file_path = state.file_path_by_run.pop(run_id, None) if run_id else None
if tool_name == "update_memory":
state.called_update_memory = True
@ -116,6 +116,4 @@ def iter_tool_end_frames(
)
yield from iter_tool_completion_emission_frames(emission_ctx)
clear_task_span_if_delegating_task_ended(
state, tool_name=tool_name, run_id=run_id
)
clear_task_span_if_delegating_task_ended(state, tool_name=tool_name, run_id=run_id)

View file

@ -15,7 +15,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items

View file

@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -34,7 +34,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -29,7 +29,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items
@ -44,9 +46,7 @@ def resolve_completed_thinking(
else "Report"
)
word_count = (
tool_output.get("word_count", 0)
if isinstance(tool_output, dict)
else 0
tool_output.get("word_count", 0) if isinstance(tool_output, dict) else 0
)
is_revision = (
tool_output.get("is_revision", False)

View file

@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
return default_thinking.resolve_completed_thinking(
tool_name, tool_output, last_items

View file

@ -17,7 +17,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Editing file", last_items)

View file

@ -24,9 +24,7 @@ def iter_completion_emission_frames(
output_text = om.group(1) if om else ""
thread_id_str = ctx.langgraph_config.get("configurable", {}).get("thread_id", "")
for sf_match in re.finditer(
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
):
for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE):
fpath = sf_match.group(1).strip()
if fpath and fpath not in ctx.stream_result.sandbox_files:
ctx.stream_result.sandbox_files.append(fpath)

View file

@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching files", last_items)

View file

@ -25,7 +25,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Searching content", last_items)

View file

@ -20,7 +20,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
if isinstance(tool_output, dict):
@ -38,9 +40,7 @@ def resolve_completed_thinking(
paths = [str(p) for p in parsed]
except (ValueError, SyntaxError):
paths = [
line.strip()
for line in ls_output.strip().split("\n")
if line.strip()
line.strip() for line in ls_output.strip().split("\n") if line.strip()
]
for p in paths:
name = p.rstrip("/").split("/")[-1]

View file

@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
d = as_tool_input_dict(tool_input)
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
display = p if len(p) <= 80 else "" + p[-77:]
return ToolStartThinking(title="Creating folder", items=[display] if display else [])
return ToolStartThinking(
title="Creating folder", items=[display] if display else []
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Creating folder", last_items)

View file

@ -27,7 +27,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Moving file", last_items)

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Reading file", last_items)

View file

@ -22,7 +22,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Deleting file", last_items)

View file

@ -17,11 +17,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
d = as_tool_input_dict(tool_input)
p = d.get("path", "") if isinstance(tool_input, dict) else str(tool_input)
display = p if len(p) <= 80 else "" + p[-77:]
return ToolStartThinking(title="Deleting folder", items=[display] if display else [])
return ToolStartThinking(
title="Deleting folder", items=[display] if display else []
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Deleting folder", last_items)

View file

@ -21,7 +21,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Writing file", last_items)

View file

@ -20,15 +20,15 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
return ToolStartThinking(
title="Planning tasks",
items=(
[f"{todo_count} task{'s' if todo_count != 1 else ''}"]
if todo_count
else []
[f"{todo_count} task{'s' if todo_count != 1 else ''}"] if todo_count else []
),
)
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_output, tool_name
return ("Planning tasks", last_items)

View file

@ -58,14 +58,18 @@ def _emission_module(tool_name: str) -> str:
def _import_thinking(tool_name: str):
try:
return importlib.import_module(f"{_BASE}.{_thinking_module(tool_name)}.thinking")
return importlib.import_module(
f"{_BASE}.{_thinking_module(tool_name)}.thinking"
)
except ModuleNotFoundError:
return importlib.import_module(f"{_BASE}.default.thinking")
def _import_emission(tool_name: str):
try:
return importlib.import_module(f"{_BASE}.{_emission_module(tool_name)}.emission")
return importlib.import_module(
f"{_BASE}.{_emission_module(tool_name)}.emission"
)
except ModuleNotFoundError:
return importlib.import_module(f"{_BASE}.default.emission")

View file

@ -23,7 +23,9 @@ def resolve_start_thinking(tool_name: str, tool_input: Any) -> ToolStartThinking
def resolve_completed_thinking(
tool_name: str, tool_output: Any, last_items: list[str],
tool_name: str,
tool_output: Any,
last_items: list[str],
) -> tuple[str, list[str]]:
del tool_name
items = last_items

View file

@ -28,11 +28,7 @@ def iter_completion_emission_frames(
xml,
):
chunk_url, content = m.group(1).strip(), m.group(2).strip()
if (
chunk_url.startswith("http")
and chunk_url in citations
and content
):
if chunk_url.startswith("http") and chunk_url in citations and content:
citations[chunk_url]["snippet"] = (
content[:200] + "" if len(content) > 200 else content
)

View file

@ -367,18 +367,26 @@ class TestPersistUserTurn:
db_thread,
patched_shielded_session,
):
"""The full ``{id, title, document_type}`` triple forwarded by
the FE must round-trip into a single ``mentioned-documents``
ContentPart on the persisted user message the history loader
renders the chips on reload from this part directly.
"""The full ``{id, title, document_type, kind}`` chip metadata
forwarded by the FE must round-trip into a single
``mentioned-documents`` ContentPart on the persisted user
message the history loader renders the chips on reload from
this part directly. Folder chips ride alongside doc chips so
the FE can render mixed mention bars without a second fetch.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8200"
mentioned = [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
{
"id": 33,
"title": "Reports",
"document_type": "FOLDER",
"kind": "folder",
},
]
msg_id = await persist_user_turn(
chat_id=thread_id,
@ -397,8 +405,61 @@ class TestPersistUserTurn:
assert row.content[1] == {
"type": "mentioned-documents",
"documents": [
{"id": 11, "title": "Alpha", "document_type": "GENERAL"},
{"id": 22, "title": "Beta", "document_type": "GENERAL"},
{"id": 11, "title": "Alpha", "document_type": "GENERAL", "kind": "doc"},
{"id": 22, "title": "Beta", "document_type": "GENERAL", "kind": "doc"},
{
"id": 33,
"title": "Reports",
"document_type": "FOLDER",
"kind": "folder",
},
],
}
async def test_legacy_chip_without_kind_defaults_to_doc(
self,
db_session,
db_user,
db_thread,
patched_shielded_session,
):
"""Pre-folder clients send chips without ``kind``. The persistence
layer defaults them to ``"doc"`` so the round-trip stays
consistent on reload the FE schema's optional default
produces the same value, but persisting it explicitly keeps
the DB row self-describing.
"""
thread_id = db_thread.id
user_id_str = str(db_user.id)
turn_id = f"{thread_id}:8201"
mentioned = [
{"id": 77, "title": "Legacy", "document_type": "GENERAL"},
]
msg_id = await persist_user_turn(
chat_id=thread_id,
user_id=user_id_str,
turn_id=turn_id,
user_query="hi",
mentioned_documents=mentioned,
)
assert isinstance(msg_id, int)
row = await db_session.get(NewChatMessage, msg_id)
assert row is not None
assert isinstance(row.content, list)
mentioned_part = next(
p for p in row.content if p.get("type") == "mentioned-documents"
)
assert mentioned_part == {
"type": "mentioned-documents",
"documents": [
{
"id": 77,
"title": "Legacy",
"document_type": "GENERAL",
"kind": "doc",
},
],
}

View file

@ -0,0 +1,285 @@
"""Tests for the @-mention resolver.
These tests pin down the contract that ``mention_resolver`` is the
single seam between ``MentionedDocumentInfo`` chips (frontend) and the
canonical ``/documents/...`` virtual paths (agent). The streaming task,
priority middleware, and persistence layer all consume the resolver's
output keeping the tests focused on substitute-in-text + the
returned id partition keeps the seam stable across refactors.
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat import mention_resolver
from app.agents.new_chat.mention_resolver import (
ResolvedMention,
ResolvedMentionSet,
resolve_mentions,
substitute_in_text,
)
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, PathIndex
from app.schemas.new_chat import MentionedDocumentInfo
pytestmark = pytest.mark.unit
class TestSubstituteInText:
"""``substitute_in_text`` is a pure string transform and is exercised
on every cloud-mode turn, so it has to be both fast and behaviour-
identical to the frontend's ``parseMentionSegments`` (longest-token
first, single forward pass)."""
def test_returns_text_unchanged_when_no_tokens(self):
assert substitute_in_text("hello @foo", []) == "hello @foo"
def test_returns_text_unchanged_when_empty(self):
assert substitute_in_text("", [("@x", "/documents/x.xml")]) == ""
def test_replaces_single_token_with_backtick_path(self):
out = substitute_in_text(
"see @notes please",
[("@notes", "/documents/notes.xml")],
)
assert out == "see `/documents/notes.xml` please"
def test_longest_token_wins_over_prefix(self):
# ``@Project Roadmap`` must NOT be partially matched by ``@Project``.
# Mirrors the FE's parseMentionSegments contract.
token_to_path = [
("@Project Roadmap", "/documents/Roadmap.xml"),
("@Project", "/documents/Project.xml"),
]
out = substitute_in_text("about @Project Roadmap today", token_to_path)
assert out == "about `/documents/Roadmap.xml` today"
def test_handles_repeated_mentions(self):
out = substitute_in_text(
"@A and @A again @B",
[
("@A", "/documents/a.xml"),
("@B", "/documents/b.xml"),
],
)
assert (
out == "`/documents/a.xml` and `/documents/a.xml` again `/documents/b.xml`"
)
def test_does_not_match_inside_word(self):
# Substitution is positional — there's no word-boundary semantics.
# ``@Pro`` inside ``foo@Project`` still matches; this is the same
# behaviour as parseMentionSegments. The test pins it so a
# future "fix" doesn't accidentally diverge between FE/BE.
out = substitute_in_text("foo@Pro", [("@Pro", "/documents/p.xml")])
assert out == "foo`/documents/p.xml`"
def test_idempotent_after_substitution(self):
# The output starts with a backtick, not ``@``, so re-running
# the substitution leaves it alone.
once = substitute_in_text("@A", [("@A", "/documents/a.xml")])
twice = substitute_in_text(once, [("@A", "/documents/a.xml")])
assert once == twice
class TestResolveMentions:
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
a ``ResolvedMentionSet`` whose id partitions feed
``KnowledgePriorityMiddleware``."""
@pytest.mark.asyncio
async def test_returns_empty_when_no_mentions(self):
session = MagicMock()
session.execute = AsyncMock()
result = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=None,
)
assert isinstance(result, ResolvedMentionSet)
assert result.mentions == []
assert result.token_to_path == []
assert result.mentioned_document_ids == []
assert result.mentioned_folder_ids == []
# No DB roundtrips when there's nothing to resolve.
session.execute.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolves_doc_chip_to_virtual_path(self, monkeypatch):
chip = MentionedDocumentInfo(
id=42,
title="Notes",
document_type="EXTENSION",
kind="doc",
)
doc_row = SimpleNamespace(id=42, title="Notes", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=5,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "doc"
assert mention.id == 42
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Notes.xml"
assert out.mentioned_document_ids == [42]
assert out.mentioned_folder_ids == []
assert ("@Notes", f"{DOCUMENTS_ROOT}/Notes.xml") in out.token_to_path
@pytest.mark.asyncio
async def test_resolves_folder_chip_with_trailing_slash(self, monkeypatch):
chip = MentionedDocumentInfo(
id=9,
title="Reports",
document_type="FOLDER",
kind="folder",
)
folder_row = SimpleNamespace(id=9, name="Reports")
async def fake_build_index(_session, _ssid):
return PathIndex(folder_paths={9: f"{DOCUMENTS_ROOT}/Reports"})
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [folder_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=3,
mentioned_documents=[chip],
)
assert len(out.mentions) == 1
mention = out.mentions[0]
assert mention.kind == "folder"
assert mention.id == 9
assert mention.virtual_path == f"{DOCUMENTS_ROOT}/Reports/"
assert out.mentioned_document_ids == []
assert out.mentioned_folder_ids == [9]
@pytest.mark.asyncio
async def test_drops_chip_when_doc_is_missing(self, monkeypatch):
chip = MentionedDocumentInfo(
id=99, title="ghost", document_type="EXTENSION", kind="doc"
)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = []
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip],
)
assert out.mentions == []
assert out.mentioned_document_ids == []
assert out.token_to_path == []
@pytest.mark.asyncio
async def test_token_to_path_is_longest_first(self, monkeypatch):
# Two chips whose titles are prefixes of each other — the
# resolver MUST sort longest-first so substitution doesn't
# break the ``@Project Roadmap`` vs ``@Project`` invariant.
chip_short = MentionedDocumentInfo(
id=1, title="A", document_type="EXTENSION", kind="doc"
)
chip_long = MentionedDocumentInfo(
id=2, title="A long one", document_type="EXTENSION", kind="doc"
)
rows = [
SimpleNamespace(id=1, title="A", folder_id=None),
SimpleNamespace(id=2, title="A long one", folder_id=None),
]
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = rows
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=1,
mentioned_documents=[chip_short, chip_long],
)
tokens = [tok for tok, _ in out.token_to_path]
assert tokens == sorted(tokens, key=len, reverse=True)
@pytest.mark.asyncio
async def test_legacy_id_arrays_resolve_without_chip_metadata(self, monkeypatch):
# ``mentioned_document_ids`` (the legacy parallel array) must
# still resolve when no chip metadata is available — covers
# callers that haven't migrated to the discriminated chip list.
doc_row = SimpleNamespace(id=7, title="Legacy", folder_id=None)
async def fake_build_index(_session, _ssid):
return PathIndex()
monkeypatch.setattr(mention_resolver, "build_path_index", fake_build_index)
scalars = MagicMock()
scalars.all.return_value = [doc_row]
result = MagicMock()
result.scalars.return_value = scalars
session = MagicMock()
session.execute = AsyncMock(return_value=result)
out = await resolve_mentions(
session,
search_space_id=2,
mentioned_documents=None,
mentioned_document_ids=[7],
)
assert out.mentioned_document_ids == [7]
assert len(out.mentions) == 1
assert out.mentions[0].title == "Legacy"
class TestResolvedMentionEquality:
"""Smoke check on the dataclass behaviour we rely on for asserting
test outputs."""
def test_equal_when_fields_equal(self):
a = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
b = ResolvedMention(
kind="doc", id=1, title="x", virtual_path="/documents/x.xml"
)
assert a == b

View file

@ -196,3 +196,50 @@ class TestVirtualPathToDoc:
)
assert document is target_doc
assert session.execute.await_count == 2
@pytest.mark.asyncio
async def test_resolves_double_extension_for_uploaded_pdf(self):
# Regression: the agent renders every KB document under
# ``/documents/`` with a trailing ``.xml`` (via ``safe_filename``),
# so an uploaded PDF whose DB title is ``2025-W2.pdf`` shows up as
# ``/documents/2025-W2.pdf.xml`` in answers. Clicking that path
# must round-trip back to the row even though the title itself
# does NOT end in ``.xml``.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf.xml",
)
assert document is target_doc
@pytest.mark.asyncio
async def test_resolves_path_without_xml_suffix(self):
# The user (or a hand-edited link) may pass the title-only form
# ``/documents/2025-W2.pdf``. The resolver must still find the row
# by literal title equality.
target_doc = SimpleNamespace(id=99, title="2025-W2.pdf", folder_id=None)
session = MagicMock()
session.execute = AsyncMock(
side_effect=[
_result_from_one(None),
_result_from_scalars([target_doc]),
]
)
document = await virtual_path_to_doc(
session,
search_space_id=5,
virtual_path=f"{DOCUMENTS_ROOT}/2025-W2.pdf",
)
assert document is target_doc

View file

@ -118,12 +118,8 @@ def test_get_by_tool_call_id_matches_action_request_payload() -> None:
tasks=(
_Task(
interrupts=(
_Interrupt(
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
),
_Interrupt(
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
),
_Interrupt(value=_hitl("a", tool_call_id="call_xxx"), id="int_a"),
_Interrupt(value=_hitl("b", tool_call_id="call_yyy"), id="int_b"),
)
),
)
@ -146,9 +142,7 @@ def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
def test_interrupt_without_id_falls_back_to_none() -> None:
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
)
state = _State(tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),))
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0].interrupt_id is None

View file

@ -37,9 +37,7 @@ def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
"context": {"reason": "destructive"},
}
out = normalize_interrupt_payload(raw)
assert out["action_requests"] == [
{"name": "send_email", "args": {"to": "a@b"}}
]
assert out["action_requests"] == [{"name": "send_email", "args": {"to": "a@b"}}]
assert out["review_configs"] == [
{
"action_name": "send_email",

View file

@ -158,9 +158,7 @@ def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception(
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
),
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
@ -220,7 +218,7 @@ class _FakeStreamingService:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:

View file

@ -60,8 +60,14 @@ async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
service = _StreamingService()
agent = _Agent(
[
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content="Hello")},
},
{
"event": "on_chat_model_stream",
"data": {"chunk": _Chunk(content=" world")},
},
]
)
result = StreamingResult()

View file

@ -37,7 +37,9 @@ def test_clear_ignored_for_non_task_tool() -> None:
def test_clear_ignored_when_task_run_id_mismatches() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-open")
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
clear_task_span_if_delegating_task_ended(
state, tool_name="task", run_id="run-other"
)
assert state.active_span_id is not None
assert state.active_task_run_id == "run-open"

View file

@ -240,9 +240,7 @@ class TestToolHeavyTurn:
class TestToolCallSpanMetadata:
def test_input_available_merges_new_metadata_keys_after_start(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_1"})
b.on_tool_input_available(
"call_t",
"task",
@ -257,9 +255,7 @@ class TestToolCallSpanMetadata:
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
)
b.on_tool_input_start("call_t", "task", "lc_t", metadata={"spanId": "spn_keep"})
b.on_tool_input_available(
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
)