diff --git a/.gitignore b/.gitignore index b45b1961c..2e6ed14e8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ node_modules/ .pnpm-store .DS_Store deepagents/ -debug.log \ No newline at end of file +debug.log +opencode/ \ No newline at end of file diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 86bac0aaf..c1bfcc538 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -247,3 +247,42 @@ LANGSMITH_TRACING=true LANGSMITH_ENDPOINT=https://api.smith.langchain.com LANGSMITH_API_KEY=lsv2_pt_..... LANGSMITH_PROJECT=surfsense + + +# ============================================================================= +# OPTIONAL: New-chat agent feature flags +# ============================================================================= +# Master kill-switch — when true, every flag below is forced OFF. +# SURFSENSE_DISABLE_NEW_AGENT_STACK=false + +# Agent quality +# SURFSENSE_ENABLE_CONTEXT_EDITING=false +# SURFSENSE_ENABLE_COMPACTION_V2=false +# SURFSENSE_ENABLE_RETRY_AFTER=false +# SURFSENSE_ENABLE_MODEL_FALLBACK=false +# SURFSENSE_ENABLE_MODEL_CALL_LIMIT=false +# SURFSENSE_ENABLE_TOOL_CALL_LIMIT=false +# SURFSENSE_ENABLE_TOOL_CALL_REPAIR=false +# SURFSENSE_ENABLE_DOOM_LOOP=false # leave OFF until UI handles permission='doom_loop' + +# Safety +# SURFSENSE_ENABLE_PERMISSION=false +# SURFSENSE_ENABLE_BUSY_MUTEX=false +# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call + +# Observability — OTel (also requires OTEL_EXPORTER_OTLP_ENDPOINT) +# SURFSENSE_ENABLE_OTEL=false + +# Skills + subagents +# SURFSENSE_ENABLE_SKILLS=false +# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false +# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false + +# Snapshot / revert +# SURFSENSE_ENABLE_ACTION_LOG=false +# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships + +# Plugins +# SURFSENSE_ENABLE_PLUGIN_LOADER=false +# Comma-separated allowlist of plugin entry-point names +# SURFSENSE_ALLOWED_PLUGINS=year_substituter diff --git a/surfsense_backend/alembic/versions/130_add_agent_action_log.py b/surfsense_backend/alembic/versions/130_add_agent_action_log.py new file mode 100644 index 000000000..f86a8a3b5 --- /dev/null +++ b/surfsense_backend/alembic/versions/130_add_agent_action_log.py @@ -0,0 +1,94 @@ +"""130_add_agent_action_log + +Revision ID: 130 +Revises: 129 +Create Date: 2026-04-28 + +Adds the append-only ``agent_action_log`` table that +:class:`ActionLogMiddleware` writes to after every tool call. Each row +optionally carries a ``reverse_descriptor`` payload used by +``POST /api/threads/{thread_id}/revert/{action_id}`` to undo the action. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "130" +down_revision: str | None = "129" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "agent_action_log", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "thread_id", + sa.Integer(), + sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("turn_id", sa.String(length=64), nullable=True, index=True), + sa.Column("message_id", sa.String(length=128), nullable=True, index=True), + sa.Column("tool_name", sa.String(length=255), nullable=False, index=True), + sa.Column("args", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("result_id", sa.String(length=255), nullable=True), + sa.Column( + "reversible", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "reverse_descriptor", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.Column("error", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "reverse_of", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + op.create_index( + "ix_agent_action_log_thread_created", + "agent_action_log", + ["thread_id", "created_at"], + ) + + +def downgrade() -> None: + op.drop_index("ix_agent_action_log_thread_created", table_name="agent_action_log") + op.drop_table("agent_action_log") diff --git a/surfsense_backend/alembic/versions/131_add_document_revisions.py b/surfsense_backend/alembic/versions/131_add_document_revisions.py new file mode 100644 index 000000000..95ce0e032 --- /dev/null +++ b/surfsense_backend/alembic/versions/131_add_document_revisions.py @@ -0,0 +1,119 @@ +"""131_add_document_revisions + +Revision ID: 131 +Revises: 130 +Create Date: 2026-04-28 + +Adds two snapshot tables that back the per-action revert flow: + +* ``document_revisions``: pre-mutation snapshot of NOTE/FILE/EXTENSION docs. +* ``folder_revisions``: pre-mutation snapshot of folder mkdir/move/delete. + +Both are written by :class:`KnowledgeBasePersistenceMiddleware` ahead of +state-changing tool calls and consumed by ``revert_service.revert_action``. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "131" +down_revision: str | None = "130" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "document_revisions", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "document_id", + sa.Integer(), + sa.ForeignKey("documents.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("content_before", sa.Text(), nullable=True), + sa.Column("title_before", sa.String(), nullable=True), + sa.Column("folder_id_before", sa.Integer(), nullable=True), + sa.Column( + "chunks_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "metadata_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column( + "created_by_turn_id", sa.String(length=64), nullable=True, index=True + ), + sa.Column( + "agent_action_id", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + + op.create_table( + "folder_revisions", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "folder_id", + sa.Integer(), + sa.ForeignKey("folders.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column("name_before", sa.String(length=255), nullable=True), + sa.Column("parent_id_before", sa.Integer(), nullable=True), + sa.Column("position_before", sa.String(length=50), nullable=True), + sa.Column( + "created_by_turn_id", sa.String(length=64), nullable=True, index=True + ), + sa.Column( + "agent_action_id", + sa.Integer(), + sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + ) + + +def downgrade() -> None: + op.drop_table("folder_revisions") + op.drop_table("document_revisions") diff --git a/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py new file mode 100644 index 000000000..ff5b52e18 --- /dev/null +++ b/surfsense_backend/alembic/versions/132_add_agent_permission_rules.py @@ -0,0 +1,81 @@ +"""132_add_agent_permission_rules + +Revision ID: 132 +Revises: 131 +Create Date: 2026-04-28 + +Adds the persistent ``agent_permission_rules`` table consumed by +:class:`PermissionMiddleware` at agent build time. Rules can be scoped +at search-space (``user_id`` / ``thread_id`` NULL), user-wide +(``user_id`` set, ``thread_id`` NULL), or per-thread (``thread_id`` set). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "132" +down_revision: str | None = "131" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "agent_permission_rules", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column( + "search_space_id", + sa.Integer(), + sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("user.id", ondelete="CASCADE"), + nullable=True, + index=True, + ), + sa.Column( + "thread_id", + sa.Integer(), + sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=True, + index=True, + ), + sa.Column("permission", sa.String(length=255), nullable=False), + sa.Column( + "pattern", + sa.String(length=255), + nullable=False, + server_default="*", + ), + sa.Column("action", sa.String(length=16), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.text("(now() AT TIME ZONE 'utc')"), + index=True, + ), + sa.UniqueConstraint( + "search_space_id", + "user_id", + "thread_id", + "permission", + "pattern", + "action", + name="uq_agent_permission_rules_scope", + ), + ) + + +def downgrade() -> None: + op.drop_table("agent_permission_rules") diff --git a/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py new file mode 100644 index 000000000..eec53ecb6 --- /dev/null +++ b/surfsense_backend/alembic/versions/133_drop_documents_content_hash_unique.py @@ -0,0 +1,105 @@ +"""133_drop_documents_content_hash_unique + +Revision ID: 133 +Revises: 132 +Create Date: 2026-04-29 + +Drop the global UNIQUE constraint on ``documents.content_hash`` so the +new-chat agent's ``write_file`` flow can persist legitimate file copies +(two paths, identical content) without hitting a constraint that mirrors +no real filesystem semantic. + +Path uniqueness still lives on ``documents.unique_identifier_hash`` (per +search space), which is the right invariant — exactly like an inode at a +given path on a POSIX filesystem. + +The non-unique INDEX on ``content_hash`` is preserved so connector +indexers' "have we seen this content before?" lookup +(:func:`app.tasks.document_processors.base.check_duplicate_document`, +which already uses ``.scalars().first()`` and is therefore tolerant of +duplicates) stays cheap. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import inspect + +from alembic import op + +revision: str = "133" +down_revision: str | None = "132" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def _existing_constraint_names(bind, table: str) -> set[str]: + inspector = inspect(bind) + return {c["name"] for c in inspector.get_unique_constraints(table)} + + +def _existing_index_names(bind, table: str) -> set[str]: + inspector = inspect(bind) + return {i["name"] for i in inspector.get_indexes(table)} + + +def upgrade() -> None: + bind = op.get_bind() + + # Both the named UniqueConstraint (added in revision 8) and the + # implicit-unique-index variant SQLAlchemy may emit need draining. + constraints = _existing_constraint_names(bind, "documents") + if "uq_documents_content_hash" in constraints: + op.drop_constraint("uq_documents_content_hash", "documents", type_="unique") + + indexes = _existing_index_names(bind, "documents") + # Some Postgres versions surface the unique constraint via a unique + # index of the same name; check for that too. + for idx_name in ("uq_documents_content_hash",): + if idx_name in indexes: + op.drop_index(idx_name, table_name="documents") + + # Ensure the non-unique index is present for fast lookups. + if "ix_documents_content_hash" not in indexes: + op.create_index( + "ix_documents_content_hash", + "documents", + ["content_hash"], + unique=False, + ) + + +def downgrade() -> None: + bind = op.get_bind() + + # Re-applying UNIQUE is destructive: there may now be legitimate + # duplicates (e.g. two NOTE documents that share content because the + # user explicitly copied one to a new path). To avoid the migration + # silently deleting user data, we keep only the lowest-id row per + # content_hash — same strategy revision 8 used when first introducing + # the constraint. + op.execute( + """ + DELETE FROM documents + WHERE id NOT IN ( + SELECT MIN(id) + FROM documents + GROUP BY content_hash + ) + """ + ) + + indexes = _existing_index_names(bind, "documents") + if "ix_documents_content_hash" in indexes: + op.drop_index("ix_documents_content_hash", table_name="documents") + + op.create_index( + "ix_documents_content_hash", + "documents", + ["content_hash"], + unique=False, + ) + op.create_unique_constraint( + "uq_documents_content_hash", "documents", ["content_hash"] + ) diff --git a/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py b/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py new file mode 100644 index 000000000..890b3e06e --- /dev/null +++ b/surfsense_backend/app/agents/autocomplete/autocomplete_agent.py @@ -0,0 +1,557 @@ +"""Vision autocomplete agent with scoped filesystem exploration. + +Converts the stateless single-shot vision autocomplete into an agent that +seeds a virtual filesystem from KB search results and lets the vision LLM +explore documents via ``ls``, ``read_file``, ``glob``, ``grep``, etc. +before generating the final completion. + +Performance: KB search and agent graph compilation run in parallel so +the only sequential latency is KB-search (or agent compile, whichever is +slower) + the agent's LLM turns. There is no separate "query extraction" +LLM call — the window title is used directly as the KB search query. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import uuid +from collections.abc import AsyncGenerator +from typing import Any + +from deepagents.graph import BASE_AGENT_PROMPT +from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware +from langchain.agents import create_agent +from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, ToolMessage + +from app.agents.new_chat.document_xml import build_document_xml +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware +from app.agents.new_chat.middleware.knowledge_search import ( + search_knowledge_base, +) +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + build_path_index, + doc_to_virtual_path, +) +from app.db import shielded_async_session +from app.services.new_streaming_service import VercelStreamingService + +try: + from deepagents.backends.utils import create_file_data +except Exception: # pragma: no cover - defensive + + def create_file_data(content: str) -> dict[str, Any]: + return {"content": content.split("\n")} + + +async def _build_autocomplete_filesystem( + *, + documents: Any, + search_space_id: int, +) -> tuple[dict[str, Any], dict[int, str]]: + """Build a ``state['files']``-shaped dict from KB search results. + + This is the autocomplete-specific replacement for the previous + ``build_scoped_filesystem`` helper. It uses the canonical path resolver + so paths line up with the rest of the system, including collision + suffixes for duplicate titles. + """ + files: dict[str, Any] = {} + doc_id_to_path: dict[int, str] = {} + + if not documents: + return files, doc_id_to_path + + async with shielded_async_session() as session: + index = await build_path_index(session, search_space_id) + + for document in documents: + if not isinstance(document, dict): + continue + meta = document.get("document") or {} + doc_id = meta.get("id") + if not isinstance(doc_id, int): + continue + title = str(meta.get("title") or "untitled") + folder_id = meta.get("folder_id") + path = doc_to_virtual_path( + doc_id=doc_id, title=title, folder_id=folder_id, index=index + ) + chunk_ids = document.get("matched_chunk_ids") or [] + try: + matched_set = {int(c) for c in chunk_ids} + except (TypeError, ValueError): + matched_set = set() + xml = build_document_xml(document, matched_chunk_ids=matched_set) + files[path] = create_file_data(xml) + doc_id_to_path[doc_id] = path + + if not files: + # Ensure the synthetic /documents folder is visible even when empty. + files.setdefault(f"{DOCUMENTS_ROOT}/.placeholder", create_file_data("")) + + return files, doc_id_to_path + + +logger = logging.getLogger(__name__) + +KB_TOP_K = 10 + +# --------------------------------------------------------------------------- +# System prompt +# --------------------------------------------------------------------------- + +AUTOCOMPLETE_SYSTEM_PROMPT = """You are a smart writing assistant that analyzes the user's screen to draft or complete text. + +You will receive a screenshot of the user's screen. Your PRIMARY source of truth is the screenshot itself — the visual context determines what to write. + +Your job: +1. Analyze the ENTIRE screenshot to understand what the user is working on (email thread, chat conversation, document, code editor, form, etc.). +2. Identify the text area where the user will type. +3. Generate the text the user most likely wants to write based on the visual context. + +You also have access to the user's knowledge base documents via filesystem tools. However: +- ONLY consult the knowledge base if the screenshot clearly involves a topic where your KB documents are DIRECTLY relevant (e.g., the user is writing about a specific project/topic that matches a document title). +- Do NOT explore documents just because they exist. Most autocomplete requests can be answered purely from the screenshot. +- If you do read a document, only incorporate information that is 100% relevant to what the user is typing RIGHT NOW. Do not add extra details, background, or tangential information from the KB. +- Keep your output SHORT — autocomplete should feel like a natural continuation, not an essay. + +Key behavior: +- If the text area is EMPTY, draft a concise response or message based on what you see on screen (e.g., reply to an email, respond to a chat message, continue a document). +- If the text area already has text, continue it naturally — typically just a sentence or two. + +Rules: +- Be CONCISE. Prefer a single paragraph or a few sentences. Autocomplete is a quick assist, not a full draft. +- Match the tone and formality of the surrounding context. +- If the screen shows code, write code. If it shows a casual chat, be casual. If it shows a formal email, be formal. +- Do NOT describe the screenshot or explain your reasoning. +- Do NOT cite or reference documents explicitly — just let the knowledge inform your writing naturally. +- If you cannot determine what to write, output an empty JSON array: [] + +## Output Format + +You MUST provide exactly 3 different suggestion options. Each should be a distinct, plausible completion — vary the tone, detail level, or angle. + +Return your suggestions as a JSON array of exactly 3 strings. Output ONLY the JSON array, nothing else — no markdown fences, no explanation, no commentary. + +Example format: +["First suggestion text here.", "Second suggestion — a different take.", "Third option with another approach."] + +## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep` + +All file paths must start with a `/`. +- ls: list files and directories at a given path. +- read_file: read a file from the filesystem. +- write_file: create a temporary file in the session (not persisted). +- edit_file: edit a file in the session (not persisted for /documents/ files). +- glob: find files matching a pattern (e.g., "**/*.xml"). +- grep: search for text within files. + +## When to Use Filesystem Tools + +BEFORE reaching for any tool, ask yourself: "Can I write a good completion purely from the screenshot?" If yes, just write it — do NOT explore the KB. + +Only use tools when: +- The user is clearly writing about a specific topic that likely has detailed information in their KB. +- You need a specific fact, name, number, or reference that the screenshot doesn't provide. + +When you do use tools, be surgical: +- Check the `ls` output first. If no document title looks relevant, stop — do not read files just to see what's there. +- If a title looks relevant, read only the `` (first ~20 lines) and jump to matched chunks. Do not read entire documents. +- Extract only the specific information you need and move on to generating the completion. + +## Reading Documents Efficiently + +Documents are formatted as XML. Each document contains: +- `` — title, type, URL, etc. +- `` — a table of every chunk with its **line range** and a + `matched="true"` flag for chunks that matched the search query. +- `` — the actual chunks in original document order. + +**Workflow**: read the first ~20 lines to see the ``, identify +chunks marked `matched="true"`, then use `read_file(path, offset=, +limit=)` to jump directly to those sections.""" + +APP_CONTEXT_BLOCK = """ + +The user is currently working in "{app_name}" (window: "{window_title}"). Use this to understand the type of application and adapt your tone and format accordingly.""" + + +def _build_autocomplete_system_prompt(app_name: str, window_title: str) -> str: + prompt = AUTOCOMPLETE_SYSTEM_PROMPT + if app_name: + prompt += APP_CONTEXT_BLOCK.format(app_name=app_name, window_title=window_title) + return prompt + + +# --------------------------------------------------------------------------- +# Pre-compute KB filesystem (runs in parallel with agent compilation) +# --------------------------------------------------------------------------- + + +class _KBResult: + """Container for pre-computed KB filesystem results.""" + + __slots__ = ("files", "ls_ai_msg", "ls_tool_msg") + + def __init__( + self, + files: dict[str, Any] | None = None, + ls_ai_msg: AIMessage | None = None, + ls_tool_msg: ToolMessage | None = None, + ) -> None: + self.files = files + self.ls_ai_msg = ls_ai_msg + self.ls_tool_msg = ls_tool_msg + + @property + def has_documents(self) -> bool: + return bool(self.files) + + +async def precompute_kb_filesystem( + search_space_id: int, + query: str, + top_k: int = KB_TOP_K, +) -> _KBResult: + """Search the KB and build the scoped filesystem outside the agent. + + This is designed to be called via ``asyncio.gather`` alongside agent + graph compilation so the two run concurrently. + """ + if not query: + return _KBResult() + + try: + search_results = await search_knowledge_base( + query=query, + search_space_id=search_space_id, + top_k=top_k, + ) + + if not search_results: + return _KBResult() + + new_files, _ = await _build_autocomplete_filesystem( + documents=search_results, + search_space_id=search_space_id, + ) + + if not new_files: + return _KBResult() + + doc_paths = [ + p + for p, v in new_files.items() + if p.startswith("/documents/") and v is not None + ] + tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}" + ai_msg = AIMessage( + content="", + tool_calls=[ + {"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id} + ], + ) + tool_msg = ToolMessage( + content=str(doc_paths) if doc_paths else "No documents found.", + tool_call_id=tool_call_id, + ) + return _KBResult(files=new_files, ls_ai_msg=ai_msg, ls_tool_msg=tool_msg) + + except Exception: + logger.warning( + "KB pre-computation failed, proceeding without KB", exc_info=True + ) + return _KBResult() + + +# --------------------------------------------------------------------------- +# Filesystem middleware — no save_document, no persistence +# --------------------------------------------------------------------------- + + +class AutocompleteFilesystemMiddleware(SurfSenseFilesystemMiddleware): + """Filesystem middleware for autocomplete — read-only exploration only. + + Passes ``search_space_id=None`` so the new persistence pipeline is + bypassed; the autocomplete flow only reads, never commits to Postgres. + """ + + def __init__(self) -> None: + super().__init__(search_space_id=None, created_by_id=None) + + +# --------------------------------------------------------------------------- +# Agent factory +# --------------------------------------------------------------------------- + + +async def _compile_agent( + llm: BaseChatModel, + app_name: str, + window_title: str, +) -> Any: + """Compile the agent graph (CPU-bound, runs in a thread).""" + system_prompt = _build_autocomplete_system_prompt(app_name, window_title) + final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT + + middleware = [ + AutocompleteFilesystemMiddleware(), + PatchToolCallsMiddleware(), + AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), + ] + + agent = await asyncio.to_thread( + create_agent, + llm, + system_prompt=final_system_prompt, + tools=[], + middleware=middleware, + ) + return agent.with_config({"recursion_limit": 200}) + + +async def create_autocomplete_agent( + llm: BaseChatModel, + *, + search_space_id: int, + kb_query: str, + app_name: str = "", + window_title: str = "", +) -> tuple[Any, _KBResult]: + """Create the autocomplete agent and pre-compute KB in parallel. + + Returns ``(agent, kb_result)`` so the caller can inject the pre-computed + filesystem into the agent's initial state without any middleware delay. + """ + agent, kb = await asyncio.gather( + _compile_agent(llm, app_name, window_title), + precompute_kb_filesystem(search_space_id, kb_query), + ) + return agent, kb + + +# --------------------------------------------------------------------------- +# JSON suggestion parsing (with fallback) +# --------------------------------------------------------------------------- + + +def _parse_suggestions(raw: str) -> list[str]: + """Extract a list of suggestion strings from the agent's output. + + Tries, in order: + 1. Direct ``json.loads`` + 2. Extract content between ```json ... ``` fences + 3. Find the first ``[`` … ``]`` span + Falls back to wrapping the raw text as a single suggestion. + """ + text = raw.strip() + if not text: + return [] + + for candidate in _json_candidates(text): + try: + parsed = json.loads(candidate) + if isinstance(parsed, list) and all(isinstance(s, str) for s in parsed): + return [s for s in parsed if s.strip()] + except (json.JSONDecodeError, ValueError): + continue + + return [text] + + +def _json_candidates(text: str) -> list[str]: + """Yield candidate JSON strings from raw text.""" + candidates = [text] + + fence = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL) + if fence: + candidates.append(fence.group(1).strip()) + + bracket = re.search(r"\[.*]", text, re.DOTALL) + if bracket: + candidates.append(bracket.group(0)) + + return candidates + + +# --------------------------------------------------------------------------- +# Streaming helper +# --------------------------------------------------------------------------- + + +async def stream_autocomplete_agent( + agent: Any, + input_data: dict[str, Any], + streaming_service: VercelStreamingService, + *, + emit_message_start: bool = True, +) -> AsyncGenerator[str, None]: + """Stream agent events as Vercel SSE, with thinking steps for tool calls. + + When ``emit_message_start`` is False the caller has already sent the + ``message_start`` event (e.g. to show preparation steps before the agent + runs). + """ + thread_id = uuid.uuid4().hex + config = {"configurable": {"thread_id": thread_id}} + + text_buffer: list[str] = [] + active_tool_depth = 0 + thinking_step_counter = 0 + tool_step_ids: dict[str, str] = {} + step_titles: dict[str, str] = {} + completed_step_ids: set[str] = set() + last_active_step_id: str | None = None + + def next_thinking_step_id() -> str: + nonlocal thinking_step_counter + thinking_step_counter += 1 + return f"autocomplete-step-{thinking_step_counter}" + + def complete_current_step() -> str | None: + nonlocal last_active_step_id + if last_active_step_id and last_active_step_id not in completed_step_ids: + completed_step_ids.add(last_active_step_id) + title = step_titles.get(last_active_step_id, "Done") + event = streaming_service.format_thinking_step( + step_id=last_active_step_id, + title=title, + status="complete", + ) + last_active_step_id = None + return event + return None + + if emit_message_start: + yield streaming_service.format_message_start() + + gen_step_id = next_thinking_step_id() + last_active_step_id = gen_step_id + step_titles[gen_step_id] = "Generating suggestions" + yield streaming_service.format_thinking_step( + step_id=gen_step_id, + title="Generating suggestions", + status="in_progress", + ) + + try: + async for event in agent.astream_events( + input_data, config=config, version="v2" + ): + event_type = event.get("event", "") + if event_type == "on_chat_model_stream": + if active_tool_depth > 0: + continue + if "surfsense:internal" in event.get("tags", []): + continue + chunk = event.get("data", {}).get("chunk") + if chunk and hasattr(chunk, "content"): + content = chunk.content + if content and isinstance(content, str): + text_buffer.append(content) + + elif event_type == "on_chat_model_end": + if active_tool_depth > 0: + continue + if "surfsense:internal" in event.get("tags", []): + continue + output = event.get("data", {}).get("output") + if output and hasattr(output, "content"): + if getattr(output, "tool_calls", None): + continue + content = output.content + if content and isinstance(content, str) and not text_buffer: + text_buffer.append(content) + + elif event_type == "on_tool_start": + active_tool_depth += 1 + tool_name = event.get("name", "unknown_tool") + run_id = event.get("run_id", "") + tool_input = event.get("data", {}).get("input", {}) + + step_event = complete_current_step() + if step_event: + yield step_event + + tool_step_id = next_thinking_step_id() + tool_step_ids[run_id] = tool_step_id + last_active_step_id = tool_step_id + + title, items = _describe_tool_call(tool_name, tool_input) + step_titles[tool_step_id] = title + yield streaming_service.format_thinking_step( + step_id=tool_step_id, + title=title, + status="in_progress", + items=items, + ) + + elif event_type == "on_tool_end": + active_tool_depth = max(0, active_tool_depth - 1) + run_id = event.get("run_id", "") + step_id = tool_step_ids.pop(run_id, None) + if step_id and step_id not in completed_step_ids: + completed_step_ids.add(step_id) + title = step_titles.get(step_id, "Done") + yield streaming_service.format_thinking_step( + step_id=step_id, + title=title, + status="complete", + ) + if last_active_step_id == step_id: + last_active_step_id = None + + step_event = complete_current_step() + if step_event: + yield step_event + + raw_text = "".join(text_buffer) + suggestions = _parse_suggestions(raw_text) + + yield streaming_service.format_data("suggestions", {"options": suggestions}) + + yield streaming_service.format_finish() + yield streaming_service.format_done() + + except Exception as e: + logger.error(f"Autocomplete agent streaming error: {e}", exc_info=True) + yield streaming_service.format_error("Autocomplete failed. Please try again.") + yield streaming_service.format_done() + + +def _describe_tool_call(tool_name: str, tool_input: Any) -> tuple[str, list[str]]: + """Return a human-readable (title, items) for a tool call thinking step.""" + inp = tool_input if isinstance(tool_input, dict) else {} + if tool_name == "ls": + path = inp.get("path", "/") + return "Listing files", [path] + if tool_name == "read_file": + fp = inp.get("file_path", "") + display = fp if len(fp) <= 80 else "…" + fp[-77:] + return "Reading file", [display] + if tool_name == "write_file": + fp = inp.get("file_path", "") + display = fp if len(fp) <= 80 else "…" + fp[-77:] + return "Writing file", [display] + if tool_name == "edit_file": + fp = inp.get("file_path", "") + display = fp if len(fp) <= 80 else "…" + fp[-77:] + return "Editing file", [display] + if tool_name == "glob": + pat = inp.get("pattern", "") + base = inp.get("path", "/") + return "Searching files", [f"{pat} in {base}"] + if tool_name == "grep": + pat = inp.get("pattern", "") + path = inp.get("path", "") + display_pat = pat[:60] + ("…" if len(pat) > 60 else "") + return "Searching content", [ + f'"{display_pat}"' + (f" in {path}" if path else "") + ] + return f"Using {tool_name}", [] diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 73a39ccbf..bfb94ba2d 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -23,9 +23,16 @@ from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_v from deepagents.backends import StateBackend from deepagents.graph import BASE_AGENT_PROMPT from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware +from deepagents.middleware.skills import SkillsMiddleware from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT from langchain.agents import create_agent -from langchain.agents.middleware import TodoListMiddleware +from langchain.agents.middleware import ( + LLMToolSelectorMiddleware, + ModelCallLimitMiddleware, + ModelFallbackMiddleware, + TodoListMiddleware, + ToolCallLimitMiddleware, +) from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool @@ -33,24 +40,54 @@ from langgraph.types import Checkpointer from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.context import SurfSenseContextSchema +from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags from app.agents.new_chat.filesystem_backends import build_backend_resolver -from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import AgentConfig from app.agents.new_chat.middleware import ( + ActionLogMiddleware, + AnonymousDocumentMiddleware, + BusyMutexMiddleware, + ClearToolUsesEdit, DedupHITLToolCallsMiddleware, + DoomLoopMiddleware, FileIntentMiddleware, - KnowledgeBaseSearchMiddleware, + KnowledgeBasePersistenceMiddleware, + KnowledgePriorityMiddleware, + KnowledgeTreeMiddleware, MemoryInjectionMiddleware, + NoopInjectionMiddleware, + OtelSpanMiddleware, + PermissionMiddleware, + RetryAfterMiddleware, + SpillingContextEditingMiddleware, + SpillToBackendEdit, SurfSenseFilesystemMiddleware, + ToolCallNameRepairMiddleware, + build_skills_backend_factory, + create_surfsense_compaction_middleware, + default_skills_sources, ) -from app.agents.new_chat.middleware.safe_summarization import ( - create_safe_summarization_middleware, +from app.agents.new_chat.permissions import Rule, Ruleset +from app.agents.new_chat.plugin_loader import ( + PluginContext, + load_allowed_plugin_names_from_env, + load_plugin_middlewares, ) +from app.agents.new_chat.subagents import build_specialized_subagents from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools +from app.agents.new_chat.tools.invalid_tool import ( + INVALID_TOOL_NAME, + invalid_tool, +) +from app.agents.new_chat.tools.registry import ( + BUILTIN_TOOLS, + build_tools_async, + get_connector_gated_tools, +) from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger @@ -243,7 +280,12 @@ async def create_surfsense_deep_agent( """ _t_agent_total = time.perf_counter() filesystem_selection = filesystem_selection or FilesystemSelection() - backend_resolver = build_backend_resolver(filesystem_selection) + backend_resolver = build_backend_resolver( + filesystem_selection, + search_space_id=search_space_id + if filesystem_selection.mode == FilesystemMode.CLOUD + else None, + ) # Discover available connectors and document types for this search space available_connectors: list[str] | None = None @@ -294,11 +336,11 @@ async def create_surfsense_deep_agent( } modified_disabled_tools = list(disabled_tools) if disabled_tools else [] - modified_disabled_tools.extend( - get_connector_gated_tools(available_connectors) - ) + modified_disabled_tools.extend(get_connector_gated_tools(available_connectors)) - # Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware. + # Remove direct KB search tool; KnowledgePriorityMiddleware now runs hybrid + # search per turn and surfaces hits as a hint plus + # `` markers inside lazy-loaded XML. if "search_knowledge_base" not in modified_disabled_tools: modified_disabled_tools.append("search_knowledge_base") @@ -310,6 +352,18 @@ async def create_surfsense_deep_agent( disabled_tools=modified_disabled_tools, additional_tools=list(additional_tools) if additional_tools else None, ) + + # Register the ``invalid`` tool only when tool-call repair is on. It + # is dispatched only when :class:`ToolCallNameRepairMiddleware` + # rewrites a malformed call. We intentionally append it AFTER + # ``build_tools_async`` so it never appears in the system-prompt + # tool list (which is built from the registry, not the bound tool + # list). + _flags: AgentFeatureFlags = get_flags() + if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in { + t.name for t in tools + }: + tools = [*list(tools), invalid_tool] _perf_log.info( "[create_agent] build_tools_async in %.3fs (%d tools)", time.perf_counter() - _t0, @@ -328,7 +382,8 @@ async def create_surfsense_deep_agent( meta = getattr(t, "metadata", None) or {} if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): _mcp_connector_tools.setdefault( - meta["mcp_connector_name"], [], + meta["mcp_connector_name"], + [], ).append(t.name) if _mcp_connector_tools: @@ -355,7 +410,139 @@ async def create_surfsense_deep_agent( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 ) - # -- Build the middleware stack (mirrors create_deep_agent internals) ------ + # Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent) + final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT + + # The middleware stack — and especially ``SubAgentMiddleware`` — is *not* + # cheap to build. ``SubAgentMiddleware.__init__`` calls ``create_agent`` + # synchronously to compile the general-purpose subagent's full state graph + # (every tool + every middleware → pydantic schemas + langgraph compile). + # On gpt-5.x agents that's roughly 1.5-2s of pure CPU work. If we run it + # directly here it blocks the asyncio event loop for the whole streaming + # task (and any other coroutine sharing this loop), which is why + # "agent creation" wall-clock time used to stretch to ~3-4s. Move the + # entire middleware build + main-graph compile into a single + # ``asyncio.to_thread`` so the heavy CPU work runs off-loop and the + # event loop stays responsive. + _t0 = time.perf_counter() + agent = await asyncio.to_thread( + _build_compiled_agent_blocking, + llm=llm, + tools=tools, + final_system_prompt=final_system_prompt, + backend_resolver=backend_resolver, + filesystem_mode=filesystem_selection.mode, + search_space_id=search_space_id, + user_id=user_id, + thread_id=thread_id, + visibility=visibility, + anon_session_id=anon_session_id, + available_connectors=available_connectors, + available_document_types=available_document_types, + mentioned_document_ids=mentioned_document_ids, + max_input_tokens=_max_input_tokens, + flags=_flags, + checkpointer=checkpointer, + ) + _perf_log.info( + "[create_agent] Middleware stack + graph compiled in %.3fs", + time.perf_counter() - _t0, + ) + + _perf_log.info( + "[create_agent] Total agent creation in %.3fs", + time.perf_counter() - _t_agent_total, + ) + return agent + + +# Tools whose output is too costly / lossy to discard. Keep this +# conservative — anything listed here is *never* pruned by +# :class:`ContextEditingMiddleware`. The list is filtered against +# actually-bound tool names so disabled connectors don't show up here. +_PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset( + { + "generate_report", + "generate_resume", + "generate_podcast", + "generate_video_presentation", + "generate_image", + # Read-heavy connector reads — recomputing them is expensive + "read_email", + "search_emails", + # The fallback for malformed tool calls — keep its replies visible + "invalid", + } +) + + +def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]: + """Return ``exclude_tools`` derived from the actually-bound tool list. + + Filters :data:`_PRUNE_PROTECTED_TOOL_NAMES` against the bound tools + so we never list tools that don't exist (would be a silent no-op). + """ + enabled = {t.name for t in tools} + return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled) + + +# Connector gating: any tool whose ``ToolDefinition.required_connector`` +# isn't actually wired up gets a synthesized permission deny rule so +# execution attempts short-circuit with ``permission_denied`` instead of +# bubbling up provider-specific 401/404 errors. Mirrors OpenCode's +# ``Permission.disabled`` (declarative, per-tool gating) — replaces the +# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE`` substring-heuristic. +def _synthesize_connector_deny_rules( + *, + available_connectors: list[str] | None, + enabled_tool_names: set[str], +) -> list[Rule]: + """Build deny rules for tools whose required connector is not enabled. + + Source of truth is ``ToolDefinition.required_connector`` in + :data:`BUILTIN_TOOLS`. A tool only gets a deny rule when: + + 1. It is currently bound (``enabled_tool_names``). + 2. It declares a ``required_connector``. + 3. That connector is *not* in ``available_connectors``. + """ + available = set(available_connectors or []) + deny: list[Rule] = [] + for tool_def in BUILTIN_TOOLS: + if tool_def.name not in enabled_tool_names: + continue + rc = tool_def.required_connector + if rc and rc not in available: + deny.append(Rule(permission=tool_def.name, pattern="*", action="deny")) + return deny + + +def _build_compiled_agent_blocking( + *, + llm: BaseChatModel, + tools: Sequence[BaseTool], + final_system_prompt: str, + backend_resolver: Any, + filesystem_mode: FilesystemMode, + search_space_id: int, + user_id: str | None, + thread_id: int | None, + visibility: ChatVisibility, + anon_session_id: str | None, + available_connectors: list[str] | None, + available_document_types: list[str] | None, + mentioned_document_ids: list[int] | None, + max_input_tokens: int | None, + flags: AgentFeatureFlags, + checkpointer: Checkpointer, +): + """Build the middleware stack and compile the agent graph synchronously. + + Runs in a worker thread (see ``asyncio.to_thread`` call site) so the heavy + CPU work — most notably ``SubAgentMiddleware.__init__`` eagerly calling + ``create_agent`` to compile the general-purpose subagent — does not block + the event loop. + """ _memory_middleware = MemoryInjectionMiddleware( user_id=user_id, search_space_id=search_space_id, @@ -363,18 +550,23 @@ async def create_surfsense_deep_agent( ) # General-purpose subagent middleware + # Subagent omits AnonymousDocumentMiddleware, KnowledgeTreeMiddleware, + # KnowledgePriorityMiddleware, and KnowledgeBasePersistenceMiddleware - it + # inherits state and tools from the parent, but should not (a) re-load + # anon docs / re-render the tree / re-run hybrid search, or (b) commit at + # its own completion (only the top-level agent's aafter_agent commits). gp_middleware = [ TodoListMiddleware(), _memory_middleware, FileIntentMiddleware(llm=llm), SurfSenseFilesystemMiddleware( backend=backend_resolver, - filesystem_mode=filesystem_selection.mode, + filesystem_mode=filesystem_mode, search_space_id=search_space_id, created_by_id=user_id, thread_id=thread_id, ), - create_safe_summarization_middleware(llm, StateBackend), + create_surfsense_compaction_middleware(llm, StateBackend), PatchToolCallsMiddleware(), AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] @@ -386,48 +578,416 @@ async def create_surfsense_deep_agent( "middleware": gp_middleware, } + # Specialized user-facing subagents (explore, report_writer, + # connector_negotiator). Registered through SubAgentMiddleware alongside + # the general-purpose spec so the parent's `task` tool can address them + # by name. Off by default until the flag flips so existing deployments + # don't see new agent types in the task tool description. + specialized_subagents: list[SubAgent] = [] + if flags.enable_specialized_subagents and not flags.disable_new_agent_stack: + try: + # Specialized subagents share the parent's filesystem + + # todo view so their system prompts (which promise + # ``read_file``, ``ls``, ``grep``, ``glob``, ``write_todos``) + # actually match runtime behavior. Build *fresh* instances + # rather than aliasing the parent's GP middleware to avoid + # subtle state coupling across compiled graphs. + subagent_extra_middleware: list = [ + TodoListMiddleware(), + SurfSenseFilesystemMiddleware( + backend=backend_resolver, + filesystem_mode=filesystem_mode, + search_space_id=search_space_id, + created_by_id=user_id, + thread_id=thread_id, + ), + ] + specialized_subagents = build_specialized_subagents( + tools=tools, + model=llm, + extra_middleware=subagent_extra_middleware, + ) + except Exception as exc: # pragma: no cover - defensive + logging.warning( + "Specialized subagent build failed; running without them: %s", + exc, + ) + specialized_subagents = [] + + subagent_specs: list[SubAgent] = [general_purpose_spec, *specialized_subagents] + # Main agent middleware + # Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ... + # before_agent hooks run in declared order; later injections sit closer to + # the latest human turn. Tree (large + cacheable) is injected earliest so + # provider-side prefix caching has more material to hit; FileIntent (most + # actionable per-turn contract) is injected closest to the user message. + # + # ``wrap_model_call`` ordering: the FIRST middleware in the list is the + # OUTERMOST wrapper. To ensure prune executes before summarization, + # place ``SpillingContextEditingMiddleware`` before + # ``SurfSenseCompactionMiddleware``. Compaction is the canonical + # token-budget defense; the Bedrock buffer-empty defense is folded + # into ``SurfSenseCompactionMiddleware``. + summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend) + _ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity + + # ContextEditing prune. Trigger at 55% of ``max_input_tokens``, + # earlier than summarization (~85%). When disabled, no edit runs. + context_edit_mw = None + if ( + flags.enable_context_editing + and not flags.disable_new_agent_stack + and max_input_tokens + ): + spill_edit = SpillToBackendEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=_safe_exclude_tools(tools), + clear_tool_inputs=True, + ) + clear_edit = ClearToolUsesEdit( + trigger=int(max_input_tokens * 0.55), + clear_at_least=int(max_input_tokens * 0.15), + keep=5, + exclude_tools=_safe_exclude_tools(tools), + clear_tool_inputs=True, + placeholder="[cleared - older tool output trimmed for context]", + ) + context_edit_mw = SpillingContextEditingMiddleware( + edits=[spill_edit, clear_edit], + backend_resolver=backend_resolver, + ) + + # Resilience knobs: header-aware retry, model fallback, and + # per-thread / per-run call-count limits. The fallback / limit + # middlewares are vanilla LangChain primitives; ``RetryAfter`` is + # SurfSense's header-aware variant (see its module docstring). + retry_mw = ( + RetryAfterMiddleware(max_retries=3) + if flags.enable_retry_after and not flags.disable_new_agent_stack + else None + ) + # Fallback chain — primary is the agent's own model; we add cheap + # alternatives. Off by default; only the first call site that + # configures the chain via env should enable it. + fallback_mw: ModelFallbackMiddleware | None = None + if flags.enable_model_fallback and not flags.disable_new_agent_stack: + try: + fallback_mw = ModelFallbackMiddleware( + "openai:gpt-4o-mini", + "anthropic:claude-3-5-haiku-20241022", + ) + except Exception: + logging.warning("ModelFallbackMiddleware init failed; skipping.") + fallback_mw = None + model_call_limit_mw = ( + ModelCallLimitMiddleware( + thread_limit=120, + run_limit=80, + exit_behavior="end", + ) + if flags.enable_model_call_limit and not flags.disable_new_agent_stack + else None + ) + tool_call_limit_mw = ( + ToolCallLimitMiddleware( + thread_limit=300, run_limit=80, exit_behavior="continue" + ) + if flags.enable_tool_call_limit and not flags.disable_new_agent_stack + else None + ) + + # Provider-compat ``_noop`` injection (mirrors OpenCode's + # ``llm.ts`` workaround for providers that reject empty assistant + # turns or alternating-role constraints). + noop_mw = ( + NoopInjectionMiddleware() + if flags.enable_compaction_v2 and not flags.disable_new_agent_stack + else None + ) + + # Tool-call name repair (lowercase + ``invalid`` fallback). + # + # ``registered_tool_names`` MUST cover every tool the model can legitimately + # call. That includes the bound ``tools`` list AND every tool provided by + # middleware in the stack — ``FilesystemMiddleware`` (read_file, ls, grep, + # glob, edit_file, write_file, execute), ``TodoListMiddleware`` + # (write_todos), ``SubAgentMiddleware`` (task), ``SkillsMiddleware`` (skill + # loaders), etc. If we only inspect ``tools`` here, every call to + # ``read_file`` / ``ls`` / ``grep`` from the model will be rewritten to + # ``invalid`` because the repair middleware doesn't recognize them. The + # built-in deepagents middleware aren't in scope yet at this point of the + # function but they're added unconditionally below, so we hard-code their + # canonical names alongside the dynamic ``tools`` set. + repair_mw = None + if flags.enable_tool_call_repair and not flags.disable_new_agent_stack: + registered_names: set[str] = {t.name for t in tools} + # Tools owned by the standard deepagents middleware stack. + registered_names |= { + "write_todos", + "ls", + "read_file", + "write_file", + "edit_file", + "glob", + "grep", + "execute", + "task", + } + repair_mw = ToolCallNameRepairMiddleware( + registered_tool_names=registered_names, + # Disable fuzzy matching to avoid silent rewrites; the + # lowercase + ``invalid`` fallback alone covers >95% of + # observed model errors. + fuzzy_match_threshold=None, + ) + + # Doom-loop detector. Off by default until the frontend handles + # ``permission == "doom_loop"`` interrupts. + doom_loop_mw = ( + DoomLoopMiddleware(threshold=3) + if flags.enable_doom_loop and not flags.disable_new_agent_stack + else None + ) + + # PermissionMiddleware. Layers, earliest -> latest (last match wins, + # same evaluation order as OpenCode's ``permission/index.ts``): + # + # 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense + # already runs per-tool HITL (see ``tools/hitl.py``) for mutating + # connector tools, so we only want PermissionMiddleware to *deny* + # things the user has gated off; the default fallback in + # ``permissions.evaluate`` is ``ask``, which would double-prompt + # on every safe read-only call (``ls``, ``read_file``, ``grep``, + # ``glob``, ``web_search`` …) and, on resume, replay the previous + # reject decision into innocent calls. + # 2. ``connector_synthesized`` — deny rules for tools whose required + # connector is not connected to this space. Overrides #1. + # 3. (future) user-defined rules from ``agent_permission_rules`` table + # via the Agent Permissions UI. Loaded last so they override both. + permission_mw: PermissionMiddleware | None = None + if flags.enable_permission and not flags.disable_new_agent_stack: + synthesized = _synthesize_connector_deny_rules( + available_connectors=available_connectors, + enabled_tool_names={t.name for t in tools}, + ) + permission_mw = PermissionMiddleware( + rulesets=[ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + Ruleset(rules=synthesized, origin="connector_synthesized"), + ], + ) + + # ActionLogMiddleware. Off by default until the ``agent_action_log`` + # table is migrated. When enabled, persists one row per tool call + # with optional reverse_descriptor for + # ``POST /api/threads/{thread_id}/revert/{action_id}``. Sits inside + # ``permission`` so denied calls aren't logged as completions. + action_log_mw: ActionLogMiddleware | None = None + if ( + flags.enable_action_log + and not flags.disable_new_agent_stack + and thread_id is not None + ): + try: + tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS} + action_log_mw = ActionLogMiddleware( + thread_id=thread_id, + search_space_id=search_space_id, + user_id=user_id, + tool_definitions=tool_defs_by_name, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "ActionLogMiddleware init failed; running without it.", + exc_info=True, + ) + action_log_mw = None + + # Per-thread busy mutex (refuse a second concurrent turn on the same + # thread; see :class:`BusyMutexMiddleware` docstring). + busy_mutex_mw: BusyMutexMiddleware | None = ( + BusyMutexMiddleware() + if flags.enable_busy_mutex and not flags.disable_new_agent_stack + else None + ) + + # OpenTelemetry spans (model.call + tool.call). Lives just inside + # BusyMutex so it spans every retry/fallback attempt of the current + # turn but never wraps a queued/blocked turn. + otel_mw: OtelSpanMiddleware | None = ( + OtelSpanMiddleware() + if flags.enable_otel and not flags.disable_new_agent_stack + else None + ) + + # Plugin entry-point loader. Off by default; opt-in via the + # ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from + # the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future + # PR can wire it through ``global_llm_config.yaml``. + plugin_middlewares: list[Any] = [] + if flags.enable_plugin_loader and not flags.disable_new_agent_stack: + try: + allowed_names = load_allowed_plugin_names_from_env() + if allowed_names: + plugin_middlewares = load_plugin_middlewares( + PluginContext.build( + search_space_id=search_space_id, + user_id=user_id, + thread_visibility=visibility, + llm=llm, + ), + allowed_plugin_names=allowed_names, + ) + except Exception: # pragma: no cover - defensive + logging.warning( + "Plugin loader failed; continuing without plugins.", + exc_info=True, + ) + plugin_middlewares = [] + + # SkillsMiddleware (deepagents) loads built-in + space-authored + # skills via a CompositeBackend. Sources are layered: built-in first, + # space last, so a search-space-authored skill of the same name + # overrides the bundled one. + skills_mw: SkillsMiddleware | None = None + if flags.enable_skills and not flags.disable_new_agent_stack: + try: + skills_factory = build_skills_backend_factory( + search_space_id=search_space_id + if filesystem_mode == FilesystemMode.CLOUD + else None, + ) + skills_mw = SkillsMiddleware( + backend=skills_factory, + sources=default_skills_sources(), + ) + except Exception as exc: # pragma: no cover - defensive + logging.warning("SkillsMiddleware init failed; skipping: %s", exc) + skills_mw = None + + # LangChain's LLM-driven tool selection — only enabled for stacks + # large enough to need narrowing (>30 tools). + selector_mw: LLMToolSelectorMiddleware | None = None + if ( + flags.enable_llm_tool_selector + and not flags.disable_new_agent_stack + and len(tools) > 30 + ): + try: + selector_mw = LLMToolSelectorMiddleware( + model="openai:gpt-4o-mini", + max_tools=12, + always_include=[ + name + for name in ( + "update_memory", + "get_connected_accounts", + "scrape_webpage", + ) + if name in {t.name for t in tools} + ], + ) + except Exception: + logging.warning("LLMToolSelectorMiddleware init failed; skipping.") + selector_mw = None + deepagent_middleware = [ + # BusyMutex is OUTERMOST: it must wrap the entire stream so no + # other turn can sneak in while this one is mid-flight. + busy_mutex_mw, + # OTel spans sit just inside BusyMutex so each retry attempt + # gets its own model.call / tool.call span. + otel_mw, TodoListMiddleware(), _memory_middleware, - FileIntentMiddleware(llm=llm), - KnowledgeBaseSearchMiddleware( + AnonymousDocumentMiddleware( + anon_session_id=anon_session_id, + ) + if filesystem_mode == FilesystemMode.CLOUD + else None, + KnowledgeTreeMiddleware( + search_space_id=search_space_id, + filesystem_mode=filesystem_mode, + llm=llm, + ) + if filesystem_mode == FilesystemMode.CLOUD + else None, + KnowledgePriorityMiddleware( llm=llm, search_space_id=search_space_id, - filesystem_mode=filesystem_selection.mode, + filesystem_mode=filesystem_mode, available_connectors=available_connectors, available_document_types=available_document_types, mentioned_document_ids=mentioned_document_ids, - anon_session_id=anon_session_id, ), + FileIntentMiddleware(llm=llm), SurfSenseFilesystemMiddleware( backend=backend_resolver, - filesystem_mode=filesystem_selection.mode, + filesystem_mode=filesystem_mode, search_space_id=search_space_id, created_by_id=user_id, thread_id=thread_id, ), - SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]), - create_safe_summarization_middleware(llm, StateBackend), + KnowledgeBasePersistenceMiddleware( + search_space_id=search_space_id, + created_by_id=user_id, + filesystem_mode=filesystem_mode, + ) + if filesystem_mode == FilesystemMode.CLOUD + else None, + # Skill loader. Placed before SubAgentMiddleware so subagents + # inherit the same skill metadata (subagent specs reference the + # same source paths via ``default_skills_sources()``). + skills_mw, + SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs), + # Tool selection (only when >30 tools and flag on). + selector_mw, + # Defensive caps, then prune, then summarize. + model_call_limit_mw, + tool_call_limit_mw, + context_edit_mw, + summarization_mw, + # Provider compatibility + retry chain — placed after prune/compact + # so retries happen on the already-trimmed payload. + noop_mw, + retry_mw, + fallback_mw, + # Tool-call repair must run after model emits but before + # permission / dedup / doom-loop interpret the calls. + repair_mw, + # Permission deny/ask BEFORE the calls are forwarded to tool nodes. + permission_mw, + doom_loop_mw, + # Action log sits inside permission so denied calls don't appear + # as completions, and outside dedup so each unique tool invocation + # gets its own row. + action_log_mw, PatchToolCallsMiddleware(), - DedupHITLToolCallsMiddleware(agent_tools=tools), + DedupHITLToolCallsMiddleware(agent_tools=list(tools)), + # Plugin slot — sits just before AnthropicCache so plugin-side + # transforms see the final tool result and run before any + # caching heuristics. Multiple plugins in declared order; loader + # filtered by the admin allowlist already. + *plugin_middlewares, AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"), ] + deepagent_middleware = [m for m in deepagent_middleware if m is not None] - # Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent) - final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT - - _t0 = time.perf_counter() - agent = await asyncio.to_thread( - create_agent, + agent = create_agent( llm, system_prompt=final_system_prompt, - tools=tools, + tools=list(tools), middleware=deepagent_middleware, context_schema=SurfSenseContextSchema, checkpointer=checkpointer, ) - agent = agent.with_config( + return agent.with_config( { "recursion_limit": 10_000, "metadata": { @@ -436,13 +996,3 @@ async def create_surfsense_deep_agent( }, } ) - _perf_log.info( - "[create_agent] Graph compiled (create_agent) in %.3fs", - time.perf_counter() - _t0, - ) - - _perf_log.info( - "[create_agent] Total agent creation in %.3fs", - time.perf_counter() - _t_agent_total, - ) - return agent diff --git a/surfsense_backend/app/agents/new_chat/document_xml.py b/surfsense_backend/app/agents/new_chat/document_xml.py new file mode 100644 index 000000000..60e586ae1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/document_xml.py @@ -0,0 +1,103 @@ +"""Shared XML builder for KB documents. + +Produces the citation-friendly XML used by every read of a knowledge-base +document (lazy-loaded by :class:`KBPostgresBackend` and synthetic anonymous +files). The XML carries a ```` near the top so the LLM can jump +directly to matched-chunk line ranges via ``read_file(offset=…, limit=…)``. + +Extracted from the original ``knowledge_search.py`` so the backend, the +priority middleware, and any future renderer share a single implementation. +""" + +from __future__ import annotations + +import json +from typing import Any + + +def build_document_xml( + document: dict[str, Any], + matched_chunk_ids: set[int] | None = None, +) -> str: + """Build citation-friendly XML with a ```` for smart seeking. + + Args: + document: Dict shape produced by hybrid search / lazy-load helpers. + Expected keys: ``document`` (with ``id``, ``title``, + ``document_type``, ``metadata``) and ``chunks`` + (list of ``{chunk_id, content}``). + matched_chunk_ids: Optional set of chunk IDs to flag as + ``matched="true"`` in the chunk index. + """ + matched = matched_chunk_ids or set() + + doc_meta = document.get("document") or {} + metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {} + document_id = doc_meta.get("id", document.get("document_id", "unknown")) + document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN")) + title = doc_meta.get("title") or metadata.get("title") or "Untitled Document" + url = ( + metadata.get("url") or metadata.get("source") or metadata.get("page_url") or "" + ) + metadata_json = json.dumps(metadata, ensure_ascii=False) + + metadata_lines: list[str] = [ + "", + "", + f" {document_id}", + f" {document_type}", + f" <![CDATA[{title}]]>", + f" ", + f" ", + "", + "", + ] + + chunks = document.get("chunks") or [] + chunk_entries: list[tuple[int | None, str]] = [] + if isinstance(chunks, list): + for chunk in chunks: + if not isinstance(chunk, dict): + continue + chunk_id = chunk.get("chunk_id") or chunk.get("id") + chunk_content = str(chunk.get("content", "")).strip() + if not chunk_content: + continue + if chunk_id is None: + xml = f" " + else: + xml = f" " + chunk_entries.append((chunk_id, xml)) + + index_overhead = 1 + len(chunk_entries) + 1 + 1 + 1 + first_chunk_line = len(metadata_lines) + index_overhead + 1 + + current_line = first_chunk_line + index_entry_lines: list[str] = [] + for cid, xml_str in chunk_entries: + num_lines = xml_str.count("\n") + 1 + end_line = current_line + num_lines - 1 + matched_attr = ' matched="true"' if cid is not None and cid in matched else "" + if cid is not None: + index_entry_lines.append( + f' ' + ) + else: + index_entry_lines.append( + f' ' + ) + current_line = end_line + 1 + + lines = metadata_lines.copy() + lines.append("") + lines.extend(index_entry_lines) + lines.append("") + lines.append("") + lines.append("") + for _, xml_str in chunk_entries: + lines.append(xml_str) + lines.extend(["", ""]) + return "\n".join(lines) + + +__all__ = ["build_document_xml"] diff --git a/surfsense_backend/app/agents/new_chat/errors.py b/surfsense_backend/app/agents/new_chat/errors.py new file mode 100644 index 000000000..a17333acc --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/errors.py @@ -0,0 +1,95 @@ +""" +Typed error taxonomy for the SurfSense agent stack. + +Used by: +- :class:`RetryAfterMiddleware` — its ``retry_on`` callable consults + the error code to decide whether a retry is appropriate. +- :class:`PermissionMiddleware` — emits ``code="permission_denied"`` + errors when a deny rule trips. +- All tools — return :class:`StreamingError` payloads in + ``ToolMessage.additional_kwargs["error"]`` so the model and the + retry/permission layers share a contract. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +ErrorCode = Literal[ + "rate_limit", + "auth", + "tool_validation", + "tool_runtime", + "context_overflow", + "provider", + "permission_denied", + "doom_loop", + "busy", + "cancelled", +] + + +class StreamingError(BaseModel): + """Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``. + + Tools and middleware emit this so retry, permission, and routing + layers can decide what to do without parsing free-form strings. + """ + + code: ErrorCode + retryable: bool = False + suggestion: str | None = None + correlation_id: str | None = None + detail: str | None = Field( + default=None, + description="Free-form additional context. Not surfaced to the model.", + ) + + class Config: + frozen = True + + +class RejectedError(Exception): + """Raised when the user rejects a permission ask without feedback. + + Caught by :class:`PermissionMiddleware`; the agent stops the current + tool fan-out and surfaces a user-facing rejection. + """ + + def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None: + super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}") + self.tool = tool + self.pattern = pattern + + +class CorrectedError(Exception): + """Raised when the user rejects a permission ask *with* feedback. + + The :class:`PermissionMiddleware` translates the feedback into a + synthetic ``ToolMessage`` so the model sees the user's correction + and can retry the request differently. + """ + + def __init__(self, feedback: str, *, tool: str | None = None) -> None: + super().__init__(feedback) + self.feedback = feedback + self.tool = tool + + +class BusyError(Exception): + """Raised when a second prompt arrives while the same thread is mid-stream.""" + + def __init__(self, request_id: str | None = None) -> None: + super().__init__("Thread is busy with another request") + self.request_id = request_id + + +__all__ = [ + "BusyError", + "CorrectedError", + "ErrorCode", + "RejectedError", + "StreamingError", +] diff --git a/surfsense_backend/app/agents/new_chat/feature_flags.py b/surfsense_backend/app/agents/new_chat/feature_flags.py new file mode 100644 index 000000000..55525abc5 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/feature_flags.py @@ -0,0 +1,199 @@ +""" +Feature flags for the SurfSense new_chat agent stack. + +These flags gate the newer agent middleware (some ported from OpenCode, +some sourced from ``langchain.agents.middleware`` / ``deepagents``, some +SurfSense-native). They follow a "default-OFF for risky things, +default-ON for safe upgrades, master kill-switch for everything new" model. + +All new middleware checks its flag at agent build time. If the master +kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new +middleware is disabled regardless of its individual flag. This gives +operators a single switch to revert to pre-port behavior. + +Examples +-------- + +Local development (recommended for trying everything except doom-loop / selector): + + SURFSENSE_ENABLE_CONTEXT_EDITING=true + SURFSENSE_ENABLE_COMPACTION_V2=true + SURFSENSE_ENABLE_RETRY_AFTER=true + SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true + SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy + SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships + SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false + +Master kill-switch (overrides everything else): + + SURFSENSE_DISABLE_NEW_AGENT_STACK=true +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +def _env_bool(name: str, default: bool) -> bool: + """Parse a boolean env var. Accepts ``1``/``true``/``yes``/``on`` (case-insensitive).""" + raw = os.environ.get(name) + if raw is None: + return default + return raw.strip().lower() in ("1", "true", "yes", "on") + + +@dataclass(frozen=True) +class AgentFeatureFlags: + """Resolved feature-flag state for one agent build. + + Constructed via :meth:`from_env`. The dataclass is frozen so it can be + safely shared across coroutines. + """ + + # Master kill-switch — when true, every flag below resolves to False + # regardless of its env value. Used for rapid rollback. + disable_new_agent_stack: bool = False + + # Agent quality — context budget, retry/limits, name-repair, doom-loop + enable_context_editing: bool = False + enable_compaction_v2: bool = False + enable_retry_after: bool = False + enable_model_fallback: bool = False + enable_model_call_limit: bool = False + enable_tool_call_limit: bool = False + enable_tool_call_repair: bool = False + enable_doom_loop: bool = ( + False # Default OFF until UI handles permission='doom_loop' + ) + + # Safety — permissions, concurrency, tool-set narrowing + enable_permission: bool = False # Default OFF for first deploy + enable_busy_mutex: bool = False + enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost + + # Skills + subagents + enable_skills: bool = False + enable_specialized_subagents: bool = False + enable_kb_planner_runnable: bool = False + + # Snapshot / revert + enable_action_log: bool = False + enable_revert_route: bool = ( + False # Backend ships before UI; route returns 503 until this flips + ) + + # Plugins + enable_plugin_loader: bool = False + + # Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT) + enable_otel: bool = False + + @classmethod + def from_env(cls) -> AgentFeatureFlags: + """Read flags from environment. + + Master kill-switch is evaluated first; when set, all other flags + force to False. + """ + master_off = _env_bool("SURFSENSE_DISABLE_NEW_AGENT_STACK", False) + if master_off: + logger.info( + "SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent " + "middleware is forced OFF for this build." + ) + return cls(disable_new_agent_stack=True) + + return cls( + disable_new_agent_stack=False, + # Agent quality + enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False), + enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False), + enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False), + enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False), + enable_model_call_limit=_env_bool( + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False + ), + enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False), + enable_tool_call_repair=_env_bool( + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False + ), + enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False), + # Safety + enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False), + enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False), + enable_llm_tool_selector=_env_bool( + "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False + ), + # Skills + subagents + enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False), + enable_specialized_subagents=_env_bool( + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False + ), + enable_kb_planner_runnable=_env_bool( + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False + ), + # Snapshot / revert + enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False), + enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False), + # Plugins + enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False), + # Observability + enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False), + ) + + def any_new_middleware_enabled(self) -> bool: + """Return True if any new middleware flag is on.""" + if self.disable_new_agent_stack: + return False + return any( + ( + self.enable_context_editing, + self.enable_compaction_v2, + self.enable_retry_after, + self.enable_model_fallback, + self.enable_model_call_limit, + self.enable_tool_call_limit, + self.enable_tool_call_repair, + self.enable_doom_loop, + self.enable_permission, + self.enable_busy_mutex, + self.enable_llm_tool_selector, + self.enable_skills, + self.enable_specialized_subagents, + self.enable_kb_planner_runnable, + self.enable_action_log, + self.enable_revert_route, + self.enable_plugin_loader, + ) + ) + + +# Module-level cache. Read once at import time so the values are consistent +# across the process lifetime. Use ``reload_for_tests`` to reset in tests. +_FLAGS: AgentFeatureFlags | None = None + + +def get_flags() -> AgentFeatureFlags: + """Return the resolved feature-flag state, caching on first call.""" + global _FLAGS + if _FLAGS is None: + _FLAGS = AgentFeatureFlags.from_env() + return _FLAGS + + +def reload_for_tests() -> AgentFeatureFlags: + """Force a fresh read from env. Tests should call this after monkeypatching env.""" + global _FLAGS + _FLAGS = AgentFeatureFlags.from_env() + return _FLAGS + + +__all__ = [ + "AgentFeatureFlags", + "get_flags", + "reload_for_tests", +] diff --git a/surfsense_backend/app/agents/new_chat/filesystem_backends.py b/surfsense_backend/app/agents/new_chat/filesystem_backends.py index 85ed5f801..c8288be71 100644 --- a/surfsense_backend/app/agents/new_chat/filesystem_backends.py +++ b/surfsense_backend/app/agents/new_chat/filesystem_backends.py @@ -5,10 +5,12 @@ from __future__ import annotations from collections.abc import Callable from functools import lru_cache +from deepagents.backends.protocol import BackendProtocol from deepagents.backends.state import StateBackend from langgraph.prebuilt.tool_node import ToolRuntime from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection +from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) @@ -23,8 +25,20 @@ def _cached_multi_root_backend( def build_backend_resolver( selection: FilesystemSelection, -) -> Callable[[ToolRuntime], StateBackend | MultiRootLocalFolderBackend]: - """Create deepagents backend resolver for the selected filesystem mode.""" + *, + search_space_id: int | None = None, +) -> Callable[[ToolRuntime], BackendProtocol]: + """Create deepagents backend resolver for the selected filesystem mode. + + In cloud mode the resolver returns a fresh :class:`KBPostgresBackend` + bound to the current ``runtime`` so the backend can read staging state + (``staged_dirs``, ``pending_moves``, ``files`` cache, ``kb_anon_doc``, + ``kb_matched_chunk_ids``) for each tool call. When no ``search_space_id`` + is provided, the resolver falls back to :class:`StateBackend` (used by + sub-agents and tests that don't need DB-backed reads). + + Desktop-local mode unchanged. + """ if selection.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER and selection.local_mounts: @@ -36,7 +50,14 @@ def build_backend_resolver( return _resolve_local - def _resolve_cloud(runtime: ToolRuntime) -> StateBackend: + if search_space_id is not None: + + def _resolve_kb(runtime: ToolRuntime) -> BackendProtocol: + return KBPostgresBackend(search_space_id, runtime) + + return _resolve_kb + + def _resolve_state(runtime: ToolRuntime) -> StateBackend: return StateBackend(runtime) - return _resolve_cloud + return _resolve_state diff --git a/surfsense_backend/app/agents/new_chat/filesystem_state.py b/surfsense_backend/app/agents/new_chat/filesystem_state.py new file mode 100644 index 000000000..18952ed6f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/filesystem_state.py @@ -0,0 +1,113 @@ +"""LangGraph state schema additions used by the SurfSense filesystem agent. + +This schema extends deepagents' upstream :class:`FilesystemState` with the +extra fields needed to implement Postgres-backed virtual filesystem semantics: + +* ``cwd`` — current working directory (per-thread checkpointed). +* ``staged_dirs`` — pending mkdir requests (cloud only). +* ``pending_moves`` — pending move_file requests (cloud only). +* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads. +* ``dirty_paths`` — paths whose state file content differs from DB. +* ``kb_priority`` — top-K priority hints rendered into a system message. +* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting. +* ``kb_anon_doc`` — Redis-loaded anonymous document (if any). +* ``tree_version`` — bumped by persistence; invalidates the tree render cache. + +Tools mutate these fields ONLY via ``Command(update=...)`` returns; the +reducers in :mod:`app.agents.new_chat.state_reducers` handle merging. +""" + +from __future__ import annotations + +from typing import Annotated, Any, NotRequired + +from deepagents.middleware.filesystem import FilesystemState +from typing_extensions import TypedDict + +from app.agents.new_chat.state_reducers import ( + _add_unique_reducer, + _dict_merge_with_tombstones_reducer, + _list_append_reducer, + _replace_reducer, +) + + +class PendingMove(TypedDict): + """A staged move_file operation pending end-of-turn commit.""" + + source: str + dest: str + overwrite: bool + + +class KbPriorityEntry(TypedDict, total=False): + path: str + score: float + document_id: int | None + title: str + mentioned: bool + + +class KbAnonDoc(TypedDict, total=False): + """In-memory anonymous-session document loaded from Redis.""" + + path: str + title: str + content: str + chunks: list[dict[str, Any]] + + +class SurfSenseFilesystemState(FilesystemState): + """Filesystem state used by the SurfSense agent (cloud + desktop). + + Extends deepagents' :class:`FilesystemState` (which provides ``files``) + with cloud-mode staging fields and search-priority hints. All extra fields + are reducer-backed so that ``Command(update=...)`` payloads merge cleanly + across agent steps and across checkpoints. + """ + + cwd: NotRequired[Annotated[str, _replace_reducer]] + """Current working directory. + + Defaults to ``"/documents"`` in cloud mode and ``"/"`` (or first mount) in + desktop mode. Initialized once per thread by ``KnowledgeTreeMiddleware``. + """ + + staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]] + """mkdir paths staged for end-of-turn folder creation (cloud only).""" + + pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]] + """move_file ops staged for end-of-turn commit (cloud only).""" + + doc_id_by_path: NotRequired[ + Annotated[dict[str, int], _dict_merge_with_tombstones_reducer] + ] + """virtual_path -> ``Document.id`` for lazily loaded files. + + Populated on first read of a KB document. Used by edit_file/move_file/ + aafter_agent to map paths back to a real DB row. ``None`` values delete + the key (tombstones). + """ + + dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]] + """Paths whose ``state["files"]`` content has been modified this turn.""" + + kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]] + """Top-K priority hints rendered as a system message before the user turn.""" + + kb_matched_chunk_ids: NotRequired[Annotated[dict[int, list[int]], _replace_reducer]] + """Internal: ``Document.id`` -> list of matched chunk IDs from hybrid search.""" + + kb_anon_doc: NotRequired[Annotated[KbAnonDoc | None, _replace_reducer]] + """Anonymous-session document loaded from Redis (read-only, no DB row).""" + + tree_version: NotRequired[Annotated[int, _replace_reducer]] + """Monotonically increasing counter; bumped when commits change the KB tree.""" + + +__all__ = [ + "KbAnonDoc", + "KbPriorityEntry", + "PendingMove", + "SurfSenseFilesystemState", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/__init__.py b/surfsense_backend/app/agents/new_chat/middleware/__init__.py index 5a24b2f9e..094c102f8 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/__init__.py +++ b/surfsense_backend/app/agents/new_chat/middleware/__init__.py @@ -1,25 +1,83 @@ """Middleware components for the SurfSense new chat agent.""" +from app.agents.new_chat.middleware.action_log import ActionLogMiddleware +from app.agents.new_chat.middleware.anonymous_document import ( + AnonymousDocumentMiddleware, +) +from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware +from app.agents.new_chat.middleware.compaction import ( + SurfSenseCompactionMiddleware, + create_surfsense_compaction_middleware, +) +from app.agents.new_chat.middleware.context_editing import ( + ClearToolUsesEdit, + SpillingContextEditingMiddleware, + SpillToBackendEdit, +) from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, ) +from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware +from app.agents.new_chat.middleware.file_intent import ( + FileIntentMiddleware, +) from app.agents.new_chat.middleware.filesystem import ( SurfSenseFilesystemMiddleware, ) -from app.agents.new_chat.middleware.file_intent import ( - FileIntentMiddleware, +from app.agents.new_chat.middleware.kb_persistence import ( + KnowledgeBasePersistenceMiddleware, + commit_staged_filesystem_state, ) from app.agents.new_chat.middleware.knowledge_search import ( KnowledgeBaseSearchMiddleware, + KnowledgePriorityMiddleware, +) +from app.agents.new_chat.middleware.knowledge_tree import ( + KnowledgeTreeMiddleware, ) from app.agents.new_chat.middleware.memory_injection import ( MemoryInjectionMiddleware, ) +from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware +from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware +from app.agents.new_chat.middleware.skills_backends import ( + BuiltinSkillsBackend, + SearchSpaceSkillsBackend, + build_skills_backend_factory, + default_skills_sources, +) +from app.agents.new_chat.middleware.tool_call_repair import ( + ToolCallNameRepairMiddleware, +) __all__ = [ + "ActionLogMiddleware", + "AnonymousDocumentMiddleware", + "BuiltinSkillsBackend", + "BusyMutexMiddleware", + "ClearToolUsesEdit", "DedupHITLToolCallsMiddleware", + "DoomLoopMiddleware", "FileIntentMiddleware", + "KnowledgeBasePersistenceMiddleware", "KnowledgeBaseSearchMiddleware", + "KnowledgePriorityMiddleware", + "KnowledgeTreeMiddleware", "MemoryInjectionMiddleware", + "NoopInjectionMiddleware", + "OtelSpanMiddleware", + "PermissionMiddleware", + "RetryAfterMiddleware", + "SearchSpaceSkillsBackend", + "SpillToBackendEdit", + "SpillingContextEditingMiddleware", + "SurfSenseCompactionMiddleware", "SurfSenseFilesystemMiddleware", + "ToolCallNameRepairMiddleware", + "build_skills_backend_factory", + "commit_staged_filesystem_state", + "create_surfsense_compaction_middleware", + "default_skills_sources", ] diff --git a/surfsense_backend/app/agents/new_chat/middleware/action_log.py b/surfsense_backend/app/agents/new_chat/middleware/action_log.py new file mode 100644 index 000000000..3675064e8 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/action_log.py @@ -0,0 +1,292 @@ +"""Append-only action-log middleware for the SurfSense agent. + +Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes +a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt +into reversibility by declaring a ``reverse`` callable on their +:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered +descriptor is persisted in ``reverse_descriptor`` for use by +``/api/threads/{thread_id}/revert/{action_id}``. + +Design points: + +* **Defensive.** Logging never blocks the agent. We catch every exception + on the DB write path and emit a warning; the tool's ``ToolMessage`` + result is always returned untouched. +* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) + + ``result_id`` + ``reverse_descriptor`` are stored. Tool output text + remains in the LangGraph checkpoint / spilled tool-output files. +* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)`` + with the parsed JSON result when the tool's content is a JSON object; + otherwise the raw text is passed. Exceptions in the reverse callable + are swallowed and logged — a failed descriptor render simply means the + action is NOT marked reversible. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage + +from app.agents.new_chat.feature_flags import get_flags +from app.agents.new_chat.tools.registry import ToolDefinition + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain.agents.middleware.types import ToolCallRequest + from langgraph.types import Command + + +logger = logging.getLogger(__name__) + + +# Cap for the persisted ``args`` JSON to avoid bloating the action log with +# accidentally-huge inputs. Values are truncated and a flag is set in the +# stored payload so consumers can detect truncation. +_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB + + +class ActionLogMiddleware(AgentMiddleware): + """Persist a row in :class:`AgentActionLog` after every tool call. + + Should be placed near the OUTERMOST end of the tool-call wrapping stack + so that it sees the *final* :class:`ToolMessage` after all retries, + permission checks, and dedup logic have run. In practice that means + placing it just inside :class:`PermissionMiddleware` and outside + :class:`DedupHITLToolCallsMiddleware`. + + The middleware is fully a no-op when: + + * the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set + (checked via :func:`get_flags`), + * the per-feature flag ``enable_action_log`` is off, or + * persistence raises (defensive: tool-call dispatch always succeeds). + + Args: + thread_id: The current chat thread's primary-key id. Required to + persist a row; if ``None`` the middleware silently no-ops. + search_space_id: Search-space id for cascade-on-delete safety. + user_id: UUID string of the user driving this turn (nullable in + anonymous mode). + tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition` + so the middleware can look up the tool's ``reverse`` callable. + When omitted, no actions are marked reversible. + """ + + tools = () + + def __init__( + self, + *, + thread_id: int | None, + search_space_id: int, + user_id: str | None, + tool_definitions: dict[str, ToolDefinition] | None = None, + ) -> None: + super().__init__() + self._thread_id = thread_id + self._search_space_id = search_space_id + self._user_id = user_id + self._tool_definitions = dict(tool_definitions or {}) + + def _enabled(self) -> bool: + flags = get_flags() + if flags.disable_new_agent_stack: + return False + return bool(flags.enable_action_log) and self._thread_id is not None + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: + if not self._enabled(): + return await handler(request) + + result: ToolMessage | Command[Any] + error_payload: dict[str, Any] | None = None + try: + result = await handler(request) + except Exception as exc: + # Persist the failure too so revert/audit can see it, then + # re-raise so downstream middleware (RetryAfter, etc.) handles it. + error_payload = {"type": type(exc).__name__, "message": str(exc)} + await self._record( + request=request, + result=None, + error_payload=error_payload, + ) + raise + + await self._record(request=request, result=result, error_payload=None) + return result + + async def _record( + self, + *, + request: ToolCallRequest, + result: ToolMessage | Command[Any] | None, + error_payload: dict[str, Any] | None, + ) -> None: + """Persist one ``agent_action_log`` row. Defensive: never raises.""" + try: + from app.db import AgentActionLog, shielded_async_session + + tool_name = _resolve_tool_name(request) + args_payload = _resolve_args_payload(request) + result_id = _resolve_result_id(result) + reverse_descriptor, reversible = self._render_reverse( + tool_name=tool_name, + args=_resolve_args_dict(request), + result=result, + ) + + row = AgentActionLog( + thread_id=self._thread_id, + user_id=self._user_id, + search_space_id=self._search_space_id, + turn_id=_resolve_turn_id(request), + message_id=_resolve_message_id(request), + tool_name=tool_name, + args=args_payload, + result_id=result_id, + reversible=reversible, + reverse_descriptor=reverse_descriptor, + error=error_payload, + ) + async with shielded_async_session() as session: + session.add(row) + await session.commit() + except Exception: + logger.warning( + "ActionLogMiddleware failed to persist action log row", + exc_info=True, + ) + + def _render_reverse( + self, + *, + tool_name: str, + args: dict[str, Any] | None, + result: ToolMessage | Command[Any] | None, + ) -> tuple[dict[str, Any] | None, bool]: + """Run the tool's ``reverse`` callable and return its descriptor. + + Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When + the tool has no ``reverse`` callable, or when the callable raises, + the action is marked non-reversible. + """ + if not result or not isinstance(result, ToolMessage): + return None, False + if args is None: + return None, False + tool_def = self._tool_definitions.get(tool_name) + if tool_def is None or tool_def.reverse is None: + return None, False + try: + parsed_result = _parse_tool_result_content(result) + descriptor = tool_def.reverse(args, parsed_result) + except Exception: + logger.warning( + "Reverse descriptor render failed for tool %s", + tool_name, + exc_info=True, + ) + return None, False + if not isinstance(descriptor, dict): + return None, False + return descriptor, True + + +# --------------------------------------------------------------------------- +# Resolution helpers — defensive against tool_call request shape variation. +# --------------------------------------------------------------------------- + + +def _resolve_tool_name(request: Any) -> str: + try: + tool = getattr(request, "tool", None) + if tool is not None: + name = getattr(tool, "name", None) + if isinstance(name, str) and name: + return name + call = getattr(request, "tool_call", None) or {} + if isinstance(call, dict): + name = call.get("name") + if isinstance(name, str) and name: + return name + except Exception: # pragma: no cover - defensive + pass + return "unknown" + + +def _resolve_args_dict(request: Any) -> dict[str, Any] | None: + try: + call = getattr(request, "tool_call", None) + if not isinstance(call, dict): + return None + args = call.get("args") + if isinstance(args, dict): + return args + return None + except Exception: # pragma: no cover - defensive + return None + + +def _resolve_args_payload(request: Any) -> dict[str, Any] | None: + """Return a JSON-serializable args dict, truncated if too big.""" + args = _resolve_args_dict(request) + if args is None: + return None + try: + encoded = json.dumps(args, default=str) + except Exception: + return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]} + if len(encoded) <= _MAX_ARGS_PERSIST_BYTES: + return args + return { + "_truncated": True, + "_size": len(encoded), + "_preview": encoded[:_MAX_ARGS_PERSIST_BYTES], + } + + +def _resolve_turn_id(request: Any) -> str | None: + try: + call = getattr(request, "tool_call", None) or {} + if isinstance(call, dict): + tid = call.get("id") + if isinstance(tid, str): + return tid + except Exception: # pragma: no cover + pass + return None + + +def _resolve_message_id(request: Any) -> str | None: + """Tool-call IDs serve as best-available message correlator at this layer.""" + return _resolve_turn_id(request) + + +def _resolve_result_id(result: Any) -> str | None: + if isinstance(result, ToolMessage): + msg_id = getattr(result, "id", None) + if isinstance(msg_id, str): + return msg_id + return None + + +def _parse_tool_result_content(result: ToolMessage) -> Any: + content = result.content + if isinstance(content, str): + try: + return json.loads(content) + except (json.JSONDecodeError, ValueError): + return content + return content + + +__all__ = ["ActionLogMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py b/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py new file mode 100644 index 000000000..2893d2e11 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/anonymous_document.py @@ -0,0 +1,91 @@ +"""Lightweight middleware that loads the anonymous-session document into state. + +Anonymous chats receive a single uploaded document via Redis (no DB row, +read-only). This middleware loads it once on the first turn into +``state['kb_anon_doc']`` so: + +* :class:`KnowledgeTreeMiddleware` can render the synthetic ``/documents`` + view without touching the DB. +* :class:`KnowledgePriorityMiddleware` skips hybrid search and emits a + degenerate priority list. +* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / ``_load_file_data``) + recognises the synthetic path. + +The middleware is a no-op when ``anon_session_id`` is not provided or when +the document is already cached in state. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langgraph.runtime import Runtime + +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT, safe_filename + +logger = logging.getLogger(__name__) + + +class AnonymousDocumentMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Load the anonymous user's uploaded document from Redis into state.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__(self, *, anon_session_id: str | None) -> None: + self.anon_session_id = anon_session_id + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if not self.anon_session_id: + return None + if state.get("kb_anon_doc"): + return None + + anon_doc = await self._load_anon_document() + if anon_doc is None: + return None + return {"kb_anon_doc": anon_doc} + + async def _load_anon_document(self) -> dict[str, Any] | None: + """Read ``anon:doc:`` from Redis.""" + try: + import redis.asyncio as aioredis # local import to keep cold paths cheap + + from app.config import config + + redis_client = aioredis.from_url( + config.REDIS_APP_URL, decode_responses=True + ) + try: + redis_key = f"anon:doc:{self.anon_session_id}" + data = await redis_client.get(redis_key) + if not data: + return None + payload = json.loads(data) + finally: + await redis_client.aclose() + except Exception as exc: + logger.warning("Failed to load anonymous document from Redis: %s", exc) + return None + + title = str(payload.get("filename") or "uploaded_document") + content = str(payload.get("content") or "") + path = f"{DOCUMENTS_ROOT}/{safe_filename(title)}" + return { + "path": path, + "title": title, + "content": content, + "chunks": [{"chunk_id": -1, "content": content}] if content else [], + } + + +__all__ = ["AnonymousDocumentMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py new file mode 100644 index 000000000..c57d85004 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py @@ -0,0 +1,236 @@ +""" +BusyMutexMiddleware — per-thread asyncio lock + cancel token. + +LangChain has no built-in concept of "this thread is already running a +turn — refuse the second concurrent request". Without it, a user +double-clicking "send" or refreshing the page mid-stream can spawn two +turns racing on the same checkpoint, producing duplicated tool calls +and mangled state. + +Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a +single-process, in-memory lock + cooperative cancellation token keyed by +``thread_id``. For multi-worker deployments a distributed lock backend +(Redis or PostgreSQL advisory locks) is a phase-2 follow-up. + +What this provides: +- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``; + acquiring the lock during ``before_agent`` blocks any concurrent + prompt on the same thread until release. +- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running + tools can poll to abort cooperatively. The event is reset between + turns. Tools should check ``runtime.context.cancel_event.is_set()`` + in tight inner loops. +- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a + second turn arrives while the lock is held. + +Note: SurfSense's ``stream_new_chat`` is the call site that should +acquire/release. Wiring this as middleware means the contract is +explicit and the lock manager is shared with subagents that compile +their own ``create_agent`` runnables. +""" + +from __future__ import annotations + +import asyncio +import logging +import weakref +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langgraph.config import get_config +from langgraph.runtime import Runtime + +from app.agents.new_chat.errors import BusyError + +logger = logging.getLogger(__name__) + + +class _ThreadLockManager: + """Process-local registry of per-thread asyncio locks + cancel events.""" + + def __init__(self) -> None: + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = ( + weakref.WeakValueDictionary() + ) + self._cancel_events: dict[str, asyncio.Event] = {} + + def lock_for(self, thread_id: str) -> asyncio.Lock: + lock = self._locks.get(thread_id) + if lock is None: + lock = asyncio.Lock() + self._locks[thread_id] = lock + return lock + + def cancel_event(self, thread_id: str) -> asyncio.Event: + event = self._cancel_events.get(thread_id) + if event is None: + event = asyncio.Event() + self._cancel_events[thread_id] = event + return event + + def request_cancel(self, thread_id: str) -> bool: + event = self._cancel_events.get(thread_id) + if event is None: + return False + event.set() + return True + + def reset(self, thread_id: str) -> None: + event = self._cancel_events.get(thread_id) + if event is not None: + event.clear() + + +# Module-level singleton — process-local but reused across all agent +# instances built in this process. Subagents created in nested +# ``create_agent`` calls also get this so locks are coherent. +manager = _ThreadLockManager() + + +def get_cancel_event(thread_id: str) -> asyncio.Event: + """Public accessor used by long-running tools to poll cancellation.""" + return manager.cancel_event(thread_id) + + +def request_cancel(thread_id: str) -> bool: + """Trip the cancel event for ``thread_id``. Returns True if found.""" + return manager.request_cancel(thread_id) + + +def reset_cancel(thread_id: str) -> None: + """Reset the cancel event for ``thread_id`` (called between turns).""" + manager.reset(thread_id) + + +class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Block concurrent prompts on the same thread. + + Acquires the thread's lock in ``abefore_agent`` and releases in + ``aafter_agent``. If the lock is held, raises :class:`BusyError` + so the caller can emit a ``surfsense.busy`` SSE event with the + in-flight request id. + + Args: + require_thread_id: When True, raise :class:`BusyError` if no + ``thread_id`` can be resolved from the active + ``RunnableConfig``. Default is False — we treat a missing + thread_id as "this turn has nothing to lock against" and + no-op the mutex. Set True only when you trust the call + site to always provide ``configurable.thread_id`` (e.g. + in production where ``stream_new_chat`` always does). + """ + + def __init__(self, *, require_thread_id: bool = False) -> None: + super().__init__() + self._require_thread_id = require_thread_id + self.tools = [] + # Per-call locks owned by this middleware. We track them as + # an instance attribute so ``aafter_agent`` knows which lock + # to release. + self._held_locks: dict[str, asyncio.Lock] = {} + + @staticmethod + def _thread_id(runtime: Runtime[ContextT]) -> str | None: + """Extract ``thread_id`` from the active LangGraph ``RunnableConfig``. + + ``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``. + The runnable config (where ``configurable.thread_id`` lives) must be + fetched via :func:`langgraph.config.get_config` from inside a node / + middleware. We fall back to ``getattr(runtime, "config", None)`` for + unit tests / legacy runtimes that synthesize a config-bearing stub. + """ + + def _from_dict(cfg: Any) -> str | None: + if not isinstance(cfg, dict): + return None + tid = (cfg.get("configurable") or {}).get("thread_id") + return str(tid) if tid is not None else None + + # Preferred path: real LangGraph runtime context. + try: + tid = _from_dict(get_config()) + except Exception: + tid = None + if tid is not None: + return tid + + # Fallback for tests and any runtime that surfaces a config dict + # directly on the runtime instance. + return _from_dict(getattr(runtime, "config", None)) + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + del state + thread_id = self._thread_id(runtime) + if thread_id is None: + if self._require_thread_id: + raise BusyError("no thread_id configured") + logger.debug( + "BusyMutexMiddleware: no thread_id resolved from RunnableConfig; " + "skipping per-thread lock for this turn." + ) + return None + + lock = manager.lock_for(thread_id) + if lock.locked(): + raise BusyError(request_id=thread_id) + await lock.acquire() + self._held_locks[thread_id] = lock + # Reset the cancel event so this turn starts fresh + reset_cancel(thread_id) + return None + + async def aafter_agent( # type: ignore[override] + self, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + del state + thread_id = self._thread_id(runtime) + if thread_id is None: + return None + lock = self._held_locks.pop(thread_id, None) + if lock is not None and lock.locked(): + lock.release() + # Always clear cancel event between turns so a stale signal + # doesn't leak into the next request. + reset_cancel(thread_id) + return None + + # Provide sync no-ops because the middleware base class allows them + def before_agent( # type: ignore[override] + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + # Sync path: no asyncio.Lock to acquire. Best we can do is reject + # if anyone else is in flight. + thread_id = self._thread_id(runtime) + if thread_id is None: + if self._require_thread_id: + raise BusyError("no thread_id configured") + return None + lock = manager.lock_for(thread_id) + if lock.locked(): + raise BusyError(request_id=thread_id) + return None + + def after_agent( # type: ignore[override] + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return None + + +__all__ = [ + "BusyMutexMiddleware", + "get_cancel_event", + "manager", + "request_cancel", + "reset_cancel", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/compaction.py b/surfsense_backend/app/agents/new_chat/middleware/compaction.py new file mode 100644 index 000000000..16361e16b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/compaction.py @@ -0,0 +1,254 @@ +""" +SurfSense compaction middleware. + +Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware` +to add SurfSense-specific behavior: + +1. **Structured summary template** (OpenCode-style ``## Goal / Constraints / + Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``) + — see :data:`SURFSENSE_SUMMARY_PROMPT` below. The base + ``SummarizationMiddleware`` only ships a freeform "summarize this" + prompt; the structured template is ported from OpenCode's + ``compaction.ts``. +2. **Protect SurfSense-specific SystemMessages** so injected hints + (````, ````, ````, + ````, ````, ````, ````) + are *not* summarized away and are kept verbatim in the post-summary + message list. Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy + (some message types are part of the agent's contract and must survive + compaction unchanged). +3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string`` + (Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage + containing only tool_calls and no text, ``content`` can be ``None`` and + ``get_buffer_string`` crashes iterating over ``None``). SurfSense-specific. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from deepagents.middleware.summarization import ( + SummarizationMiddleware, + compute_summarization_defaults, +) +from langchain_core.messages import SystemMessage + +from app.observability import otel as ot + +if TYPE_CHECKING: + from deepagents.backends.protocol import BACKEND_TYPES + from langchain_core.language_models import BaseChatModel + from langchain_core.messages import AnyMessage + +logger = logging.getLogger(__name__) + +# Structured summary template ported from OpenCode's +# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a +# module-level constant so unit tests can assert on its sections. +SURFSENSE_SUMMARY_PROMPT = """ +SurfSense Conversation Compaction Assistant + + + +Extract the most important context from the conversation history below into a structured summary that will replace the older messages. + + + +You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report. + +## Goal +What is the user's primary goal or request? State it in one or two sentences. + +## Constraints +What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)? + +## Progress +What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion. + +## Key Decisions +What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path. + +## Next Steps +What specific tasks remain to achieve the goal? Order them by dependency. + +## Critical Context +What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name). + +## Relevant Files +What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree. + + + +Messages to summarize: +{messages} + + +Respond ONLY with the structured summary. Do not include any text before or after. +""" + +# SystemMessage prefixes that must NOT be summarized away. They are +# re-injected on every turn by the corresponding middleware, but the +# compaction step happens *before* re-injection in some paths, so we +# must preserve them verbatim across the cutoff. +PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = ( + "", # KnowledgePriorityMiddleware + "", # KnowledgeTreeMiddleware + "", # FileIntentMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware + "", # MemoryInjectionMiddleware +) + + +def _is_protected_system_message(msg: AnyMessage) -> bool: + """Return True if ``msg`` is a SystemMessage we must not summarize.""" + if not isinstance(msg, SystemMessage): + return False + content = msg.content + if not isinstance(content, str): + return False + stripped = content.lstrip() + return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES) + + +def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: + """Return ``msg`` with ``content=None`` coerced to ``""``. + + Folds in the historical defense from ``safe_summarization.py`` — + ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``, + so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only + AIMessage) explodes. We return a copy with empty string content so + downstream consumers see an empty body without mutating the original. + """ + if getattr(msg, "content", "not-missing") is not None: + return msg + try: + return msg.model_copy(update={"content": ""}) + except AttributeError: + import copy + + new_msg = copy.copy(msg) + try: + new_msg.content = "" + except Exception: + logger.debug( + "Could not sanitize content=None on message of type %s", + type(msg).__name__, + ) + return msg + return new_msg + + +class SurfSenseCompactionMiddleware(SummarizationMiddleware): + """SummarizationMiddleware tuned for SurfSense. + + Notes + ----- + - Overrides :meth:`_partition_messages` so protected SystemMessages + survive into the ``preserved_messages`` half regardless of cutoff. + - Overrides :meth:`_filter_summary_messages` so the buffer-string path + never iterates ``None`` content. + - Inherits everything else (auto-trigger, backend offload, + ``_summarization_event`` plumbing, ``ContextOverflowError`` fallback). + """ + + def _partition_messages( # type: ignore[override] + self, + conversation_messages: list[AnyMessage], + cutoff_index: int, + ) -> tuple[list[AnyMessage], list[AnyMessage]]: + """Split messages but always preserve SurfSense protected SystemMessages. + + Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy + (``opencode/packages/opencode/src/session/compaction.ts``): some + message types are always kept verbatim because they are part of the + agent's working contract, not transient output. + + Also opens a ``compaction.run`` OTel span (no-op when OTel is off) + so dashboards can count compaction events and message-volume + without having to instrument upstream callers. + """ + # Opening a span here is appropriate because partitioning is the + # first call SummarizationMiddleware makes when it has decided to + # summarize; we record the volume and then close as a normal span. + with ot.compaction_span( + reason="auto", + messages_in=len(conversation_messages), + extra={"compaction.cutoff_index": int(cutoff_index)}, + ): + messages_to_summarize, preserved_messages = super()._partition_messages( + conversation_messages, cutoff_index + ) + + protected: list[AnyMessage] = [] + kept_for_summary: list[AnyMessage] = [] + for msg in messages_to_summarize: + if _is_protected_system_message(msg): + protected.append(msg) + else: + kept_for_summary.append(msg) + + # Place protected blocks at the *front* of preserved_messages so + # they keep their original ordering relative to the summary + # HumanMessage that precedes the rest of the preserved tail. + return kept_for_summary, [*protected, *preserved_messages] + + def _filter_summary_messages( # type: ignore[override] + self, messages: list[AnyMessage] + ) -> list[AnyMessage]: + """Filter previous summaries AND sanitize ``content=None``. + + Folds the ``safe_summarization.py`` defense in: when the buffer + builder iterates ``m.text`` over ``None`` it explodes; sanitizing + here covers both the sync and async offload paths. + """ + filtered = super()._filter_summary_messages(messages) + return [_sanitize_message_content(m) for m in filtered] + + +def create_surfsense_compaction_middleware( + model: BaseChatModel, + backend: BACKEND_TYPES, + *, + summary_prompt: str | None = None, + history_path_prefix: str = "/conversation_history", + **overrides: Any, +) -> SurfSenseCompactionMiddleware: + """Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults. + + Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings`` + via :func:`deepagents.middleware.summarization.compute_summarization_defaults` + so callers get the same behavior as ``create_summarization_middleware`` + plus our overrides. + + Args: + model: Chat model to call for summary generation. + backend: Backend instance or factory for offloading conversation history. + summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`. + history_path_prefix: Path prefix for offloaded conversation history. + **overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`. + """ + defaults = compute_summarization_defaults(model) + return SurfSenseCompactionMiddleware( + model=model, + backend=backend, + trigger=overrides.pop("trigger", defaults["trigger"]), + keep=overrides.pop("keep", defaults["keep"]), + trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None), + truncate_args_settings=overrides.pop( + "truncate_args_settings", defaults["truncate_args_settings"] + ), + summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT, + history_path_prefix=history_path_prefix, + **overrides, + ) + + +__all__ = [ + "PROTECTED_SYSTEM_PREFIXES", + "SURFSENSE_SUMMARY_PROMPT", + "SurfSenseCompactionMiddleware", + "create_surfsense_compaction_middleware", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/context_editing.py b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py new file mode 100644 index 000000000..39bc57c8b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/context_editing.py @@ -0,0 +1,350 @@ +""" +SpillToBackendEdit + SpillingContextEditingMiddleware. + +LangChain's :class:`ClearToolUsesEdit` discards old ``ToolMessage.content`` +when the context-editing budget triggers, replacing the body with a fixed +placeholder. That's lossy: anything the agent might want to revisit is +gone. The spill-to-disk pattern (originally from OpenCode's +``opencode/packages/opencode/src/tool/truncate.ts``) keeps the prune +behavior but writes the full original payload to the runtime backend +under ``/tool_outputs/{thread_id}/{message_id}.txt`` first. The +placeholder is then upgraded to point at the spill path so the agent +(or a subagent) can read it back on demand. + +Why this is a middleware subclass instead of a plain ``ContextEdit``: +``ContextEdit.apply`` is sync, but writing to the backend is async. We +capture the spill payloads inside ``apply`` and flush them via +``await backend.aupload_files(...)`` from ``awrap_model_call`` *before* +delegating to the handler, so the explore subagent can always read what +the placeholder advertises. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Awaitable, Callable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware.context_editing import ( + ClearToolUsesEdit, + ContextEdit, + ContextEditingMiddleware, + TokenCounter, +) +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + ToolMessage, +) +from langchain_core.messages.utils import count_tokens_approximately +from langgraph.config import get_config + +if TYPE_CHECKING: + from deepagents.backends.protocol import BackendProtocol + from langchain.agents.middleware.types import ( + ModelRequest, + ModelResponse, + ) + +logger = logging.getLogger(__name__) + +DEFAULT_SPILL_PREFIX = "/tool_outputs" + + +def _build_spill_placeholder(spill_path: str) -> str: + """Build the user-facing placeholder text shown to the model.""" + return ( + f"[cleared — full output at {spill_path}; ask the explore subagent to read it]" + ) + + +def _get_thread_id_or_session() -> str: + """Best-effort thread_id discovery for the spill path. + + Falls back to a process-stable string if no LangGraph config is + available (e.g. unit tests). The exact value doesn't matter as long + as it's stable within one stream so the placeholder paths line up + with the actual upload path. + """ + try: + config = get_config() + thread_id = config.get("configurable", {}).get("thread_id") + if thread_id is not None: + return str(thread_id) + except RuntimeError: + pass + return "no_thread" + + +@dataclass(slots=True) +class SpillToBackendEdit(ContextEdit): + """Capture-and-replace context edit that spills full tool output to the backend. + + Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude + semantics) **and** records the original ``ToolMessage.content`` in + :attr:`pending_spills` so the wrapping middleware can flush them + before the model call. + + Args: + trigger: Token threshold above which the edit fires. + clear_at_least: Minimum number of tokens to reclaim (best effort). + keep: Number of most-recent ``ToolMessage`` instances to leave + untouched. + exclude_tools: Names of tools whose output is NOT spilled. + clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls`` + args when their pair is cleared. + path_prefix: Path under the backend where spills are written. + Default ``"/tool_outputs"``. + """ + + trigger: int = 100_000 + clear_at_least: int = 0 + keep: int = 3 + clear_tool_inputs: bool = False + exclude_tools: Sequence[str] = () + path_prefix: str = DEFAULT_SPILL_PREFIX + + pending_spills: list[tuple[str, bytes]] = field(default_factory=list) + _lock: threading.Lock = field(default_factory=threading.Lock) + + def drain_pending(self) -> list[tuple[str, bytes]]: + """Return and clear the pending-spill list atomically.""" + with self._lock: + out = list(self.pending_spills) + self.pending_spills.clear() + return out + + def apply( + self, + messages: list[AnyMessage], + *, + count_tokens: TokenCounter, + ) -> None: + """Mirror ``ClearToolUsesEdit.apply`` but capture originals first.""" + tokens = count_tokens(messages) + if tokens <= self.trigger: + return + + candidates = [ + (idx, msg) + for idx, msg in enumerate(messages) + if isinstance(msg, ToolMessage) + ] + if self.keep >= len(candidates): + return + if self.keep: + candidates = candidates[: -self.keep] + + thread_id = _get_thread_id_or_session() + excluded_tools = set(self.exclude_tools) + + for idx, tool_message in candidates: + if tool_message.response_metadata.get("context_editing", {}).get("cleared"): + continue + + ai_message = next( + (m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), + None, + ) + if ai_message is None: + continue + + tool_call = next( + ( + call + for call in ai_message.tool_calls + if call.get("id") == tool_message.tool_call_id + ), + None, + ) + if tool_call is None: + continue + + tool_name = tool_message.name or tool_call["name"] + if tool_name in excluded_tools: + continue + + message_id = tool_message.id or tool_message.tool_call_id or "unknown" + spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt" + + original = tool_message.content + payload = self._encode_payload(original) + with self._lock: + self.pending_spills.append((spill_path, payload)) + + messages[idx] = tool_message.model_copy( + update={ + "artifact": None, + "content": _build_spill_placeholder(spill_path), + "response_metadata": { + **tool_message.response_metadata, + "context_editing": { + "cleared": True, + "strategy": "spill_to_backend", + "spill_path": spill_path, + }, + }, + } + ) + + if self.clear_tool_inputs: + ai_idx = messages.index(ai_message) + messages[ai_idx] = self._clear_input_args( + ai_message, tool_message.tool_call_id or "" + ) + + if self.clear_at_least > 0: + new_token_count = count_tokens(messages) + cleared_tokens = max(0, tokens - new_token_count) + if cleared_tokens >= self.clear_at_least: + break + + @staticmethod + def _encode_payload(content: Any) -> bytes: + """Serialize ``ToolMessage.content`` to bytes for upload.""" + if isinstance(content, bytes): + return content + if isinstance(content, str): + return content.encode("utf-8") + try: + import json + + return json.dumps(content, default=str).encode("utf-8") + except Exception: + return str(content).encode("utf-8") + + @staticmethod + def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage: + updated_tool_calls: list[dict[str, Any]] = [] + cleared_any = False + for tool_call in message.tool_calls: + updated = dict(tool_call) + if updated.get("id") == tool_call_id: + updated["args"] = {} + cleared_any = True + updated_tool_calls.append(updated) + + metadata = dict(getattr(message, "response_metadata", {})) + if cleared_any: + ctx = dict(metadata.get("context_editing", {})) + ids = set(ctx.get("cleared_tool_inputs", [])) + ids.add(tool_call_id) + ctx["cleared_tool_inputs"] = sorted(ids) + metadata["context_editing"] = ctx + return message.model_copy( + update={ + "tool_calls": updated_tool_calls, + "response_metadata": metadata, + } + ) + + +BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol" + + +class SpillingContextEditingMiddleware(ContextEditingMiddleware): + """:class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes. + + Runs the configured edits as the parent does, then flushes any + pending spills via the supplied backend resolver before delegating + to the model handler. Spill failures are logged but never abort the + model call — the placeholder text is already in the message, so the + worst case is the agent gets a placeholder it cannot follow up on. + """ + + def __init__( + self, + *, + edits: Sequence[ContextEdit], + backend_resolver: BackendResolver | None = None, + token_count_method: str = "approximate", + ) -> None: + super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type] + self._backend_resolver = backend_resolver + + def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None: + if self._backend_resolver is None: + return None + if callable(self._backend_resolver): + try: + from langchain.tools import ToolRuntime + + tool_runtime = ToolRuntime( + state=getattr(request, "state", {}), + context=getattr(request.runtime, "context", None), + stream_writer=getattr(request.runtime, "stream_writer", None), + store=getattr(request.runtime, "store", None), + config=getattr(request.runtime, "config", None) or {}, + tool_call_id=None, + ) + return self._backend_resolver(tool_runtime) + except Exception: + logger.exception("Failed to resolve spill backend") + return None + return self._backend_resolver # type: ignore[return-value] + + def _collect_pending(self) -> list[tuple[str, bytes]]: + out: list[tuple[str, bytes]] = [] + for edit in self.edits: + if isinstance(edit, SpillToBackendEdit): + out.extend(edit.drain_pending()) + return out + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> Any: + if not request.messages: + return await handler(request) + + if self.token_count_method == "approximate": + + def count_tokens(messages: Sequence[BaseMessage]) -> int: + return count_tokens_approximately(messages) + + else: + system_msg = [request.system_message] if request.system_message else [] + + def count_tokens(messages: Sequence[BaseMessage]) -> int: + return request.model.get_num_tokens_from_messages( + system_msg + list(messages), request.tools + ) + + edited_messages = deepcopy(list(request.messages)) + for edit in self.edits: + edit.apply(edited_messages, count_tokens=count_tokens) + + pending = self._collect_pending() + if pending: + backend = self._resolve_backend(request) + if backend is not None: + try: + await backend.aupload_files(pending) + except Exception: + logger.exception( + "Spill-to-backend upload failed (%d files); placeholders " + "remain in messages but content is unrecoverable", + len(pending), + ) + else: + logger.warning( + "SpillToBackendEdit produced %d pending spills but no backend " + "resolver was configured; content is unrecoverable", + len(pending), + ) + + return await handler(request.override(messages=edited_messages)) + + +__all__ = [ + "DEFAULT_SPILL_PREFIX", + "ClearToolUsesEdit", + "SpillToBackendEdit", + "SpillingContextEditingMiddleware", + "_build_spill_placeholder", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py index 61494ff1a..c55347284 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py +++ b/surfsense_backend/app/agents/new_chat/middleware/dedup_tool_calls.py @@ -2,17 +2,27 @@ When the LLM emits multiple calls to the same HITL tool with the same primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``), -only the first call is kept. Non-HITL tools are never touched. +only the first call is kept. Non-HITL tools are never touched. This runs in the ``after_model`` hook — **before** any tool executes — so the duplicate call is stripped from the AIMessage that gets checkpointed. That means it is also safe across LangGraph ``interrupt()`` boundaries: the removed call will never appear on graph resume. + +Dedup-key resolution order: + +1. :class:`ToolDefinition.dedup_key` — callable provided by the registry + entry. This is the canonical mechanism. +2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name; + used by MCP / Composio tools whose schemas the registry doesn't see. + +A tool with no resolver from either path simply opts out of dedup. """ from __future__ import annotations import logging +from collections.abc import Callable from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState @@ -20,81 +30,83 @@ from langgraph.runtime import Runtime logger = logging.getLogger(__name__) -_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = { - # Gmail - "send_gmail_email": "subject", - "create_gmail_draft": "subject", - "update_gmail_draft": "draft_subject_or_id", - "trash_gmail_email": "email_subject_or_id", - # Google Calendar - "create_calendar_event": "title", - "update_calendar_event": "event_title_or_id", - "delete_calendar_event": "event_title_or_id", - # Google Drive - "create_google_drive_file": "file_name", - "delete_google_drive_file": "file_name", - # OneDrive - "create_onedrive_file": "file_name", - "delete_onedrive_file": "file_name", - # Dropbox - "create_dropbox_file": "file_name", - "delete_dropbox_file": "file_name", - # Notion - "create_notion_page": "title", - "update_notion_page": "page_title", - "delete_notion_page": "page_title", - # Linear - "create_linear_issue": "title", - "update_linear_issue": "issue_ref", - "delete_linear_issue": "issue_ref", - # Jira - "create_jira_issue": "summary", - "update_jira_issue": "issue_title_or_key", - "delete_jira_issue": "issue_title_or_key", - # Confluence - "create_confluence_page": "title", - "update_confluence_page": "page_title_or_id", - "delete_confluence_page": "page_title_or_id", -} +# Resolver type — given the tool ``args`` dict returns a stable +# string used to dedupe consecutive calls. ``None`` means no dedup. +DedupResolver = Callable[[dict[str, Any]], str] + + +def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver: + """Adapt a string-arg name into a :data:`DedupResolver`. + + Convenience helper used by registry entries that just want to dedupe + on a single arg's lowercased value (the most common case for native + HITL tools like ``send_gmail_email`` keyed on ``subject``). + + Example:: + + ToolDefinition( + name="send_gmail_email", + ..., + dedup_key=wrap_dedup_key_by_arg_name("subject"), + ) + """ + + def _resolver(args: dict[str, Any]) -> str: + return str(args.get(arg_name, "")).lower() + + return _resolver + + +# Backwards-compatible alias for code that imported the original +# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`. +_wrap_string_key = wrap_dedup_key_by_arg_name class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] """Remove duplicate HITL tool calls from a single LLM response. - Only the **first** occurrence of each (tool-name, primary-arg-value) + Only the **first** occurrence of each ``(tool-name, dedup_key)`` pair is kept; subsequent duplicates are silently dropped. - The dedup map is built from two sources: + The dedup-resolver map is built from two sources, in priority order: - 1. A comprehensive list of native HITL tools (hardcoded above). - 2. Any ``StructuredTool`` instances passed via *agent_tools* whose - ``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``. - This is how MCP tools automatically get dedup support. + 1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's + ``ToolDefinition.dedup_key``. Receives the args dict and returns + a string signature. This is the canonical mechanism. + 2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg + name; primarily used by MCP / Composio tools. """ tools = () def __init__(self, *, agent_tools: list[Any] | None = None) -> None: - self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS) + self._resolvers: dict[str, DedupResolver] = {} + for t in agent_tools or []: meta = getattr(t, "metadata", None) or {} + callable_key = meta.get("dedup_key") + if callable(callable_key): + self._resolvers[t.name] = callable_key + continue if meta.get("hitl") and meta.get("hitl_dedup_key"): - self._dedup_keys[t.name] = meta["hitl_dedup_key"] + self._resolvers[t.name] = wrap_dedup_key_by_arg_name( + meta["hitl_dedup_key"] + ) def after_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state, self._dedup_keys) + return self._dedup(state, self._resolvers) async def aafter_model( self, state: AgentState, runtime: Runtime[Any] ) -> dict[str, Any] | None: - return self._dedup(state, self._dedup_keys) + return self._dedup(state, self._resolvers) @staticmethod def _dedup( state: AgentState, - dedup_keys: dict[str, str], # type: ignore[type-arg] + resolvers: dict[str, DedupResolver], ) -> dict[str, Any] | None: messages = state.get("messages") if not messages: @@ -110,9 +122,16 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg] for tc in tool_calls: name = tc.get("name", "") - dedup_key_arg = dedup_keys.get(name) - if dedup_key_arg is not None: - arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower() + resolver = resolvers.get(name) + if resolver is not None: + try: + arg_val = resolver(tc.get("args", {}) or {}) + except Exception: + logger.exception( + "Dedup resolver for tool %s raised; keeping call", name + ) + deduped.append(tc) + continue key = (name, arg_val) if key in seen: logger.info( diff --git a/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py new file mode 100644 index 000000000..850ecd1d2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/doom_loop.py @@ -0,0 +1,237 @@ +""" +DoomLoopMiddleware — pattern-based detector for repeated identical tool calls. + +LangChain has :class:`ToolCallLimitMiddleware` which caps the *total* number +of tool calls per turn — but it can't tell apart "10 distinct, useful +calls" from "the same call 10 times in a row". This middleware fills that +gap with a sliding-window check on tool-call signatures, ported from +OpenCode's ``packages/opencode/src/session/processor.ts``. + +When the same tool with the same arguments is called N times in a row, +the agent has likely entered an infinite loop. We surface this to the +user as an interrupt with ``permission="doom_loop"`` so the UI can +render an "Are you stuck? Continue / cancel?" affordance. + +This ships **OFF by default** until the frontend explicitly handles +``context.permission == "doom_loop"`` interrupts. + +Wire format: uses SurfSense's existing ``interrupt()`` payload shape +(see ``app/agents/new_chat/tools/hitl.py``): + + { + "type": "permission_ask", + "action": {"tool": , "params": }, + "context": {"permission": "doom_loop", "recent_signatures": [...]}, + } + +so the frontend that already handles HITL prompts can render this with +no changes beyond a string check. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from collections import deque +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langchain_core.messages import AIMessage +from langgraph.config import get_config +from langgraph.runtime import Runtime +from langgraph.types import interrupt + +from app.observability import otel as ot + +logger = logging.getLogger(__name__) + + +def _signature(name: str, args: Any) -> str: + """Hash a tool call ``(name, args)`` to a short signature.""" + try: + canonical = json.dumps(args, sort_keys=True, default=str) + except (TypeError, ValueError): + canonical = repr(args) + digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest() + return digest[:16] + + +class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Detect repeated identical tool calls and prompt the user. + + Tracks a sliding window of the most-recent ``threshold`` tool-call + signatures across the live request. When all entries match, raise + a SurfSense-style HITL interrupt with ``permission="doom_loop"``. + + Args: + threshold: How many consecutive identical signatures count as a + doom loop. Default 3 (matches OpenCode's processor.ts). + """ + + def __init__(self, *, threshold: int = 3) -> None: + super().__init__() + if threshold < 2: + raise ValueError("DoomLoopMiddleware threshold must be >= 2") + self._threshold = threshold + self.tools = [] + # Per-thread sliding windows. We can't put this in graph state + # without state-schema gymnastics; for one process-lifetime it's + # fine to keep an in-memory map keyed by thread_id. + self._windows: dict[str, deque[str]] = {} + + @staticmethod + def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str: + """Resolve the thread id for sliding-window keying. + + Prefer LangGraph's ``get_config()`` (the only way to read + ``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry + a ``config`` attribute). Fall back to ``runtime.config`` for unit + tests that synthesize a config-bearing stub. Default + ``"no_thread"`` is intentionally only used when both lookups fail + — it would collapse all threads into one window so we keep the + debug log loud. + """ + + def _from_dict(cfg: Any) -> str | None: + if not isinstance(cfg, dict): + return None + tid = (cfg.get("configurable") or {}).get("thread_id") + return str(tid) if tid is not None else None + + try: + tid = _from_dict(get_config()) + except Exception: + tid = None + if tid is not None: + return tid + + tid = _from_dict(getattr(runtime, "config", None)) + if tid is not None: + return tid + + logger.debug( + "DoomLoopMiddleware: no thread_id resolved from RunnableConfig; " + "falling back to shared 'no_thread' window." + ) + return "no_thread" + + def _window(self, thread_id: str) -> deque[str]: + win = self._windows.get(thread_id) + if win is None: + win = deque(maxlen=self._threshold) + self._windows[thread_id] = win + return win + + def _detect( + self, message: AIMessage, runtime: Runtime[ContextT] + ) -> tuple[bool, list[str], dict[str, Any] | None]: + if not message.tool_calls: + return False, [], None + + thread_id = self._thread_id_from_runtime(runtime) + window = self._window(thread_id) + + triggered_call: dict[str, Any] | None = None + for call in message.tool_calls: + name = ( + call.get("name") + if isinstance(call, dict) + else getattr(call, "name", None) + ) + args = ( + call.get("args") + if isinstance(call, dict) + else getattr(call, "args", {}) + ) + if not isinstance(name, str): + continue + sig = _signature(name, args) + window.append(sig) + if len(window) >= self._threshold and len(set(window)) == 1: + triggered_call = {"name": name, "params": args or {}} + break + + if triggered_call is None: + return False, list(window), None + return True, list(window), triggered_call + + def after_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + triggered, signatures, action = self._detect(last, runtime) + if not triggered: + return None + + logger.warning( + "Doom loop detected: tool %s called %d times in a row (sig=%s)", + action["name"] if action else "", + self._threshold, + signatures[-1] if signatures else "", + ) + + # Open an interrupt.raised span with permission=doom_loop attribute + # so dashboards can break out doom-loop interrupts from regular + # permission asks via the ``interrupt.permission`` attribute. + with ot.interrupt_span( + interrupt_type="permission_ask", + extra={ + "interrupt.permission": "doom_loop", + "interrupt.threshold": self._threshold, + "interrupt.tool": (action or {}).get("tool", ""), + }, + ): + decision = interrupt( + { + "type": "permission_ask", + "action": action or {"tool": "", "params": {}}, + "context": { + "permission": "doom_loop", + "recent_signatures": signatures, + "threshold": self._threshold, + }, + } + ) + + # Reset window so the next decision (continue/cancel) starts fresh. + thread_id = self._thread_id_from_runtime(runtime) + self._windows.pop(thread_id, None) + + # Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."} + # If the user cancelled, jump to end. Otherwise return ``None`` so the + # tool call proceeds. The frontend's exact reply names may differ — + # we tolerate any shape that contains a string with "reject"/"cancel". + if isinstance(decision, dict): + kind = str( + decision.get("decision_type") or decision.get("type") or "" + ).lower() + if "reject" in kind or "cancel" in kind: + return {"jump_to": "end"} + return None + + async def aafter_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + return self.after_model(state, runtime) + + +__all__ = [ + "DoomLoopMiddleware", + "_signature", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py index 4bf5dcfe4..7897e13d6 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/file_intent.py +++ b/surfsense_backend/app/agents/new_chat/middleware/file_intent.py @@ -21,7 +21,7 @@ from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langgraph.runtime import Runtime from pydantic import BaseModel, Field, ValidationError @@ -213,10 +213,23 @@ def _build_classifier_prompt(*, recent_conversation: str, user_text: str) -> str ) -def _build_recent_conversation(messages: list[BaseMessage], *, max_messages: int = 6) -> str: +def _build_recent_conversation( + messages: list[BaseMessage], *, max_messages: int = 6 +) -> str: rows: list[str] = [] - for msg in messages[-max_messages:]: - role = "user" if isinstance(msg, HumanMessage) else "assistant" + filtered: list[tuple[str, BaseMessage]] = [] + for msg in messages: + role: str | None = None + if isinstance(msg, HumanMessage): + role = "user" + elif isinstance(msg, AIMessage): + if getattr(msg, "tool_calls", None): + continue + role = "assistant" + else: + continue + filtered.append((role, msg)) + for role, msg in filtered[-max_messages:]: text = re.sub(r"\s+", " ", _extract_text_from_message(msg)).strip() if text: rows.append(f"{role}: {text[:280]}") @@ -246,7 +259,9 @@ class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] [HumanMessage(content=prompt)], config={"tags": ["surfsense:internal"]}, ) - payload = json.loads(_extract_json_payload(_extract_text_from_message(response))) + payload = json.loads( + _extract_json_payload(_extract_text_from_message(response)) + ) plan = FileIntentPlan.model_validate(payload) return plan except (json.JSONDecodeError, ValidationError, ValueError) as exc: @@ -317,4 +332,3 @@ class FileIntentMiddleware(AgentMiddleware): # type: ignore[type-arg] insert_at = max(len(new_messages) - 1, 0) new_messages.insert(insert_at, contract_msg) return {"messages": new_messages, "file_operation_contract": contract} - diff --git a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py index 8dfa89ef2..62316d69e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/filesystem.py +++ b/surfsense_backend/app/agents/new_chat/middleware/filesystem.py @@ -1,7 +1,26 @@ """Custom filesystem middleware for the SurfSense agent. -This middleware customizes prompts and persists write/edit operations for -`/documents/*` files into SurfSense's `Document`/`Chunk` tables. +This middleware fully overrides every deepagents filesystem tool so that the +``Command(update=...)`` payload can carry SurfSense-specific state fields +(``cwd``, ``staged_dirs``, ``pending_moves``, ``doc_id_by_path``, +``dirty_paths``) atomically alongside the standard ``files`` update. + +In CLOUD mode the backend is :class:`KBPostgresBackend` (lazy DB reads, no DB +writes). End-of-turn persistence is handled by +:class:`KnowledgeBasePersistenceMiddleware`. In DESKTOP_LOCAL_FOLDER mode the +backend is :class:`MultiRootLocalFolderBackend` and writes go straight to disk. + +New tools introduced here: + +* ``mkdir`` — cloud-only stages folder paths to ``state['staged_dirs']``; + desktop creates real directories. +* ``cd`` / ``pwd`` — manage ``state['cwd']`` (per-thread). +* ``move_file`` — staged commit in cloud, real disk move in desktop. +* ``list_tree`` — works in both modes (cloud uses + :func:`KBPostgresBackend.alist_tree_listing`). + +The middleware no longer ships ``save_document``; persistence is inferred +from ``write_file`` / ``edit_file`` against ``/documents/*`` paths. """ from __future__ import annotations @@ -9,66 +28,92 @@ from __future__ import annotations import asyncio import json import logging +import posixpath import re import secrets -from datetime import UTC, datetime from typing import Annotated, Any from daytona.common.errors import DaytonaError from deepagents import FilesystemMiddleware from deepagents.backends.protocol import EditResult, WriteResult -from deepagents.backends.utils import validate_path -from deepagents.middleware.filesystem import FilesystemState -from fractional_indexing import generate_key_between +from deepagents.backends.utils import ( + create_file_data, + format_read_response, + validate_path, +) from langchain.tools import ToolRuntime -from langchain_core.callbacks import dispatch_custom_event from langchain_core.messages import ToolMessage from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from sqlalchemy import delete, select from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.middleware.kb_postgres_backend import ( + KBPostgresBackend, + paginate_listing, +) from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) +from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT from app.agents.new_chat.sandbox import ( _evict_sandbox_cache, delete_sandbox, get_or_create_sandbox, is_sandbox_enabled, ) -from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session -from app.indexing_pipeline.document_chunker import chunk_text -from app.utils.document_converters import ( - embed_texts, - generate_content_hash, - generate_unique_identifier_hash, -) +from app.agents.new_chat.state_reducers import _CLEAR logger = logging.getLogger(__name__) -# ============================================================================= -# System Prompt (injected into every model call by wrap_model_call) -# ============================================================================= -SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions +# ============================================================================= +# System Prompt (built per-session based on filesystem_mode) +# ============================================================================= +# +# Each chat session runs in exactly one filesystem mode. Including rules for +# the OTHER mode just wastes tokens and confuses the model, so we build the +# prompt + tool descriptions for the active mode only. + +_COMMON_PROMPT_HEADER = """## Following Conventions - Read files before editing — understand existing content before making changes. - Mimic existing style, naming conventions, and patterns. - Never claim a file was created/updated unless filesystem tool output confirms success. - If a file write/edit fails, explicitly report the failure. +""" +_CLOUD_SYSTEM_PROMPT = ( + _COMMON_PROMPT_HEADER + + """ ## Filesystem Tools -All file paths must start with a `/`. -- ls: list files and directories at a given path. -- read_file: read a file from the filesystem. -- write_file: create a temporary file in the session (not persisted). -- edit_file: edit a file in the session (not persisted for /documents/ files). -- glob: find files matching a pattern (e.g., "**/*.xml"). -- grep: search for text within files. -- save_document: **permanently** save a new document to the user's knowledge - base. Use only when the user explicitly asks to save/create a document. +All file paths must start with `/`. Relative paths resolve against the +current working directory (`cwd`, default `/documents`). + +- ls(path, offset=0, limit=200): list files and directories at the given path. +- read_file(path, offset, limit): read a file (paginated) from the filesystem. +- write_file(path, content): create a new text file in the workspace. +- edit_file(path, old, new): exact string-replacement edit (lazy-loads KB + documents on first edit). +- glob(pattern, path): find files matching a glob pattern. +- grep(pattern, path, glob): substring search across files. +- mkdir(path): create a folder under `/documents/` (committed at end of turn). +- cd(path): change the current working directory. +- pwd(): print the current working directory. +- move_file(source, dest): move/rename a file under `/documents/`. +- list_tree(path, max_depth, page_size): recursively list files/folders. + +## Persistence Rules + +- Files written under `/documents/<...>` are **persisted** at end of turn as + Documents in the user's knowledge base. +- Files whose **basename** starts with `temp_` (e.g. `temp_plan.md` or + `/documents/temp_scratch.md`) are **discarded** at end of turn — use this + prefix for any scratch/working content you do NOT want saved. +- All other paths (outside `/documents/` and not `temp_*`) are rejected. +- mkdir/move_file are staged this turn and committed at end of turn alongside + any new/edited documents. ## Reading Documents Efficiently @@ -85,23 +130,107 @@ those sections instead of reading the entire file sequentially. Use `` values as citation IDs in your answers. -## User-Mentioned Documents +## Priority List -When the `ls` output tags a file with `[MENTIONED BY USER — read deeply]`, -the user **explicitly selected** that document. These files are your highest- -priority sources: -1. **Always read them thoroughly** — scan the full ``, then read - all major sections, not just matched chunks. -2. **Prefer their content** over other search results when answering. -3. **Cite from them first** whenever applicable. +You receive a `` system message each turn listing the +top-K paths most relevant to the user's query (by hybrid search). Read those +first — matched sections are flagged inside each document's ``. + +## Workspace Tree + +You receive a `` system message each turn with the current +folder/document layout. The tree may be truncated past a hard cap; in that +case, drill into specific folders with `ls(...)` or `list_tree(...)`. + +## grep Line Numbers + +`grep` searches across both your in-memory edits and the indexed chunks in +Postgres. State-cached files return real line numbers; database hits return +`line=0` because their position depends on per-document XML layout — call +`read_file(path)` to find the exact line. """ +) + +_DESKTOP_SYSTEM_PROMPT = ( + _COMMON_PROMPT_HEADER + + """ +## Local Folder Mode + +This chat operates directly on the user's local folders. Writes and edits +hit disk immediately — there is no end-of-turn staging, no `/documents/` +namespace, and no `temp_` semantics. + +## Filesystem Tools + +All file paths must start with `/` and use mount-prefixed absolute paths +like `//file.ext`. Relative paths resolve against the current working +directory (`cwd`). + +- ls(path, offset=0, limit=200): list files and directories at the given path. +- read_file(path, offset, limit): read a file (paginated) from disk. +- write_file(path, content): write a file to disk. +- edit_file(path, old, new): exact string-replacement edit on disk. +- glob(pattern, path): find files matching a glob pattern. +- grep(pattern, path, glob): substring search across files. +- mkdir(path): create a directory on disk. +- cd(path): change the current working directory. +- pwd(): print the current working directory. +- move_file(source, dest): move/rename a file. +- list_tree(path, max_depth, page_size): recursively list files/folders. + +## Workflow Tips + +- If you are unsure which mounts are available, call `ls('/')` first. +- For large trees, prefer `list_tree` then `grep` then `read_file` over + brute-force directory traversal. +- Cross-mount moves are not supported. +""" +) + +_SANDBOX_PROMPT_ADDENDUM = ( + "\n- execute_code: run Python code in an isolated sandbox." + "\n\n## Code Execution" + "\n\nUse execute_code whenever a task benefits from running code." + " Never perform arithmetic manually." + "\n\nDocuments here are XML-wrapped markdown, not raw data files." + " To work with them programmatically, read the document first," + " extract the data, write it as a clean file (CSV, JSON, etc.)," + " and then run your code against it." +) + + +def _build_filesystem_system_prompt( + filesystem_mode: FilesystemMode, + *, + sandbox_available: bool, +) -> str: + """Build the filesystem system prompt for a given session mode. + + The prompt only describes rules and tools that actually apply in the + chosen mode — there is no cross-mode noise. + """ + base = ( + _CLOUD_SYSTEM_PROMPT + if filesystem_mode == FilesystemMode.CLOUD + else _DESKTOP_SYSTEM_PROMPT + ) + if sandbox_available: + base += _SANDBOX_PROMPT_ADDENDUM + return base + + +# Backwards-compatible alias retained for any external imports. +SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = _CLOUD_SYSTEM_PROMPT # ============================================================================= # Per-Tool Descriptions (shown to the LLM as the tool's docstring) # ============================================================================= -SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. -""" +# ============================================================================= +# Per-Tool Descriptions (mode-specific; injected as the tool's docstring) +# ============================================================================= + +# --- mode-agnostic --------------------------------------------------------- SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem. @@ -116,105 +245,241 @@ Usage: - Use chunk IDs (``) as citations in answers. """ -SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the in-memory filesystem (session-only). - -Use this to create scratch/working files during the conversation. Files created -here are ephemeral and will not be saved to the user's knowledge base. - -To permanently save a document to the user's knowledge base, use the -`save_document` tool instead. - -Supported outputs include common LLM-friendly text formats like markdown, json, -yaml, csv, xml, html, css, sql, and code files. - -When creating content from open-ended prompts, produce concrete and useful text, -not placeholders. Avoid adding dates/timestamps unless the user explicitly asks -for them. -""" - -SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. - -IMPORTANT: -- Read the file before editing. -- Preserve exact indentation and formatting. -- Edits to documents under `/documents/` are session-only (not persisted to the - database) because those files use an XML citation wrapper around the original - content. -""" - -SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. - -Use absolute paths for both source and destination. - -Notes: -- In local-folder mode, paths should use mount prefixes (e.g., //foo.txt). -- Rename is a special case of move (same folder, different filename). -- Cross-mount moves are not supported. -""" - -SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. - -Use this in desktop local-folder mode to discover nested files at scale. - -Args: -- path: absolute mount-prefixed path (e.g., //src) or "/" for mount roots. -- max_depth: recursion depth limit (default 8). -- page_size: maximum number of entries returned (max 1000). -- include_files/include_dirs: filter returned entry types. - -Returns JSON with: -- entries: [{path, is_dir, size, modified_at, depth}] -- truncated: true when additional entries were omitted due to page_size -""" - SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern. Supports standard glob patterns: `*`, `**`, `?`. Returns absolute file paths. """ -SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. +SURFSENSE_CD_TOOL_DESCRIPTION = """Changes the current working directory (cwd). -Use this to locate relevant document files/chunks before reading full files. +Args: +- path: absolute or relative directory path. Relative paths resolve against + the current cwd. + +The new cwd is used by other filesystem tools whenever a relative path is +given. Returns the resolved cwd. """ +SURFSENSE_PWD_TOOL_DESCRIPTION = """Prints the current working directory.""" + SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION = """Executes Python code in an isolated sandbox environment. Common data-science packages are pre-installed (pandas, numpy, matplotlib, scipy, scikit-learn). -When to use this tool: use execute_code for numerical computation, data -analysis, statistics, and any task that benefits from running Python code. -Never perform arithmetic manually when this tool is available. - Usage notes: - No outbound network access. - Returns combined stdout/stderr with exit code. - Use print() to produce output. -- You can create files, run shell commands via subprocess or os.system(), - and use any standard library module. - Use the optional timeout parameter to override the default timeout. """ -SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base. +# --- cloud-only ------------------------------------------------------------ -This is an expensive operation — it creates a new Document record in the -database, chunks the content, and generates embeddings for search. +_CLOUD_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. -Use ONLY when the user explicitly asks to save/create/store a document. -Do NOT use this for scratch work; use `write_file` for temporary files. +Usage: +- Provide an absolute path under `/documents` (relative paths resolve under + the current cwd, which defaults to `/documents`). +- For very large folders, use `offset` and `limit` to paginate the listing. +- Returns one entry per line; directories end with a trailing `/`. +""" + +_CLOUD_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new text file to the workspace. + +Usage: +- Files written under `/documents/<...>` are persisted as Documents at end + of turn. +- Use a `temp_` filename prefix (e.g. `temp_plan.md` or `/documents/temp_x.md`) + for scratch/working files; they are automatically discarded at end of turn. +- Writes outside `/documents/` are rejected unless the basename starts with + `temp_`. +- Supported outputs include common LLM-friendly text formats like markdown, + json, yaml, csv, xml, html, css, sql, and code files. +- Avoid placeholders; produce concrete and useful text. +""" + +_CLOUD_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files. + +IMPORTANT: +- Read the file before editing. +- Preserve exact indentation and formatting. +- Edits to documents under `/documents/` are persisted at end of turn. +- Edits to `temp_*` files are discarded at end of turn. +""" + +_CLOUD_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder. + +Use absolute paths for both source and destination. + +Notes: +- `move_file` is staged this turn and committed at end of turn. +- The agent cannot overwrite an existing destination — pass a fresh dest + path or move the existing destination away first. +- The anonymous uploaded document is read-only and cannot be moved. +- Rename is a special case of move (same folder, different filename). +""" + +_CLOUD_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. Args: - title: The document title (e.g., "Meeting Notes 2025-06-01"). - content: The plain-text or markdown content to save. Do NOT include XML - citation wrappers — pass only the actual document text. - folder_path: Optional folder path under /documents/ (e.g., "Work/Notes"). - Folders are created automatically if they don't exist. +- path: absolute path to start from. Defaults to `/documents`. +- max_depth: recursion depth limit (default 8). +- page_size: maximum number of entries returned (max 1000). +- include_files / include_dirs: filter returned entry types. + +Returns JSON with: +- entries: [{path, is_dir, size, modified_at, depth}] +- truncated: true when additional entries were omitted due to page_size. +""" + +_CLOUD_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. + +Searches both your in-memory edits and the indexed chunks in Postgres. +State-cached file matches include real line numbers; database hits return +`line=0` because their position depends on per-document XML layout — call +`read_file(path)` afterwards to find the exact line. +""" + +_CLOUD_MKDIR_TOOL_DESCRIPTION = """Creates a directory under `/documents/`. + +Stages the folder for end-of-turn commit; the Folder row is inserted only +after the agent's turn finishes successfully. + +Args: +- path: absolute path of the new directory (must start with + `/documents/`). + +Notes: +- Parent folders are created as needed. +""" + +# --- desktop-only ---------------------------------------------------------- + +_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path. + +Usage: +- Provide an absolute path using a mount prefix (e.g. `//sub/dir`). + Use `ls('/')` to discover available mounts. +- For very large folders, use `offset` and `limit` to paginate the listing. +- Returns one entry per line; directories end with a trailing `/`. +""" + +_DESKTOP_WRITE_FILE_TOOL_DESCRIPTION = """Writes a text file to disk. + +Usage: +- Use mount-prefixed absolute paths like `//sub/file.ext`. +- Writes hit disk immediately. There is no end-of-turn staging. +- Supported outputs include common LLM-friendly text formats like markdown, + json, yaml, csv, xml, html, css, sql, and code files. +- Avoid placeholders; produce concrete and useful text. +""" + +_DESKTOP_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files on disk. + +IMPORTANT: +- Read the file before editing. +- Preserve exact indentation and formatting. +- Edits hit disk immediately. +""" + +_DESKTOP_MOVE_FILE_TOOL_DESCRIPTION = """Moves or renames a file or folder on disk. + +Use mount-prefixed absolute paths for both source and destination +(e.g. `//old.txt` -> `//new.txt`). + +Notes: +- Cross-mount moves are not supported. +- Rename is a special case of move (same folder, different filename). +""" + +_DESKTOP_LIST_TREE_TOOL_DESCRIPTION = """Lists files/folders recursively in a single bounded call. + +Args: +- path: absolute path to start from. Defaults to `/`. +- max_depth: recursion depth limit (default 8). +- page_size: maximum number of entries returned (max 1000). +- include_files / include_dirs: filter returned entry types. + +Returns JSON with: +- entries: [{path, is_dir, size, modified_at, depth}] +- truncated: true when additional entries were omitted due to page_size. +""" + +_DESKTOP_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files. + +Searches files on disk and any in-memory edits. Returns real line numbers. +""" + +_DESKTOP_MKDIR_TOOL_DESCRIPTION = """Creates a directory on disk. + +Args: +- path: absolute mount-prefixed path of the new directory. + +Notes: +- Parent folders are created as needed. """ +def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]: + """Pick the active-mode description for every filesystem tool.""" + if filesystem_mode == FilesystemMode.CLOUD: + return { + "ls": _CLOUD_LIST_FILES_TOOL_DESCRIPTION, + "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, + "write_file": _CLOUD_WRITE_FILE_TOOL_DESCRIPTION, + "edit_file": _CLOUD_EDIT_FILE_TOOL_DESCRIPTION, + "move_file": _CLOUD_MOVE_FILE_TOOL_DESCRIPTION, + "list_tree": _CLOUD_LIST_TREE_TOOL_DESCRIPTION, + "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, + "grep": _CLOUD_GREP_TOOL_DESCRIPTION, + "mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION, + "cd": SURFSENSE_CD_TOOL_DESCRIPTION, + "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + } + return { + "ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION, + "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, + "write_file": _DESKTOP_WRITE_FILE_TOOL_DESCRIPTION, + "edit_file": _DESKTOP_EDIT_FILE_TOOL_DESCRIPTION, + "move_file": _DESKTOP_MOVE_FILE_TOOL_DESCRIPTION, + "list_tree": _DESKTOP_LIST_TREE_TOOL_DESCRIPTION, + "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, + "grep": _DESKTOP_GREP_TOOL_DESCRIPTION, + "mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION, + "cd": SURFSENSE_CD_TOOL_DESCRIPTION, + "pwd": SURFSENSE_PWD_TOOL_DESCRIPTION, + } + + +# Backwards-compatible aliases retained for any external imports/tests that +# referenced the original CLOUD-flavoured constants. +SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = _CLOUD_LIST_FILES_TOOL_DESCRIPTION +SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = _CLOUD_WRITE_FILE_TOOL_DESCRIPTION +SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = _CLOUD_EDIT_FILE_TOOL_DESCRIPTION +SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION = _CLOUD_MOVE_FILE_TOOL_DESCRIPTION +SURFSENSE_LIST_TREE_TOOL_DESCRIPTION = _CLOUD_LIST_TREE_TOOL_DESCRIPTION +SURFSENSE_GREP_TOOL_DESCRIPTION = _CLOUD_GREP_TOOL_DESCRIPTION +SURFSENSE_MKDIR_TOOL_DESCRIPTION = _CLOUD_MKDIR_TOOL_DESCRIPTION + + +# ============================================================================= +# Helpers +# ============================================================================= + + +_TEMP_PREFIX = "temp_" + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + class SurfSenseFilesystemMiddleware(FilesystemMiddleware): - """SurfSense-specific filesystem middleware with DB persistence for docs.""" + """SurfSense-specific filesystem middleware (cloud + desktop).""" + + state_schema = SurfSenseFilesystemState _MAX_EXECUTE_TIMEOUT = 300 @@ -234,582 +499,45 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): self._thread_id = thread_id self._sandbox_available = is_sandbox_enabled() and thread_id is not None - system_prompt = SURFSENSE_FILESYSTEM_SYSTEM_PROMPT - if self._sandbox_available: - system_prompt += ( - "\n- execute_code: run Python code in an isolated sandbox." - "\n\n## Code Execution" - "\n\nUse execute_code whenever a task benefits from running code." - " Never perform arithmetic manually." - "\n\nDocuments here are XML-wrapped markdown, not raw data files." - " To work with them programmatically, read the document first," - " extract the data, write it as a clean file (CSV, JSON, etc.)," - " and then run your code against it." - ) - if filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - system_prompt += ( - "\n- move_file: move or rename files/folders in local-folder mode." - "\n- list_tree: recursively list nested local paths in one bounded response." - "\n\n## Local Folder Mode" - "\n\nThis chat is running in desktop local-folder mode." - " Keep all file operations local. Do not use save_document." - " Always use mount-prefixed absolute paths like //file.ext." - " If you are unsure which mounts are available, call ls('/') first." - " For big trees: use list_tree, then grep, then read_file." - ) + # Build the prompt + tool descriptions for the active mode only — + # mixing both modes wastes tokens and confuses the model with rules + # it can't actually use this session. + system_prompt = _build_filesystem_system_prompt( + filesystem_mode, + sandbox_available=self._sandbox_available, + ) super().__init__( backend=backend, system_prompt=system_prompt, - custom_tool_descriptions={ - "ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION, - "read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION, - "write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION, - "edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION, - "move_file": SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION, - "list_tree": SURFSENSE_LIST_TREE_TOOL_DESCRIPTION, - "glob": SURFSENSE_GLOB_TOOL_DESCRIPTION, - "grep": SURFSENSE_GREP_TOOL_DESCRIPTION, - }, + custom_tool_descriptions=_build_tool_descriptions(filesystem_mode), tool_token_limit_before_evict=tool_token_limit_before_evict, max_execute_timeout=self._MAX_EXECUTE_TIMEOUT, ) self.tools = [t for t in self.tools if t.name != "execute"] - if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: - self.tools.append(self._create_move_file_tool()) - self.tools.append(self._create_list_tree_tool()) - if self._should_persist_documents(): - self.tools.append(self._create_save_document_tool()) + self.tools.append(self._create_mkdir_tool()) + self.tools.append(self._create_cd_tool()) + self.tools.append(self._create_pwd_tool()) + self.tools.append(self._create_move_file_tool()) + self.tools.append(self._create_list_tree_tool()) if self._sandbox_available: self.tools.append(self._create_execute_code_tool()) + # ------------------------------------------------------------------ helpers + + def _is_cloud(self) -> bool: + return self._filesystem_mode == FilesystemMode.CLOUD + @staticmethod def _run_async_blocking(coro: Any) -> Any: - """Run async coroutine from sync code path when no event loop is running.""" try: loop = asyncio.get_running_loop() if loop.is_running(): - return "Error: sync filesystem persistence not supported inside an active event loop." + return "Error: sync filesystem operation not supported inside an active event loop." except RuntimeError: pass return asyncio.run(coro) - @staticmethod - def _parse_virtual_path(file_path: str) -> tuple[list[str], str]: - """Parse /documents/... path into folder parts and a document title.""" - if not file_path.startswith("/documents/"): - return [], "" - rel = file_path[len("/documents/") :].strip("/") - if not rel: - return [], "" - parts = [part for part in rel.split("/") if part] - file_name = parts[-1] - title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name - return parts[:-1], title - - async def _ensure_folder_hierarchy( - self, - *, - folder_parts: list[str], - search_space_id: int, - ) -> int | None: - """Ensure folder hierarchy exists and return leaf folder ID.""" - if not folder_parts: - return None - async with shielded_async_session() as session: - parent_id: int | None = None - for name in folder_parts: - result = await session.execute( - select(Folder).where( - Folder.search_space_id == search_space_id, - Folder.parent_id == parent_id - if parent_id is not None - else Folder.parent_id.is_(None), - Folder.name == name, - ) - ) - folder = result.scalar_one_or_none() - if folder is None: - sibling_result = await session.execute( - select(Folder.position) - .where( - Folder.search_space_id == search_space_id, - Folder.parent_id == parent_id - if parent_id is not None - else Folder.parent_id.is_(None), - ) - .order_by(Folder.position.desc()) - .limit(1) - ) - last_position = sibling_result.scalar_one_or_none() - folder = Folder( - name=name, - position=generate_key_between(last_position, None), - parent_id=parent_id, - search_space_id=search_space_id, - created_by_id=self._created_by_id, - updated_at=datetime.now(UTC), - ) - session.add(folder) - await session.flush() - parent_id = folder.id - await session.commit() - return parent_id - - async def _persist_new_document( - self, *, file_path: str, content: str - ) -> dict[str, Any] | str: - """Persist a new NOTE document from a newly written file. - - Returns a dict with document metadata on success, or an error string. - """ - if self._search_space_id is None: - return {} - folder_parts, title = self._parse_virtual_path(file_path) - if not title: - return "Error: write_file for document persistence requires path under /documents/.xml" - folder_id = await self._ensure_folder_hierarchy( - folder_parts=folder_parts, - search_space_id=self._search_space_id, - ) - async with shielded_async_session() as session: - content_hash = generate_content_hash(content, self._search_space_id) - existing = await session.execute( - select(Document.id).where(Document.content_hash == content_hash) - ) - if existing.scalar_one_or_none() is not None: - return "Error: A document with identical content already exists." - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.NOTE, - file_path, - self._search_space_id, - ) - doc = Document( - title=title, - document_type=DocumentType.NOTE, - document_metadata={"virtual_path": file_path}, - content=content, - content_hash=content_hash, - unique_identifier_hash=unique_identifier_hash, - source_markdown=content, - search_space_id=self._search_space_id, - folder_id=folder_id, - created_by_id=self._created_by_id, - updated_at=datetime.now(UTC), - ) - session.add(doc) - await session.flush() - - summary_embedding = embed_texts([content])[0] - doc.embedding = summary_embedding - chunk_texts = chunk_text(content) - if chunk_texts: - chunk_embeddings = embed_texts(chunk_texts) - chunks = [ - Chunk(document_id=doc.id, content=text, embedding=embedding) - for text, embedding in zip( - chunk_texts, chunk_embeddings, strict=True - ) - ] - session.add_all(chunks) - await session.commit() - - return { - "id": doc.id, - "title": title, - "documentType": DocumentType.NOTE.value, - "searchSpaceId": self._search_space_id, - "folderId": folder_id, - "createdById": str(self._created_by_id) - if self._created_by_id - else None, - } - - async def _persist_edited_document( - self, *, file_path: str, updated_content: str - ) -> str | None: - """Persist edits for an existing NOTE document and recreate chunks.""" - if self._search_space_id is None: - return None - unique_identifier_hash = generate_unique_identifier_hash( - DocumentType.NOTE, - file_path, - self._search_space_id, - ) - doc_id_from_xml: int | None = None - match = re.search(r"\s*(\d+)\s*", updated_content) - if match: - doc_id_from_xml = int(match.group(1)) - async with shielded_async_session() as session: - doc_result = await session.execute( - select(Document).where( - Document.search_space_id == self._search_space_id, - Document.unique_identifier_hash == unique_identifier_hash, - ) - ) - document = doc_result.scalar_one_or_none() - if document is None and doc_id_from_xml is not None: - by_id_result = await session.execute( - select(Document).where( - Document.search_space_id == self._search_space_id, - Document.id == doc_id_from_xml, - ) - ) - document = by_id_result.scalar_one_or_none() - if document is None: - return "Error: Could not map edited file to an existing document." - - document.content = updated_content - document.source_markdown = updated_content - document.content_hash = generate_content_hash( - updated_content, self._search_space_id - ) - document.updated_at = datetime.now(UTC) - if not document.document_metadata: - document.document_metadata = {} - document.document_metadata["virtual_path"] = file_path - - summary_embedding = embed_texts([updated_content])[0] - document.embedding = summary_embedding - - await session.execute(delete(Chunk).where(Chunk.document_id == document.id)) - chunk_texts = chunk_text(updated_content) - if chunk_texts: - chunk_embeddings = embed_texts(chunk_texts) - session.add_all( - [ - Chunk( - document_id=document.id, content=text, embedding=embedding - ) - for text, embedding in zip( - chunk_texts, chunk_embeddings, strict=True - ) - ] - ) - await session.commit() - return None - - def _create_save_document_tool(self) -> BaseTool: - """Create save_document tool that persists a new document to the KB.""" - - def sync_save_document( - title: Annotated[str, "Title for the new document."], - content: Annotated[ - str, - "Plain-text or markdown content to save. Do NOT include XML wrappers.", - ], - runtime: ToolRuntime[None, FilesystemState], - folder_path: Annotated[ - str, - "Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.", - ] = "", - ) -> Command | str: - if not content.strip(): - return "Error: content cannot be empty." - file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled" - if not file_name.lower().endswith(".xml"): - file_name = f"{file_name}.xml" - folder = folder_path.strip().strip("/") if folder_path else "" - virtual_path = ( - f"/documents/{folder}/{file_name}" - if folder - else f"/documents/{file_name}" - ) - - persist_result = self._run_async_blocking( - self._persist_new_document(file_path=virtual_path, content=content) - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - return f"Document '{title}' saved to knowledge base (path: {virtual_path})." - - async def async_save_document( - title: Annotated[str, "Title for the new document."], - content: Annotated[ - str, - "Plain-text or markdown content to save. Do NOT include XML wrappers.", - ], - runtime: ToolRuntime[None, FilesystemState], - folder_path: Annotated[ - str, - "Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.", - ] = "", - ) -> Command | str: - if not content.strip(): - return "Error: content cannot be empty." - file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled" - if not file_name.lower().endswith(".xml"): - file_name = f"{file_name}.xml" - folder = folder_path.strip().strip("/") if folder_path else "" - virtual_path = ( - f"/documents/{folder}/{file_name}" - if folder - else f"/documents/{file_name}" - ) - - persist_result = await self._persist_new_document( - file_path=virtual_path, content=content - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - return f"Document '{title}' saved to knowledge base (path: {virtual_path})." - - return StructuredTool.from_function( - name="save_document", - description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION, - func=sync_save_document, - coroutine=async_save_document, - ) - - def _create_execute_code_tool(self) -> BaseTool: - """Create execute_code tool backed by a Daytona sandbox.""" - - def sync_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, FilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return self._run_async_blocking( - self._execute_in_sandbox(command, runtime, timeout) - ) - - async def async_execute_code( - command: Annotated[ - str, "Python code to execute. Use print() to see output." - ], - runtime: ToolRuntime[None, FilesystemState], - timeout: Annotated[ - int | None, - "Optional timeout in seconds.", - ] = None, - ) -> str: - if timeout is not None: - if timeout < 0: - return f"Error: timeout must be non-negative, got {timeout}." - if timeout > self._MAX_EXECUTE_TIMEOUT: - return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." - return await self._execute_in_sandbox(command, runtime, timeout) - - return StructuredTool.from_function( - name="execute_code", - description=SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION, - func=sync_execute_code, - coroutine=async_execute_code, - ) - - @staticmethod - def _wrap_as_python(code: str) -> str: - """Wrap Python code in a shell invocation for the sandbox.""" - sentinel = f"_PYEOF_{secrets.token_hex(8)}" - return f"python3 << '{sentinel}'\n{code}\n{sentinel}" - - async def _execute_in_sandbox( - self, - command: str, - runtime: ToolRuntime[None, FilesystemState], - timeout: int | None, - ) -> str: - """Core logic: get sandbox, sync files, run command, handle retries.""" - assert self._thread_id is not None - command = self._wrap_as_python(command) - - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except (DaytonaError, Exception) as first_err: - logger.warning( - "Sandbox execute failed for thread %s, retrying: %s", - self._thread_id, - first_err, - ) - try: - await delete_sandbox(self._thread_id) - except Exception: - _evict_sandbox_cache(self._thread_id) - try: - return await self._try_sandbox_execute(command, runtime, timeout) - except Exception: - logger.exception( - "Sandbox retry also failed for thread %s", self._thread_id - ) - return "Error: Code execution is temporarily unavailable. Please try again." - - async def _try_sandbox_execute( - self, - command: str, - runtime: ToolRuntime[None, FilesystemState], - timeout: int | None, - ) -> str: - sandbox, _is_new = await get_or_create_sandbox(self._thread_id) - # NOTE: sync_files_to_sandbox is intentionally disabled. - # The virtual FS contains XML-wrapped KB documents whose paths - # would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g. - # /home/daytona/documents/documents/Report.xml) and uploading - # all KB docs on the first execute_code call adds significant - # latency. Re-enable once path mapping is fixed and upload is - # limited to user-created scratch files. - # files = runtime.state.get("files") or {} - # await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new) - result = await sandbox.aexecute(command, timeout=timeout) - output = (result.output or "").strip() - if not output and result.exit_code == 0: - return ( - "[Code executed successfully but produced no output. " - "Use print() to display results, then try again.]" - ) - parts = [result.output] - if result.exit_code is not None: - status = "succeeded" if result.exit_code == 0 else "failed" - parts.append(f"\n[Command {status} with exit code {result.exit_code}]") - if result.truncated: - parts.append("\n[Output was truncated due to size limits]") - return "".join(parts) - - def _create_write_file_tool(self) -> BaseTool: - """Create write_file — ephemeral for /documents/*, persisted otherwise.""" - tool_description = ( - self._custom_tool_descriptions.get("write_file") - or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION - ) - - def sync_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Must be absolute, not relative.", - ], - content: Annotated[ - str, - "The text content to write to the file. This parameter is required.", - ], - runtime: ToolRuntime[None, FilesystemState], - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = resolved_backend.write(validated_path, content) - if res.error: - return res.error - verify_error = self._verify_written_content_sync( - backend=resolved_backend, - path=validated_path, - expected_content=content, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - persist_result = self._run_async_blocking( - self._persist_new_document( - file_path=validated_path, content=content - ) - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Updated file {res.path}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Updated file {res.path}" - - async def async_write_file( - file_path: Annotated[ - str, - "Absolute path where the file should be created. Must be absolute, not relative.", - ], - content: Annotated[ - str, - "The text content to write to the file. This parameter is required.", - ], - runtime: ToolRuntime[None, FilesystemState], - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = await resolved_backend.awrite(validated_path, content) - if res.error: - return res.error - verify_error = await self._verify_written_content_async( - backend=resolved_backend, - path=validated_path, - expected_content=content, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - persist_result = await self._persist_new_document( - file_path=validated_path, - content=content, - ) - if isinstance(persist_result, str): - return persist_result - if isinstance(persist_result, dict) and persist_result.get("id"): - dispatch_custom_event("document_created", persist_result) - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Updated file {res.path}", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Updated file {res.path}" - - return StructuredTool.from_function( - name="write_file", - description=tool_description, - func=sync_write_file, - coroutine=async_write_file, - ) - - @staticmethod - def _is_kb_document(path: str) -> bool: - """Return True for paths under /documents/ (KB-sourced, XML-wrapped).""" - return path.startswith("/documents/") - - def _should_persist_documents(self) -> bool: - """Only cloud mode persists file content to Document/Chunk tables.""" - return self._filesystem_mode == FilesystemMode.CLOUD - @staticmethod def _normalize_absolute_path(candidate: str) -> str: normalized = re.sub(r"/+", "/", candidate.strip().replace("\\", "/")) @@ -857,7 +585,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): def _normalize_local_mount_path( self, candidate: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: normalized = self._normalize_absolute_path(candidate) backend = self._get_backend(runtime) @@ -877,7 +605,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): suggested_path = contract.get("suggested_path") if isinstance(suggested_path, str) and suggested_path.strip(): normalized_suggested = self._normalize_absolute_path(suggested_path) - suggested_mount = self._extract_mount_from_path(normalized_suggested, mounts) + suggested_mount = self._extract_mount_from_path( + normalized_suggested, mounts + ) matching_mounts = [ mount @@ -902,265 +632,675 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): return f"/{backend.default_mount()}{normalized}" + def _default_cwd(self) -> str: + return DOCUMENTS_ROOT if self._is_cloud() else "/" + + def _current_cwd(self, runtime: ToolRuntime[None, SurfSenseFilesystemState]) -> str: + cwd = runtime.state.get("cwd") if hasattr(runtime, "state") else None + if isinstance(cwd, str) and cwd.startswith("/"): + return cwd + return self._default_cwd() + def _get_contract_suggested_path( - self, runtime: ToolRuntime[None, FilesystemState] + self, runtime: ToolRuntime[None, SurfSenseFilesystemState] ) -> str: contract = runtime.state.get("file_operation_contract") or {} suggested = contract.get("suggested_path") if isinstance(suggested, str) and suggested.strip(): return self._normalize_absolute_path(suggested) - return "/notes.md" + return self._default_cwd().rstrip("/") + "/notes.md" + + def _resolve_relative( + self, + path: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + candidate = path.strip() + if not candidate: + return self._current_cwd(runtime) + if candidate.startswith("/"): + return self._normalize_absolute_path(candidate) + cwd = self._current_cwd(runtime) + joined = posixpath.normpath(posixpath.join(cwd, candidate)) + return self._normalize_absolute_path(joined) def _resolve_write_target_path( self, file_path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: candidate = file_path.strip() if not candidate: return self._get_contract_suggested_path(runtime) if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) def _resolve_move_target_path( self, file_path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: candidate = file_path.strip() if not candidate: return "" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) def _resolve_list_target_path( self, path: str, - runtime: ToolRuntime[None, FilesystemState], + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str: - candidate = path.strip() or "/" + candidate = path.strip() or self._current_cwd(runtime) if candidate == "/": return "/" if self._filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER: return self._normalize_local_mount_path(candidate, runtime) - if not candidate.startswith("/"): - return f"/{candidate.lstrip('/')}" - return candidate + return self._resolve_relative(candidate, runtime) - @staticmethod - def _is_error_text(value: str) -> bool: - return value.startswith("Error:") + # ------------------------------------------------------------------ namespace policy - @staticmethod - def _read_for_verification_sync(backend: Any, path: str) -> str: - read_raw = getattr(backend, "read_raw", None) - if callable(read_raw): - return read_raw(path) - return backend.read(path, offset=0, limit=200000) - - @staticmethod - async def _read_for_verification_async(backend: Any, path: str) -> str: - aread_raw = getattr(backend, "aread_raw", None) - if callable(aread_raw): - return await aread_raw(path) - return await backend.aread(path, offset=0, limit=200000) - - def _verify_written_content_sync( + def _check_cloud_write_namespace( self, - *, - backend: Any, path: str, - expected_content: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], ) -> str | None: - actual = self._read_for_verification_sync(backend, path) - if self._is_error_text(actual): - return f"Error: could not verify written file '{path}'." - if actual.rstrip() != expected_content.rstrip(): - return ( - "Error: file write verification failed; expected content was not fully written " - f"to '{path}'." - ) - return None + """Return an error string if cloud writes to ``path`` are not allowed. - async def _verify_written_content_async( - self, - *, - backend: Any, - path: str, - expected_content: str, - ) -> str | None: - actual = await self._read_for_verification_async(backend, path) - if self._is_error_text(actual): - return f"Error: could not verify written file '{path}'." - if actual.rstrip() != expected_content.rstrip(): - return ( - "Error: file write verification failed; expected content was not fully written " - f"to '{path}'." - ) - return None + Order matters: + 1. Reject writes to the anonymous read-only doc. + 2. Allow ``/documents/*``. + 3. Allow ``temp_*`` basename anywhere. + 4. Reject everything else. + """ + if not self._is_cloud(): + return None + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict): + anon_path = str(anon.get("path") or "") + if anon_path and anon_path == path: + return "Error: the anonymous uploaded document is read-only." + if path.startswith(DOCUMENTS_ROOT + "/") or path == DOCUMENTS_ROOT: + return None + if _basename(path).startswith(_TEMP_PREFIX): + return None + return ( + "Error: cloud writes must target /documents/<...> or use a 'temp_' " + f"basename for scratch (got '{path}')." + ) - def _verify_edited_content_sync( - self, - *, - backend: Any, - path: str, - new_string: str, - ) -> tuple[str | None, str | None]: - updated_content = self._read_for_verification_sync(backend, path) - if self._is_error_text(updated_content): - return ( - f"Error: could not verify edited file '{path}'.", - None, - ) - if new_string and new_string not in updated_content: - return ( - "Error: edit verification failed; updated content was not found in " - f"'{path}'.", - None, - ) - return None, updated_content + # ------------------------------------------------------------------ tool: ls - async def _verify_edited_content_async( - self, - *, - backend: Any, - path: str, - new_string: str, - ) -> tuple[str | None, str | None]: - updated_content = await self._read_for_verification_async(backend, path) - if self._is_error_text(updated_content): - return ( - f"Error: could not verify edited file '{path}'.", - None, + def _create_ls_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("ls") + or SURFSENSE_LIST_FILES_TOOL_DESCRIPTION + ) + + def sync_ls( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to the directory to list. Relative paths resolve against the current cwd.", + ] = "", + offset: Annotated[ + int, + "Number of entries to skip. Use for paginating large folders. Defaults to 0.", + ] = 0, + limit: Annotated[ + int, + "Maximum number of entries to return. Defaults to 200.", + ] = 200, + ) -> str: + return self._run_async_blocking( + async_ls(runtime, path=path, offset=offset, limit=limit) ) - if new_string and new_string not in updated_content: - return ( - "Error: edit verification failed; updated content was not found in " - f"'{path}'.", - None, + + async def async_ls( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to the directory to list. Relative paths resolve against the current cwd.", + ] = "", + offset: Annotated[ + int, + "Number of entries to skip. Use for paginating large folders. Defaults to 0.", + ] = 0, + limit: Annotated[ + int, + "Maximum number of entries to return. Defaults to 200.", + ] = 200, + ) -> str: + target = self._resolve_list_target_path(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + if offset < 0: + offset = 0 + if limit < 1: + limit = 1 + backend = self._get_backend(runtime) + infos = await backend.als_info(validated) + page = paginate_listing(infos, offset=offset, limit=limit) + paths = [ + f"{fi.get('path', '')}/" if fi.get("is_dir") else fi.get("path", "") + for fi in page + ] + total = len(infos) + shown = len(page) + header = ( + f"{validated} ({shown} of {total} entries" + f"{f', offset={offset}' if offset else ''})" ) - return None, updated_content + if not paths: + return f"{header}\n(empty)" + body = "\n".join(paths) + if total > offset + shown: + body += ( + f"\n... {total - offset - shown} more — call ls(" + f"'{validated}', offset={offset + shown}, limit={limit})" + ) + return f"{header}\n{body}" + + return StructuredTool.from_function( + name="ls", + description=tool_description, + func=sync_ls, + coroutine=async_ls, + ) + + # ------------------------------------------------------------------ tool: read_file + + def _create_read_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("read_file") + or SURFSENSE_READ_FILE_TOOL_DESCRIPTION + ) + + async def async_read_file( + file_path: Annotated[ + str, + "Absolute path to the file to read. Relative paths resolve against the current cwd.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + offset: Annotated[ + int, + "Line number to start reading from (0-indexed).", + ] = 0, + limit: Annotated[ + int, + "Maximum number of lines to read.", + ] = 100, + ) -> Command | str: + target = self._resolve_relative(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + files = runtime.state.get("files") or {} + if validated in files: + return format_read_response(files[validated], offset, limit) + + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: File '{validated}' not found" + file_data, doc_id = loaded + rendered = format_read_response(file_data, offset, limit) + update: dict[str, Any] = { + "files": {validated: file_data}, + "messages": [ + ToolMessage( + content=rendered, + tool_call_id=runtime.tool_call_id, + ) + ], + } + if doc_id is not None: + update["doc_id_by_path"] = {validated: doc_id} + return Command(update=update) + + try: + rendered = await backend.aread(validated, offset=offset, limit=limit) + except Exception as exc: # pragma: no cover - defensive + return f"Error: {exc}" + return rendered + + def sync_read_file( + file_path: Annotated[ + str, + "Absolute path to the file to read. Relative paths resolve against the current cwd.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + offset: Annotated[ + int, + "Line number to start reading from (0-indexed).", + ] = 0, + limit: Annotated[ + int, + "Maximum number of lines to read.", + ] = 100, + ) -> Command | str: + return self._run_async_blocking( + async_read_file(file_path, runtime, offset, limit) + ) + + return StructuredTool.from_function( + name="read_file", + description=tool_description, + func=sync_read_file, + coroutine=async_read_file, + ) + + # ------------------------------------------------------------------ tool: write_file + + def _create_write_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("write_file") + or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION + ) + + async def async_write_file( + file_path: Annotated[ + str, + "Absolute path where the file should be created. Relative paths resolve against the current cwd.", + ], + content: Annotated[str, "Text content to write to the file."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_write_target_path(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + namespace_error = self._check_cloud_write_namespace(validated, runtime) + if namespace_error: + return namespace_error + + backend = self._get_backend(runtime) + res: WriteResult = await backend.awrite(validated, content) + if res.error: + return res.error + + path = res.path or validated + files_update = res.files_update or {path: create_file_data(content)} + update: dict[str, Any] = { + "files": files_update, + "messages": [ + ToolMessage( + content=f"Updated file {path}", + tool_call_id=runtime.tool_call_id, + ) + ], + } + if self._is_cloud(): + update["dirty_paths"] = [path] + return Command(update=update) + + def sync_write_file( + file_path: Annotated[ + str, + "Absolute path where the file should be created. Relative paths resolve against the current cwd.", + ], + content: Annotated[str, "Text content to write to the file."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking( + async_write_file(file_path, content, runtime) + ) + + return StructuredTool.from_function( + name="write_file", + description=tool_description, + func=sync_write_file, + coroutine=async_write_file, + ) + + # ------------------------------------------------------------------ tool: edit_file + + def _create_edit_file_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("edit_file") + or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION + ) + + async def async_edit_file( + file_path: Annotated[ + str, + "Absolute path to the file to edit. Relative paths resolve against the current cwd.", + ], + old_string: Annotated[ + str, + "Exact text to replace. Must be unique unless replace_all is True.", + ], + new_string: Annotated[ + str, + "Replacement text. Must differ from old_string.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + replace_all: Annotated[ + bool, + "If True, replace all occurrences of old_string. Defaults to False.", + ] = False, + ) -> Command | str: + target = self._resolve_relative(file_path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + namespace_error = self._check_cloud_write_namespace(validated, runtime) + if namespace_error: + return namespace_error + + backend = self._get_backend(runtime) + files_state = runtime.state.get("files") or {} + doc_id_to_attach: int | None = None + + if ( + self._is_cloud() + and validated not in files_state + and isinstance(backend, KBPostgresBackend) + ): + loaded = await backend._load_file_data(validated) + if loaded is None: + return f"Error: File '{validated}' not found" + _, doc_id_to_attach = loaded + + res: EditResult = await backend.aedit( + validated, old_string, new_string, replace_all=replace_all + ) + if res.error: + return res.error + + path = res.path or validated + files_update = res.files_update or {} + update: dict[str, Any] = { + "files": files_update, + "messages": [ + ToolMessage( + content=( + f"Successfully replaced {res.occurrences} instance(s) " + f"of the string in '{path}'" + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + if self._is_cloud(): + update["dirty_paths"] = [path] + if doc_id_to_attach is not None: + update["doc_id_by_path"] = {path: doc_id_to_attach} + return Command(update=update) + + def sync_edit_file( + file_path: Annotated[ + str, + "Absolute path to the file to edit. Relative paths resolve against the current cwd.", + ], + old_string: Annotated[ + str, + "Exact text to replace. Must be unique unless replace_all is True.", + ], + new_string: Annotated[ + str, + "Replacement text. Must differ from old_string.", + ], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + replace_all: Annotated[ + bool, + "If True, replace all occurrences of old_string. Defaults to False.", + ] = False, + ) -> Command | str: + return self._run_async_blocking( + async_edit_file( + file_path, old_string, new_string, runtime, replace_all=replace_all + ) + ) + + return StructuredTool.from_function( + name="edit_file", + description=tool_description, + func=sync_edit_file, + coroutine=async_edit_file, + ) + + # ------------------------------------------------------------------ tool: mkdir + + def _create_mkdir_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("mkdir") + or SURFSENSE_MKDIR_TOOL_DESCRIPTION + ) + + async def async_mkdir( + path: Annotated[str, "Absolute or relative directory path to create."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + if self._is_cloud(): + if not ( + validated.startswith(DOCUMENTS_ROOT + "/") + or validated == DOCUMENTS_ROOT + ): + return ( + "Error: cloud mkdir must target a path under /documents/ " + f"(got '{validated}')." + ) + return Command( + update={ + "staged_dirs": [validated], + "messages": [ + ToolMessage( + content=( + f"Staged directory '{validated}' (will be created " + "at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + backend = self._get_backend(runtime) + local_method = getattr(backend, "amkdir", None) or getattr( + backend, "mkdir", None + ) + if callable(local_method): + try: + res = local_method(validated, parents=True, exist_ok=True) + if asyncio.iscoroutine(res): + await res + except TypeError: + res = local_method(validated) + if asyncio.iscoroutine(res): + await res + except Exception as exc: # pragma: no cover + return f"Error: {exc}" + return f"Created directory {validated}" + + def sync_mkdir( + path: Annotated[str, "Absolute or relative directory path to create."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_mkdir(path, runtime)) + + return StructuredTool.from_function( + name="mkdir", + description=tool_description, + func=sync_mkdir, + coroutine=async_mkdir, + ) + + # ------------------------------------------------------------------ tool: cd + + def _create_cd_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("cd") or SURFSENSE_CD_TOOL_DESCRIPTION + ) + + async def async_cd( + path: Annotated[str, "Absolute or relative directory path to switch into."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + target = self._resolve_relative(path, runtime) + try: + validated = validate_path(target) + except ValueError as exc: + return f"Error: {exc}" + + backend = self._get_backend(runtime) + try: + infos = await backend.als_info(validated) + except Exception as exc: # pragma: no cover - defensive + return f"Error: {exc}" + staged_dirs = list(runtime.state.get("staged_dirs") or []) + files = runtime.state.get("files") or {} + cwd_exists = ( + bool(infos) + or validated in staged_dirs + or any(p == validated for p in files) + or any( + isinstance(p, str) and p.startswith(validated.rstrip("/") + "/") + for p in files + ) + or validated == "/" + or validated == DOCUMENTS_ROOT + ) + if not cwd_exists: + return f"Error: directory '{validated}' not found." + return Command( + update={ + "cwd": validated, + "messages": [ + ToolMessage( + content=f"cwd changed to {validated}", + tool_call_id=runtime.tool_call_id, + ) + ], + } + ) + + def sync_cd( + path: Annotated[str, "Absolute or relative directory path to switch into."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> Command | str: + return self._run_async_blocking(async_cd(path, runtime)) + + return StructuredTool.from_function( + name="cd", + description=tool_description, + func=sync_cd, + coroutine=async_cd, + ) + + # ------------------------------------------------------------------ tool: pwd + + def _create_pwd_tool(self) -> BaseTool: + tool_description = ( + self._custom_tool_descriptions.get("pwd") or SURFSENSE_PWD_TOOL_DESCRIPTION + ) + + def sync_pwd( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + return self._current_cwd(runtime) + + async def async_pwd( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + ) -> str: + return self._current_cwd(runtime) + + return StructuredTool.from_function( + name="pwd", + description=tool_description, + func=sync_pwd, + coroutine=async_pwd, + ) + + # ------------------------------------------------------------------ tool: move_file def _create_move_file_tool(self) -> BaseTool: - """Create move_file for desktop local-folder mode.""" tool_description = ( self._custom_tool_descriptions.get("move_file") or SURFSENSE_MOVE_FILE_TOOL_DESCRIPTION ) - def sync_move_file( - source_path: Annotated[ - str, - "Absolute source path to move from.", - ], - destination_path: Annotated[ - str, - "Absolute destination path to move to.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - overwrite: Annotated[ - bool, - "If True, replace an existing destination file. Defaults to False.", - ] = False, - ) -> Command | str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: move_file is only available in desktop local-folder mode." - - if not source_path.strip() or not destination_path.strip(): - return "Error: source_path and destination_path are required." - - resolved_backend = self._get_backend(runtime) - source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path(destination_path, runtime) - try: - validated_source = validate_path(source_target) - validated_destination = validate_path(destination_target) - except ValueError as exc: - return f"Error: {exc}" - res: WriteResult = resolved_backend.move( - validated_source, - validated_destination, - overwrite=overwrite, - ) - if res.error: - return res.error - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=( - f"Moved '{validated_source}' to " - f"'{res.path or validated_destination}'" - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Moved '{validated_source}' to '{res.path or validated_destination}'" - async def async_move_file( - source_path: Annotated[ - str, - "Absolute source path to move from.", - ], - destination_path: Annotated[ - str, - "Absolute destination path to move to.", - ], - runtime: ToolRuntime[None, FilesystemState], + source_path: Annotated[str, "Absolute or relative source path."], + destination_path: Annotated[str, "Absolute or relative destination path."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], *, overwrite: Annotated[ bool, - "If True, replace an existing destination file. Defaults to False.", + "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", ] = False, ) -> Command | str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: move_file is only available in desktop local-folder mode." - if not source_path.strip() or not destination_path.strip(): return "Error: source_path and destination_path are required." - resolved_backend = self._get_backend(runtime) - source_target = self._resolve_move_target_path(source_path, runtime) - destination_target = self._resolve_move_target_path(destination_path, runtime) + source = self._resolve_move_target_path(source_path, runtime) + dest = self._resolve_move_target_path(destination_path, runtime) try: - validated_source = validate_path(source_target) - validated_destination = validate_path(destination_target) + validated_source = validate_path(source) + validated_dest = validate_path(dest) except ValueError as exc: return f"Error: {exc}" - res: WriteResult = await resolved_backend.amove( - validated_source, - validated_destination, - overwrite=overwrite, + + if self._is_cloud(): + return await self._cloud_move_file( + runtime, + validated_source, + validated_dest, + overwrite=overwrite, + ) + + backend = self._get_backend(runtime) + res: WriteResult = await backend.amove( + validated_source, validated_dest, overwrite=overwrite ) if res.error: return res.error + update: dict[str, Any] = { + "messages": [ + ToolMessage( + content=f"Moved '{validated_source}' to '{res.path or validated_dest}'", + tool_call_id=runtime.tool_call_id, + ) + ], + } if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=( - f"Moved '{validated_source}' to " - f"'{res.path or validated_destination}'" - ), - tool_call_id=runtime.tool_call_id, - ) - ], - } + update["files"] = res.files_update + return Command(update=update) + + def sync_move_file( + source_path: Annotated[str, "Absolute or relative source path."], + destination_path: Annotated[str, "Absolute or relative destination path."], + runtime: ToolRuntime[None, SurfSenseFilesystemState], + *, + overwrite: Annotated[ + bool, + "If True, replace existing destination. Cloud mode rejects True. Defaults to False.", + ] = False, + ) -> Command | str: + return self._run_async_blocking( + async_move_file( + source_path, destination_path, runtime, overwrite=overwrite ) - return f"Moved '{validated_source}' to '{res.path or validated_destination}'" + ) return StructuredTool.from_function( name="move_file", @@ -1169,91 +1309,112 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): coroutine=async_move_file, ) + async def _cloud_move_file( + self, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + source: str, + dest: str, + *, + overwrite: bool, + ) -> Command | str: + backend = self._get_backend(runtime) + if not isinstance(backend, KBPostgresBackend): + return "Error: cloud move requires KBPostgresBackend." + + if source == dest: + return f"Moved '{source}' to '{dest}' (no-op)" + if overwrite: + return ( + "Error: overwrite=True is not supported in cloud mode. Move/edit " + "the destination doc explicitly first." + ) + if not source.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud move_file source must be under /documents/ (got " + f"'{source}')." + ) + if not dest.startswith(DOCUMENTS_ROOT + "/"): + return ( + "Error: cloud move_file destination must be under /documents/ (got " + f"'{dest}')." + ) + anon = runtime.state.get("kb_anon_doc") or {} + if isinstance(anon, dict): + anon_path = str(anon.get("path") or "") + if anon_path and (anon_path in (source, dest)): + return "Error: the anonymous uploaded document is read-only." + + files = runtime.state.get("files") or {} + doc_id_by_path = runtime.state.get("doc_id_by_path") or {} + pending_moves = list(runtime.state.get("pending_moves") or []) + + # Dest collision: occupied in state, in pending moves, or in DB. + if dest in files: + return f"Error: destination '{dest}' already exists." + if any(move.get("dest") == dest for move in pending_moves): + return f"Error: destination '{dest}' already exists." + if dest != source: + existing_dest = await backend._load_file_data(dest) + if existing_dest is not None: + return f"Error: destination '{dest}' already exists." + + # Source materialization: lazy load if not in state. + source_file_data = files.get(source) + source_doc_id = doc_id_by_path.get(source) + if source_file_data is None: + loaded = await backend._load_file_data(source) + if loaded is None: + return f"Error: source '{source}' not found." + source_file_data, loaded_doc_id = loaded + if source_doc_id is None: + source_doc_id = loaded_doc_id + + files_update: dict[str, Any] = {source: None, dest: source_file_data} + update: dict[str, Any] = { + "files": files_update, + "pending_moves": [{"source": source, "dest": dest, "overwrite": False}], + "messages": [ + ToolMessage( + content=( + f"Moved '{source}' to '{dest}' (will commit at end of turn)." + ), + tool_call_id=runtime.tool_call_id, + ) + ], + } + + doc_id_update: dict[str, int | None] = {source: None} + if source_doc_id is not None: + doc_id_update[dest] = source_doc_id + update["doc_id_by_path"] = doc_id_update + + dirty_paths = list(runtime.state.get("dirty_paths") or []) + if source in dirty_paths: + new_dirty: list[Any] = [_CLEAR] + for entry in dirty_paths: + new_dirty.append(dest if entry == source else entry) + update["dirty_paths"] = new_dirty + return Command(update=update) + + # ------------------------------------------------------------------ tool: list_tree + def _create_list_tree_tool(self) -> BaseTool: - """Create list_tree for desktop local-folder mode.""" tool_description = ( self._custom_tool_descriptions.get("list_tree") or SURFSENSE_LIST_TREE_TOOL_DESCRIPTION ) - def sync_list_tree( - runtime: ToolRuntime[None, FilesystemState], - *, - path: Annotated[ - str, - "Absolute path to list from. Use '/' for mount roots.", - ] = "/", - max_depth: Annotated[ - int, - "Maximum recursion depth to traverse. Defaults to 8.", - ] = 8, - page_size: Annotated[ - int, - "Maximum number of entries to return. Defaults to 500 (max 1000).", - ] = 500, - include_files: Annotated[ - bool, - "Whether file entries should be included.", - ] = True, - include_dirs: Annotated[ - bool, - "Whether directory entries should be included.", - ] = True, - ) -> str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: list_tree is only available in desktop local-folder mode." - if max_depth < 0: - return "Error: max_depth must be >= 0." - if page_size < 1: - return "Error: page_size must be >= 1." - if not include_files and not include_dirs: - return "Error: include_files and include_dirs cannot both be false." - - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_list_target_path(path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - - result = resolved_backend.list_tree( - validated_path, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - error = result.get("error") if isinstance(result, dict) else None - if isinstance(error, str) and error: - return error - return json.dumps(result, ensure_ascii=True) - async def async_list_tree( - runtime: ToolRuntime[None, FilesystemState], - *, + runtime: ToolRuntime[None, SurfSenseFilesystemState], path: Annotated[ str, - "Absolute path to list from. Use '/' for mount roots.", - ] = "/", - max_depth: Annotated[ - int, - "Maximum recursion depth to traverse. Defaults to 8.", - ] = 8, - page_size: Annotated[ - int, - "Maximum number of entries to return. Defaults to 500 (max 1000).", - ] = 500, - include_files: Annotated[ - bool, - "Whether file entries should be included.", - ] = True, - include_dirs: Annotated[ - bool, - "Whether directory entries should be included.", - ] = True, + "Absolute path to start from. Defaults to /documents in cloud mode.", + ] = "", + max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, + page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, + include_files: Annotated[bool, "Include file entries."] = True, + include_dirs: Annotated[bool, "Include directory entries."] = True, ) -> str: - if self._filesystem_mode != FilesystemMode.DESKTOP_LOCAL_FOLDER: - return "Error: list_tree is only available in desktop local-folder mode." if max_depth < 0: return "Error: max_depth must be >= 0." if page_size < 1: @@ -1261,25 +1422,58 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): if not include_files and not include_dirs: return "Error: include_files and include_dirs cannot both be false." - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_list_target_path(path, runtime) + target = self._resolve_list_target_path(path, runtime) try: - validated_path = validate_path(target_path) + validated = validate_path(target) except ValueError as exc: return f"Error: {exc}" - result = await resolved_backend.alist_tree( - validated_path, - max_depth=max_depth, - page_size=page_size, - include_files=include_files, - include_dirs=include_dirs, - ) - error = result.get("error") if isinstance(result, dict) else None - if isinstance(error, str) and error: - return error + backend = self._get_backend(runtime) + if isinstance(backend, KBPostgresBackend): + result = await backend.alist_tree_listing( + validated, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + elif hasattr(backend, "alist_tree"): + result = await backend.alist_tree( + validated, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + else: + return "Error: list_tree is not supported by the active backend." + + if isinstance(result, dict) and isinstance(result.get("error"), str): + return result["error"] return json.dumps(result, ensure_ascii=True) + def sync_list_tree( + runtime: ToolRuntime[None, SurfSenseFilesystemState], + path: Annotated[ + str, + "Absolute path to start from. Defaults to /documents in cloud mode.", + ] = "", + max_depth: Annotated[int, "Recursion depth limit. Default 8."] = 8, + page_size: Annotated[int, "Maximum entries returned. Max 1000."] = 500, + include_files: Annotated[bool, "Include file entries."] = True, + include_dirs: Annotated[bool, "Include directory entries."] = True, + ) -> str: + return self._run_async_blocking( + async_list_tree( + runtime, + path=path, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + ) + return StructuredTool.from_function( name="list_tree", description=tool_description, @@ -1287,162 +1481,103 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware): coroutine=async_list_tree, ) - def _create_edit_file_tool(self) -> BaseTool: - """Create edit_file with DB persistence (skipped for KB documents).""" - tool_description = ( - self._custom_tool_descriptions.get("edit_file") - or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION - ) + # ------------------------------------------------------------------ tool: execute_code (sandbox) - def sync_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Must be absolute, not relative.", + def _create_execute_code_tool(self) -> BaseTool: + def sync_execute_code( + command: Annotated[ + str, "Python code to execute. Use print() to see output." ], - old_string: Annotated[ - str, - "The exact text to find and replace. Must be unique in the file unless replace_all is True.", - ], - new_string: Annotated[ - str, - "The text to replace old_string with. Must be different from old_string.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. If False (default), old_string must be unique.", - ] = False, - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: EditResult = resolved_backend.edit( - validated_path, - old_string, - new_string, - replace_all=replace_all, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: Annotated[ + int | None, + "Optional timeout in seconds.", + ] = None, + ) -> str: + if timeout is not None: + if timeout < 0: + return f"Error: timeout must be non-negative, got {timeout}." + if timeout > self._MAX_EXECUTE_TIMEOUT: + return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." + return self._run_async_blocking( + self._execute_in_sandbox(command, runtime, timeout) ) - if res.error: - return res.error - verify_error, updated_content = self._verify_edited_content_sync( - backend=resolved_backend, - path=validated_path, - new_string=new_string, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - if updated_content is None: - return ( - f"Error: could not reload edited file '{validated_path}' for " - "persistence." - ) - persist_result = self._run_async_blocking( - self._persist_edited_document( - file_path=validated_path, - updated_content=updated_content, - ) - ) - if isinstance(persist_result, str): - return persist_result - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'" - - async def async_edit_file( - file_path: Annotated[ - str, - "Absolute path to the file to edit. Must be absolute, not relative.", + async def async_execute_code( + command: Annotated[ + str, "Python code to execute. Use print() to see output." ], - old_string: Annotated[ - str, - "The exact text to find and replace. Must be unique in the file unless replace_all is True.", - ], - new_string: Annotated[ - str, - "The text to replace old_string with. Must be different from old_string.", - ], - runtime: ToolRuntime[None, FilesystemState], - *, - replace_all: Annotated[ - bool, - "If True, replace all occurrences of old_string. If False (default), old_string must be unique.", - ] = False, - ) -> Command | str: - resolved_backend = self._get_backend(runtime) - target_path = self._resolve_write_target_path(file_path, runtime) - try: - validated_path = validate_path(target_path) - except ValueError as exc: - return f"Error: {exc}" - res: EditResult = await resolved_backend.aedit( - validated_path, - old_string, - new_string, - replace_all=replace_all, - ) - if res.error: - return res.error - - verify_error, updated_content = await self._verify_edited_content_async( - backend=resolved_backend, - path=validated_path, - new_string=new_string, - ) - if verify_error: - return verify_error - - if self._should_persist_documents() and not self._is_kb_document( - validated_path - ): - if updated_content is None: - return ( - f"Error: could not reload edited file '{validated_path}' for " - "persistence." - ) - persist_error = await self._persist_edited_document( - file_path=validated_path, - updated_content=updated_content, - ) - if persist_error: - return persist_error - - if res.files_update is not None: - return Command( - update={ - "files": res.files_update, - "messages": [ - ToolMessage( - content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'", - tool_call_id=runtime.tool_call_id, - ) - ], - } - ) - return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'" + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: Annotated[ + int | None, + "Optional timeout in seconds.", + ] = None, + ) -> str: + if timeout is not None: + if timeout < 0: + return f"Error: timeout must be non-negative, got {timeout}." + if timeout > self._MAX_EXECUTE_TIMEOUT: + return f"Error: timeout {timeout}s exceeds maximum ({self._MAX_EXECUTE_TIMEOUT}s)." + return await self._execute_in_sandbox(command, runtime, timeout) return StructuredTool.from_function( - name="edit_file", - description=tool_description, - func=sync_edit_file, - coroutine=async_edit_file, + name="execute_code", + description=SURFSENSE_EXECUTE_CODE_TOOL_DESCRIPTION, + func=sync_execute_code, + coroutine=async_execute_code, ) + + @staticmethod + def _wrap_as_python(code: str) -> str: + sentinel = f"_PYEOF_{secrets.token_hex(8)}" + return f"python3 << '{sentinel}'\n{code}\n{sentinel}" + + async def _execute_in_sandbox( + self, + command: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: int | None, + ) -> str: + assert self._thread_id is not None + command = self._wrap_as_python(command) + try: + return await self._try_sandbox_execute(command, runtime, timeout) + except (DaytonaError, Exception) as first_err: + logger.warning( + "Sandbox execute failed for thread %s, retrying: %s", + self._thread_id, + first_err, + ) + try: + await delete_sandbox(self._thread_id) + except Exception: + _evict_sandbox_cache(self._thread_id) + try: + return await self._try_sandbox_execute(command, runtime, timeout) + except Exception: + logger.exception( + "Sandbox retry also failed for thread %s", self._thread_id + ) + return "Error: Code execution is temporarily unavailable. Please try again." + + async def _try_sandbox_execute( + self, + command: str, + runtime: ToolRuntime[None, SurfSenseFilesystemState], + timeout: int | None, + ) -> str: + sandbox, _is_new = await get_or_create_sandbox(self._thread_id) + result = await sandbox.aexecute(command, timeout=timeout) + output = (result.output or "").strip() + if not output and result.exit_code == 0: + return ( + "[Code executed successfully but produced no output. " + "Use print() to display results, then try again.]" + ) + parts = [result.output] + if result.exit_code is not None: + status = "succeeded" if result.exit_code == 0 else "failed" + parts.append(f"\n[Command {status} with exit code {result.exit_code}]") + if result.truncated: + parts.append("\n[Output was truncated due to size limits]") + return "".join(parts) diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py new file mode 100644 index 000000000..378b83950 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_persistence.py @@ -0,0 +1,645 @@ +"""End-of-turn persistence for the cloud-mode SurfSense filesystem. + +This middleware runs ``aafter_agent`` once per turn (cloud only). It commits +all staged folder creations, file moves, and content writes/edits to +Postgres in a single ordered pass: + +1. Materialize ``staged_dirs`` into ``Folder`` rows. +2. Apply ``pending_moves`` in order (chained moves resolved via + ``doc_id_by_path``). +3. Normalize ``dirty_paths`` through ``pending_moves`` so write-then-move + sequences commit at the final path. +4. Commit content writes / edits for ``/documents/*`` paths, skipping + ``temp_*`` basenames. + +The commit body is exposed as a free function ``commit_staged_filesystem_state`` +so the optional stream-task fallback (``stream_new_chat.py``) can call the +exact same routine when ``aafter_agent`` was skipped (e.g. client disconnect). +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any + +from fractional_indexing import generate_key_between +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.callbacks import dispatch_custom_event +from langgraph.runtime import Runtime +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + parse_documents_path, + safe_folder_segment, + virtual_path_to_doc, +) +from app.agents.new_chat.state_reducers import _CLEAR +from app.db import ( + Chunk, + Document, + DocumentType, + Folder, + shielded_async_session, +) +from app.indexing_pipeline.document_chunker import chunk_text +from app.utils.document_converters import ( + embed_texts, + generate_content_hash, + generate_unique_identifier_hash, +) + +logger = logging.getLogger(__name__) + + +_TEMP_PREFIX = "temp_" + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + +# --------------------------------------------------------------------------- +# Folder helpers +# --------------------------------------------------------------------------- + + +async def _ensure_folder_hierarchy( + session: AsyncSession, + *, + search_space_id: int, + created_by_id: str | None, + folder_parts: list[str], +) -> int | None: + """Ensure a chain of folder names exists under the search space. + + Returns the leaf folder id, or ``None`` if ``folder_parts`` is empty + (i.e. a document directly under ``/documents/``). + """ + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(str(raw)) + query = select(Folder).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + if parent_id is None: + query = query.where(Folder.parent_id.is_(None)) + else: + query = query.where(Folder.parent_id == parent_id) + result = await session.execute(query) + folder = result.scalar_one_or_none() + if folder is None: + sibling_query = ( + select(Folder.position).order_by(Folder.position.desc()).limit(1) + ) + sibling_query = sibling_query.where( + Folder.search_space_id == search_space_id + ) + if parent_id is None: + sibling_query = sibling_query.where(Folder.parent_id.is_(None)) + else: + sibling_query = sibling_query.where(Folder.parent_id == parent_id) + sibling_result = await session.execute(sibling_query) + last_position = sibling_result.scalar_one_or_none() + folder = Folder( + name=name, + position=generate_key_between(last_position, None), + parent_id=parent_id, + search_space_id=search_space_id, + created_by_id=created_by_id, + updated_at=datetime.now(UTC), + ) + session.add(folder) + await session.flush() + parent_id = folder.id + return parent_id + + +# --------------------------------------------------------------------------- +# Document helpers +# --------------------------------------------------------------------------- + + +async def _create_document( + session: AsyncSession, + *, + virtual_path: str, + content: str, + search_space_id: int, + created_by_id: str | None, +) -> Document: + """Create a NOTE Document + Chunks for ``virtual_path``.""" + folder_parts, title = parse_documents_path(virtual_path) + if not title: + raise ValueError(f"invalid /documents path '{virtual_path}'") + folder_id = await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts, + ) + unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + # Filesystem-parity invariant: the only thing that *must* be unique is + # the path. Two notes can legitimately share content (e.g. ``cp a b``). + # Guard against the path-derived ``unique_identifier_hash`` constraint + # so we surface a clean ValueError instead of letting the INSERT poison + # the session with an IntegrityError. + path_collision = await session.execute( + select(Document.id).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_identifier_hash, + ) + ) + if path_collision.scalar_one_or_none() is not None: + raise ValueError( + f"a document already exists at path '{virtual_path}' " + "(unique_identifier_hash collision)" + ) + # ``content_hash`` is intentionally NOT checked for uniqueness here. + # In a real filesystem two files at different paths can hold identical + # bytes, and the agent's ``write_file`` path needs that semantic to + # support copy/duplicate operations. The hash remains useful as a + # change-detection hint for connector indexers, which still consult it + # via :func:`check_duplicate_document` but do so with a non-unique + # lookup (``.first()``). + content_hash = generate_content_hash(content, search_space_id) + doc = Document( + title=title, + document_type=DocumentType.NOTE, + document_metadata={"virtual_path": virtual_path}, + content=content, + content_hash=content_hash, + unique_identifier_hash=unique_identifier_hash, + source_markdown=content, + search_space_id=search_space_id, + folder_id=folder_id, + created_by_id=created_by_id, + updated_at=datetime.now(UTC), + ) + session.add(doc) + await session.flush() + + summary_embedding = embed_texts([content])[0] + doc.embedding = summary_embedding + chunks = chunk_text(content) + if chunks: + chunk_embeddings = embed_texts(chunks) + session.add_all( + [ + Chunk(document_id=doc.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) + ] + ) + return doc + + +async def _update_document( + session: AsyncSession, + *, + doc_id: int, + content: str, + virtual_path: str, + search_space_id: int, +) -> Document | None: + """Update an existing Document's content + chunks.""" + result = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + document = result.scalar_one_or_none() + if document is None: + return None + + document.content = content + document.source_markdown = content + document.content_hash = generate_content_hash(content, search_space_id) + document.updated_at = datetime.now(UTC) + metadata = dict(document.document_metadata or {}) + metadata["virtual_path"] = virtual_path + document.document_metadata = metadata + document.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + + summary_embedding = embed_texts([content])[0] + document.embedding = summary_embedding + + await session.execute(delete(Chunk).where(Chunk.document_id == document.id)) + chunks = chunk_text(content) + if chunks: + chunk_embeddings = embed_texts(chunks) + session.add_all( + [ + Chunk(document_id=document.id, content=text, embedding=embedding) + for text, embedding in zip(chunks, chunk_embeddings, strict=True) + ] + ) + return document + + +# --------------------------------------------------------------------------- +# Move helpers +# --------------------------------------------------------------------------- + + +async def _apply_move( + session: AsyncSession, + *, + search_space_id: int, + created_by_id: str | None, + move: dict[str, Any], + doc_id_by_path: dict[str, int], + doc_id_path_tombstones: dict[str, int | None], +) -> dict[str, Any] | None: + """Apply a single staged move; updates the in-memory mapping for chain resolution.""" + source = str(move.get("source") or "") + dest = str(move.get("dest") or "") + if not source or not dest or source == dest: + return None + + if not source.startswith(DOCUMENTS_ROOT + "/") or not dest.startswith( + DOCUMENTS_ROOT + "/" + ): + return None + + doc_id: int | None = doc_id_by_path.get(source) + document: Document | None = None + if doc_id is not None: + result = await session.execute( + select(Document).where( + Document.id == doc_id, + Document.search_space_id == search_space_id, + ) + ) + document = result.scalar_one_or_none() + if document is None: + document = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=source, + ) + if document is None: + logger.info( + "kb_persistence: skipping move %s -> %s (source not found)", + source, + dest, + ) + return None + + folder_parts, new_title = parse_documents_path(dest) + if not new_title: + return None + folder_id = await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts, + ) + + document.title = new_title + document.folder_id = folder_id + metadata = dict(document.document_metadata or {}) + metadata["virtual_path"] = dest + document.document_metadata = metadata + document.unique_identifier_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + dest, + search_space_id, + ) + document.updated_at = datetime.now(UTC) + + doc_id_by_path.pop(source, None) + doc_id_by_path[dest] = document.id + doc_id_path_tombstones[source] = None + doc_id_path_tombstones[dest] = document.id + return {"id": document.id, "source": source, "dest": dest, "title": new_title} + + +# --------------------------------------------------------------------------- +# Commit body +# --------------------------------------------------------------------------- + + +async def commit_staged_filesystem_state( + state: dict[str, Any] | AgentState, + *, + search_space_id: int, + created_by_id: str | None, + filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, + dispatch_events: bool = True, +) -> dict[str, Any] | None: + """Commit all staged filesystem changes; return the state delta for reducers. + + Shared between :class:`KnowledgeBasePersistenceMiddleware.aafter_agent` + and the optional stream-task fallback. + """ + if filesystem_mode != FilesystemMode.CLOUD: + return None + + state_dict: dict[str, Any] = ( + dict(state) + if isinstance(state, dict) + else dict(getattr(state, "values", {}) or {}) + ) + + files: dict[str, Any] = state_dict.get("files") or {} + staged_dirs: list[str] = list(state_dict.get("staged_dirs") or []) + pending_moves: list[dict[str, Any]] = list(state_dict.get("pending_moves") or []) + dirty_paths: list[str] = list(state_dict.get("dirty_paths") or []) + doc_id_by_path: dict[str, int] = dict(state_dict.get("doc_id_by_path") or {}) + kb_anon_doc = state_dict.get("kb_anon_doc") + + if kb_anon_doc: + temp_paths = [ + p + for p in files + if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) + ] + return { + "dirty_paths": [_CLEAR], + "staged_dirs": [_CLEAR], + "pending_moves": [_CLEAR], + "files": dict.fromkeys(temp_paths), + } + + if not (staged_dirs or pending_moves or dirty_paths): + return None + + committed_creates: list[dict[str, Any]] = [] + committed_updates: list[dict[str, Any]] = [] + discarded: list[str] = [] + applied_moves: list[dict[str, Any]] = [] + doc_id_path_tombstones: dict[str, int | None] = {} + tree_changed = False + + try: + async with shielded_async_session() as session: + for folder_path in staged_dirs: + if not isinstance(folder_path, str): + continue + if not folder_path.startswith(DOCUMENTS_ROOT): + continue + rel = folder_path[len(DOCUMENTS_ROOT) :].strip("/") + folder_parts_full = [p for p in rel.split("/") if p] + if not folder_parts_full: + continue + await _ensure_folder_hierarchy( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + folder_parts=folder_parts_full, + ) + tree_changed = True + + for move in pending_moves: + applied = await _apply_move( + session, + search_space_id=search_space_id, + created_by_id=created_by_id, + move=move, + doc_id_by_path=doc_id_by_path, + doc_id_path_tombstones=doc_id_path_tombstones, + ) + if applied: + applied_moves.append(applied) + tree_changed = True + + move_alias = { + m["source"]: m["dest"] for m in pending_moves if m.get("source") + } + + def _final_path(path: str) -> str: + seen: set[str] = set() + while path in move_alias and path not in seen: + seen.add(path) + path = move_alias[path] + return path + + kb_dirty_seen: set[str] = set() + kb_dirty: list[str] = [] + for raw in dirty_paths: + if not isinstance(raw, str): + continue + final = _final_path(raw) + if not final.startswith(DOCUMENTS_ROOT + "/"): + continue + if final in kb_dirty_seen: + continue + kb_dirty_seen.add(final) + kb_dirty.append(final) + + for path in kb_dirty: + basename = _basename(path) + if basename.startswith(_TEMP_PREFIX): + discarded.append(path) + continue + file_data = files.get(path) + if not isinstance(file_data, dict): + continue + content = "\n".join(file_data.get("content") or []) + doc_id = doc_id_by_path.get(path) + if doc_id is None: + # The in-memory ``doc_id_by_path`` is per-thread and starts + # empty in every new chat. If the agent writes to a path + # that already exists in the DB (e.g. a previous chat's + # ``notes.md``), we must NOT try to INSERT — it would hit + # ``unique_identifier_hash`` (path-derived). Look up the + # existing doc and update it in place instead. + existing = await virtual_path_to_doc( + session, + search_space_id=search_space_id, + virtual_path=path, + ) + if existing is not None: + doc_id = existing.id + doc_id_by_path[path] = existing.id + if doc_id is not None: + updated = await _update_document( + session, + doc_id=doc_id, + content=content, + virtual_path=path, + search_space_id=search_space_id, + ) + if updated is not None: + committed_updates.append( + { + "id": updated.id, + "title": updated.title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": updated.folder_id, + "createdById": str(created_by_id) + if created_by_id + else None, + "virtualPath": path, + } + ) + else: + # Wrap each create in a SAVEPOINT so a residual + # ``IntegrityError`` (e.g. a deployment that hasn't run + # migration 133 yet, where ``documents.content_hash`` + # still carries its legacy global UNIQUE constraint) + # rolls back only this one create instead of poisoning + # the whole turn's transaction. + try: + async with session.begin_nested(): + new_doc = await _create_document( + session, + virtual_path=path, + content=content, + search_space_id=search_space_id, + created_by_id=created_by_id, + ) + except ValueError as exc: + logger.warning( + "kb_persistence: skipping %s create: %s", path, exc + ) + continue + except IntegrityError as exc: + # The path-uniqueness check above already protected + # against ``unique_identifier_hash`` collisions, so + # the most likely culprit is the legacy + # ``ix_documents_content_hash`` UNIQUE constraint + # that migration 133 drops. Log loudly so operators + # know to run the migration; do NOT silently swallow. + msg = str(exc.orig) if exc.orig is not None else str(exc) + logger.error( + "kb_persistence: IntegrityError creating %s: %s. " + "If this mentions content_hash, run alembic " + "upgrade to apply migration 133 which drops the " + "global UNIQUE constraint on documents.content_hash.", + path, + msg, + ) + continue + doc_id_by_path[path] = new_doc.id + committed_creates.append( + { + "id": new_doc.id, + "title": new_doc.title, + "documentType": DocumentType.NOTE.value, + "searchSpaceId": search_space_id, + "folderId": new_doc.folder_id, + "createdById": str(created_by_id) + if created_by_id + else None, + "virtualPath": path, + } + ) + tree_changed = True + + await session.commit() + except Exception: # pragma: no cover - rollback safety net + logger.exception( + "kb_persistence: commit failed (search_space=%s)", search_space_id + ) + return None + + if dispatch_events: + for payload in committed_creates: + try: + dispatch_custom_event("document_created", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_created event" + ) + for payload in committed_updates: + try: + dispatch_custom_event("document_updated", payload) + except Exception: + logger.exception( + "kb_persistence: failed to dispatch document_updated event" + ) + + temp_paths = [ + p for p in files if isinstance(p, str) and _basename(p).startswith(_TEMP_PREFIX) + ] + + doc_id_update: dict[str, int | None] = {**doc_id_path_tombstones} + for payload in committed_creates: + doc_id_update[str(payload.get("virtualPath") or "")] = int(payload["id"]) + + delta: dict[str, Any] = { + "dirty_paths": [_CLEAR], + "staged_dirs": [_CLEAR], + "pending_moves": [_CLEAR], + } + if temp_paths: + delta["files"] = dict.fromkeys(temp_paths) + if doc_id_update: + delta["doc_id_by_path"] = doc_id_update + if tree_changed: + delta["tree_version"] = int(state_dict.get("tree_version") or 0) + 1 + + logger.info( + "kb_persistence: commit (search_space=%s) creates=%d updates=%d " + "moves=%d staged_dirs=%d discarded=%d", + search_space_id, + len(committed_creates), + len(committed_updates), + len(applied_moves), + len(staged_dirs), + len(discarded), + ) + return delta + + +# --------------------------------------------------------------------------- +# Middleware +# --------------------------------------------------------------------------- + + +class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-arg] + """End-of-turn cloud persistence for the SurfSense filesystem agent.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__( + self, + *, + search_space_id: int, + created_by_id: str | None, + filesystem_mode: FilesystemMode, + ) -> None: + self.search_space_id = search_space_id + self.created_by_id = created_by_id + self.filesystem_mode = filesystem_mode + + async def aafter_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + return await commit_staged_filesystem_state( + state, + search_space_id=self.search_space_id, + created_by_id=self.created_by_id, + filesystem_mode=self.filesystem_mode, + ) + + +__all__ = [ + "KnowledgeBasePersistenceMiddleware", + "commit_staged_filesystem_state", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py new file mode 100644 index 000000000..ddb2d4af1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/kb_postgres_backend.py @@ -0,0 +1,963 @@ +"""Postgres-backed virtual filesystem for the SurfSense agent (cloud mode). + +The backend is **strictly conforming** to deepagents' +:class:`BackendProtocol`. It returns ``WriteResult`` / ``EditResult`` / list +shapes exactly as upstream expects (no extra fields). All side-state +plumbing — ``dirty_paths``, ``doc_id_by_path``, ``staged_dirs``, +``pending_moves``, ``files`` cache — is appended by the overridden tool +wrappers in :class:`SurfSenseFilesystemMiddleware` via ``Command.update``. + +The backend never writes to Postgres. End-of-turn persistence is handled by +:class:`KnowledgeBasePersistenceMiddleware`. This module is purely a +read-side and a state-merging helper. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import fnmatch +import logging +import re +from datetime import UTC +from typing import Any + +from deepagents.backends.protocol import ( + BackendProtocol, + EditResult, + FileDownloadResponse, + FileInfo, + FileUploadResponse, + GrepMatch, + WriteResult, +) +from deepagents.backends.utils import ( + create_file_data, + file_data_to_string, + format_read_response, + perform_string_replacement, + update_file_data, +) +from langchain.tools import ToolRuntime +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.document_xml import build_document_xml +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + build_path_index, + doc_to_virtual_path, + virtual_path_to_doc, +) +from app.db import Chunk, Document, shielded_async_session + +logger = logging.getLogger(__name__) + +_TEMP_PREFIX = "temp_" +_GREP_MAX_TOTAL_MATCHES = 50 +_GREP_MAX_PER_DOC = 5 + + +def _basename(path: str) -> str: + return path.rsplit("/", 1)[-1] + + +def _is_under(child: str, parent: str) -> bool: + """Return True iff ``child`` is at-or-under ``parent`` (directory semantics).""" + if parent == "/": + return child.startswith("/") + return child == parent or child.startswith(parent.rstrip("/") + "/") + + +def paginate_listing( + infos: list[FileInfo], + *, + offset: int = 0, + limit: int | None = None, +) -> list[FileInfo]: + """Paginate a listing produced by :meth:`KBPostgresBackend.als_info`.""" + if offset < 0: + offset = 0 + end: int | None + end = None if limit is None or limit < 0 else offset + limit + return list(infos[offset:end]) + + +class KBPostgresBackend(BackendProtocol): + """Lazy, read-only Postgres view for ``/documents/*`` virtual paths. + + The backend exposes a virtual ``/documents/`` namespace mirroring the + ``Folder``/``Document`` graph. Reads materialize XML on first access and + cache it via the overriding tool wrappers (NOT here). Writes never touch + the DB — they return ``files_update`` deltas that the wrappers turn into + Command updates, and the persistence middleware commits them at end of + turn. + """ + + _IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".gif", ".webp"}) + + def __init__(self, search_space_id: int, runtime: ToolRuntime) -> None: + self.search_space_id = search_space_id + self.runtime = runtime + + @property + def state(self) -> dict[str, Any]: + return getattr(self.runtime, "state", {}) or {} + + # ------------------------------------------------------------------ helpers + + def _state_files(self) -> dict[str, Any]: + return dict(self.state.get("files") or {}) + + def _staged_dirs(self) -> list[str]: + return list(self.state.get("staged_dirs") or []) + + def _pending_moves(self) -> list[dict[str, Any]]: + return list(self.state.get("pending_moves") or []) + + def _kb_anon_doc(self) -> dict[str, Any] | None: + anon = self.state.get("kb_anon_doc") + return anon if isinstance(anon, dict) else None + + def _matched_chunk_ids(self, doc_id: int) -> set[int]: + mapping = self.state.get("kb_matched_chunk_ids") or {} + try: + return set(mapping.get(doc_id, []) or []) + except TypeError: + return set() + + @staticmethod + def _file_data_size(file_data: dict[str, Any]) -> int: + try: + return len("\n".join(file_data.get("content") or [])) + except Exception: + return 0 + + def _normalize_listing_path(self, path: str) -> str: + if not path: + return DOCUMENTS_ROOT + if path == "/": + return path + return path.rstrip("/") if path != "/" else path + + def _moved_view_paths( + self, + existing: dict[str, dict[str, Any]], + ) -> tuple[set[str], dict[str, str]]: + """Apply ``pending_moves`` to a path set and return ``(removed, alias)``. + + Removed paths should disappear from listings; ``alias[source] = dest`` + means a virtual entry should appear at ``dest`` even if no DB row is + yet there. + """ + removed: set[str] = set() + alias: dict[str, str] = {} + for move in self._pending_moves(): + src = move.get("source") + dst = move.get("dest") + if not src or not dst: + continue + removed.add(src) + alias[src] = dst + existing.pop(src, None) + return removed, alias + + # ------------------------------------------------------------------ ls/read + + async def als_info(self, path: str) -> list[FileInfo]: # type: ignore[override] + normalized = self._normalize_listing_path(path) + infos: list[FileInfo] = [] + seen: set[str] = set() + + anon = self._kb_anon_doc() + if anon: + anon_path = str(anon.get("path") or "") + if ( + anon_path + and _is_under(anon_path, normalized) + and anon_path != normalized + and anon_path not in seen + ): + infos.append( + FileInfo( + path=anon_path, + is_dir=False, + size=len(str(anon.get("content") or "")), + modified_at="", + ) + ) + seen.add(anon_path) + + files = self._state_files() + moved_removed, moved_alias = self._moved_view_paths(files) + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + db_infos, subdir_paths = await self._list_db_directory( + session, normalized + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.als_info DB error: %s", exc) + db_infos, subdir_paths = [], set() + + for info in db_infos: + p = info.get("path", "") + if not p or p in seen or p in moved_removed: + continue + infos.append(info) + seen.add(p) + + for src, dst in moved_alias.items(): + if src not in seen: + if not _is_under(dst, normalized): + continue + rel = ( + dst[len(normalized) :].lstrip("/") + if normalized != "/" + else dst.lstrip("/") + ) + if "/" in rel: + subdir_paths.add( + (normalized.rstrip("/") + "/" + rel.split("/", 1)[0]) + if normalized != "/" + else "/" + rel.split("/", 1)[0] + ) + continue + if dst in seen: + continue + fd = files.get(dst) + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + infos.append( + FileInfo( + path=dst, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(dst) + + for staged in self._staged_dirs(): + if not staged or not staged.startswith(DOCUMENTS_ROOT): + continue + if staged == normalized: + continue + if not _is_under(staged, normalized): + continue + rel = ( + staged[len(normalized) :].lstrip("/") + if normalized != "/" + else staged.lstrip("/") + ) + if not rel: + continue + first = rel.split("/", 1)[0] + immediate = ( + normalized.rstrip("/") + "/" + first + if normalized != "/" + else "/" + first + ) + subdir_paths.add(immediate) + + for sub in sorted(subdir_paths): + if sub in seen: + continue + infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at="")) + seen.add(sub) + + for path_key, fd in files.items(): + if not isinstance(path_key, str) or path_key in seen: + continue + if not _is_under(path_key, normalized) or path_key == normalized: + continue + if normalized == "/": + rel = path_key.lstrip("/") + else: + rel = path_key[len(normalized) :].lstrip("/") + if not rel: + continue + if "/" in rel: + first = rel.split("/", 1)[0] + immediate = ( + normalized.rstrip("/") + "/" + first + if normalized != "/" + else "/" + first + ) + if immediate not in seen: + infos.append( + FileInfo(path=immediate, is_dir=True, size=0, modified_at="") + ) + seen.add(immediate) + continue + include = path_key.startswith(DOCUMENTS_ROOT) or _basename( + path_key + ).startswith(_TEMP_PREFIX) + if not include: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + infos.append( + FileInfo( + path=path_key, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(path_key) + + infos.sort(key=lambda fi: (not fi.get("is_dir", False), fi.get("path", ""))) + return infos + + def ls_info(self, path: str) -> list[FileInfo]: # type: ignore[override] + return asyncio.run(self.als_info(path)) + + async def _list_db_directory( + self, + session: AsyncSession, + normalized_path: str, + ) -> tuple[list[FileInfo], set[str]]: + """List immediate Folders + Documents at ``normalized_path``. + + Returns ``(file_infos, subdirectory_paths)``. ``normalized_path`` may + be ``/`` (synthesizes ``/documents``) or a path under ``/documents``. + """ + if normalized_path == "/": + return ( + [], + {DOCUMENTS_ROOT}, + ) + + if not normalized_path.startswith(DOCUMENTS_ROOT): + return [], set() + + index = await build_path_index(session, self.search_space_id) + target_folder_id: int | None = None + if normalized_path != DOCUMENTS_ROOT: + target_path = normalized_path + matches = [ + fid for fid, fpath in index.folder_paths.items() if fpath == target_path + ] + if not matches: + return [], set() + target_folder_id = matches[0] + + result = await session.execute( + select(Document.id, Document.title, Document.folder_id, Document.updated_at) + .where(Document.search_space_id == self.search_space_id) + .where( + Document.folder_id == target_folder_id + if target_folder_id is not None + else Document.folder_id.is_(None) + ) + ) + rows = result.all() + + file_infos: list[FileInfo] = [] + for row in rows: + path = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + modified = "" + if row.updated_at is not None: + with contextlib.suppress(Exception): + modified = row.updated_at.astimezone(UTC).isoformat() + file_infos.append( + FileInfo( + path=path, + is_dir=False, + size=0, + modified_at=modified, + ) + ) + + subdirs: set[str] = set() + for _fid, fpath in index.folder_paths.items(): + if fpath == normalized_path: + continue + base = normalized_path.rstrip("/") + if not fpath.startswith(base + "/"): + continue + rel = fpath[len(base) + 1 :] + if "/" in rel: + continue + subdirs.add(base + "/" + rel) + return file_infos, subdirs + + async def aread( # type: ignore[override] + self, + file_path: str, + offset: int = 0, + limit: int = 2000, + ) -> str: + files = self._state_files() + file_data = files.get(file_path) + if file_data is not None: + return format_read_response(file_data, offset, limit) + + loaded = await self._load_file_data(file_path) + if loaded is None: + return f"Error: File '{file_path}' not found" + file_data, _ = loaded + return format_read_response(file_data, offset, limit) + + def read(self, file_path: str, offset: int = 0, limit: int = 2000) -> str: # type: ignore[override] + return asyncio.run(self.aread(file_path, offset, limit)) + + async def _load_file_data( + self, + path: str, + ) -> tuple[dict[str, Any], int | None] | None: + """Lazy-load a virtual KB document into a deepagents ``FileData``. + + Returns ``(file_data, doc_id)`` or ``None`` if the path doesn't map + to any known document. ``doc_id`` is ``None`` for the synthetic + anonymous document so the caller doesn't track it as a DB-backed file. + """ + anon = self._kb_anon_doc() + if anon and str(anon.get("path") or "") == path: + doc_payload = { + "document_id": -1, + "chunks": list(anon.get("chunks") or []), + "matched_chunk_ids": [], + "document": { + "id": -1, + "title": anon.get("title") or "uploaded_document", + "document_type": "FILE", + "metadata": {"source": "anonymous_upload"}, + }, + "source": "FILE", + } + xml = build_document_xml(doc_payload, matched_chunk_ids=set()) + file_data = create_file_data(xml) + return file_data, None + + if not path.startswith(DOCUMENTS_ROOT): + return None + + async with shielded_async_session() as session: + document = await virtual_path_to_doc( + session, + search_space_id=self.search_space_id, + virtual_path=path, + ) + if document is None: + return None + chunk_rows = await session.execute( + select(Chunk.id, Chunk.content) + .where(Chunk.document_id == document.id) + .order_by(Chunk.id) + ) + chunks = [ + {"chunk_id": row.id, "content": row.content} for row in chunk_rows.all() + ] + + doc_payload = { + "document_id": document.id, + "chunks": chunks, + "matched_chunk_ids": list(self._matched_chunk_ids(document.id)), + "document": { + "id": document.id, + "title": document.title, + "document_type": ( + document.document_type.value + if getattr(document, "document_type", None) is not None + else "UNKNOWN" + ), + "metadata": dict(document.document_metadata or {}), + }, + "source": ( + document.document_type.value + if getattr(document, "document_type", None) is not None + else "UNKNOWN" + ), + } + xml = build_document_xml( + doc_payload, + matched_chunk_ids=self._matched_chunk_ids(document.id), + ) + file_data = create_file_data(xml) + return file_data, document.id + + # ------------------------------------------------------------------ writes + + async def awrite(self, file_path: str, content: str) -> WriteResult: # type: ignore[override] + files = self._state_files() + if file_path in files: + return WriteResult( + error=( + f"Cannot write to {file_path} because it already exists. " + "Read and then make an edit, or write to a new path." + ) + ) + new_file_data = create_file_data(content) + return WriteResult(path=file_path, files_update={file_path: new_file_data}) + + def write(self, file_path: str, content: str) -> WriteResult: # type: ignore[override] + return asyncio.run(self.awrite(file_path, content)) + + async def aedit( # type: ignore[override] + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + files = self._state_files() + file_data = files.get(file_path) + if file_data is None: + loaded = await self._load_file_data(file_path) + if loaded is None: + return EditResult(error=f"Error: File '{file_path}' not found") + file_data, _ = loaded + + content = file_data_to_string(file_data) + result = perform_string_replacement( + content, old_string, new_string, replace_all + ) + if isinstance(result, str): + return EditResult(error=result) + + new_content, occurrences = result + new_file_data = update_file_data(file_data, new_content) + return EditResult( + path=file_path, + files_update={file_path: new_file_data}, + occurrences=int(occurrences), + ) + + def edit( # type: ignore[override] + self, + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + ) -> EditResult: + return asyncio.run(self.aedit(file_path, old_string, new_string, replace_all)) + + # ------------------------------------------------------------------ glob/grep + + async def aglob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override] + normalized = self._normalize_listing_path(path) + results: list[FileInfo] = [] + seen: set[str] = set() + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + regex = re.compile(fnmatch.translate(pattern)) + for path_key, fd in files.items(): + if path_key in moved_removed: + continue + if not _is_under(path_key, normalized): + continue + rel = ( + path_key[len(normalized) :].lstrip("/") + if normalized != "/" + else path_key.lstrip("/") + ) + if not regex.match(rel) and not regex.match(path_key): + continue + if path_key in seen: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + results.append( + FileInfo( + path=path_key, + is_dir=False, + size=int(size), + modified_at=fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + ) + ) + seen.add(path_key) + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == self.search_space_id + ) + ) + for row in rows.all(): + candidate = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + if candidate in seen or candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + rel = ( + candidate[len(normalized) :].lstrip("/") + if normalized != "/" + else candidate.lstrip("/") + ) + if not regex.match(rel) and not regex.match(candidate): + continue + results.append( + FileInfo( + path=candidate, is_dir=False, size=0, modified_at="" + ) + ) + seen.add(candidate) + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.aglob_info DB error: %s", exc) + + results.sort(key=lambda fi: fi.get("path", "")) + return results + + def glob_info(self, pattern: str, path: str = "/") -> list[FileInfo]: # type: ignore[override] + return asyncio.run(self.aglob_info(pattern, path)) + + async def agrep_raw( # type: ignore[override] + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + ) -> list[GrepMatch] | str: + if not pattern: + return "Error: pattern cannot be empty" + + normalized = self._normalize_listing_path(path or "/") + matches: list[GrepMatch] = [] + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + glob_re = re.compile(fnmatch.translate(glob)) if glob else None + for path_key, fd in files.items(): + if path_key in moved_removed: + continue + if not _is_under(path_key, normalized): + continue + if glob_re is not None and not glob_re.match(_basename(path_key)): + continue + if not isinstance(fd, dict): + continue + for line_no, line in enumerate(fd.get("content") or [], 1): + if pattern in line: + matches.append( + GrepMatch(path=path_key, line=int(line_no), text=str(line)) + ) + if len(matches) >= _GREP_MAX_TOTAL_MATCHES: + return matches + + if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/": + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + sub = ( + select(Chunk.document_id, Chunk.id, Chunk.content) + .join(Document, Document.id == Chunk.document_id) + .where(Document.search_space_id == self.search_space_id) + .where(Chunk.content.ilike(f"%{pattern}%")) + .order_by(Chunk.document_id, Chunk.id) + ) + chunk_rows = await session.execute(sub) + per_doc: dict[int, int] = {} + doc_id_to_path: dict[int, str] = {} + needed_doc_ids: set[int] = set() + chunk_buffer: list[tuple[int, int, str]] = [] + for row in chunk_rows.all(): + per_doc.setdefault(row.document_id, 0) + if per_doc[row.document_id] >= _GREP_MAX_PER_DOC: + continue + per_doc[row.document_id] += 1 + chunk_buffer.append((row.document_id, row.id, row.content)) + needed_doc_ids.add(row.document_id) + if sum(per_doc.values()) >= _GREP_MAX_TOTAL_MATCHES - len( + matches + ): + break + if needed_doc_ids: + doc_rows = await session.execute( + select( + Document.id, Document.title, Document.folder_id + ).where(Document.id.in_(list(needed_doc_ids))) + ) + for row in doc_rows.all(): + doc_id_to_path[row.id] = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + for doc_id, chunk_id, content in chunk_buffer: + candidate = doc_id_to_path.get(doc_id) + if not candidate or candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + if glob_re is not None and not glob_re.match( + _basename(candidate) + ): + continue + snippet = " ".join(str(content).split())[:240] + matches.append( + GrepMatch( + path=candidate, + line=0, + text=( + f": " + f"{snippet}" + ), + ) + ) + if len(matches) >= _GREP_MAX_TOTAL_MATCHES: + break + except Exception as exc: # pragma: no cover - defensive + logger.warning("KBPostgresBackend.agrep_raw DB error: %s", exc) + + return matches + + def grep_raw( # type: ignore[override] + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + ) -> list[GrepMatch] | str: + return asyncio.run(self.agrep_raw(pattern, path, glob)) + + # ------------------------------------------------------------------ list_tree (helper) + + async def alist_tree_listing( + self, + path: str = DOCUMENTS_ROOT, + *, + max_depth: int | None = 8, + page_size: int = 500, + include_files: bool = True, + include_dirs: bool = True, + ) -> dict[str, Any]: + """Recursive tree listing for cloud mode. + + Mirrors the shape returned by :class:`MultiRootLocalFolderBackend.list_tree`: + ``{"entries": [{path, is_dir, size, modified_at, depth}, ...], "truncated": bool}``. + """ + normalized = self._normalize_listing_path(path or DOCUMENTS_ROOT) + if not normalized.startswith(DOCUMENTS_ROOT) and normalized != "/": + return {"error": "Error: path must be under /documents/"} + + entries: list[dict[str, Any]] = [] + truncated = False + + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + doc_rows_raw = await session.execute( + select( + Document.id, + Document.title, + Document.folder_id, + Document.updated_at, + ).where(Document.search_space_id == self.search_space_id) + ) + doc_rows = list(doc_rows_raw.all()) + except Exception as exc: # pragma: no cover + logger.warning("KBPostgresBackend.alist_tree_listing DB error: %s", exc) + return {"entries": [], "truncated": False} + + files = self._state_files() + moved_removed, _ = self._moved_view_paths(files) + anon = self._kb_anon_doc() + anon_path = str(anon.get("path") or "") if anon else "" + + def _depth_of(p: str) -> int: + if p == DOCUMENTS_ROOT: + return 0 + rel_root = ( + p[len(DOCUMENTS_ROOT) :].lstrip("/") + if normalized.startswith(DOCUMENTS_ROOT) + else p.lstrip("/") + ) + return len([part for part in rel_root.split("/") if part]) + + def _add_entry(entry: dict[str, Any]) -> bool: + nonlocal truncated + if len(entries) >= page_size: + truncated = True + return False + entries.append(entry) + return True + + if include_dirs: + for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]): + if not _is_under(fpath, normalized): + continue + depth = _depth_of(fpath) + if max_depth is not None and depth > max_depth: + continue + if not _add_entry( + { + "path": fpath, + "is_dir": True, + "size": 0, + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + for staged in self._staged_dirs(): + if not _is_under(staged, normalized): + continue + depth = _depth_of(staged) + if max_depth is not None and depth > max_depth: + continue + if any(e["path"] == staged for e in entries): + continue + if not _add_entry( + { + "path": staged, + "is_dir": True, + "size": 0, + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + if include_files: + for row in sorted(doc_rows, key=lambda r: str(r.title or "")): + candidate = doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + if candidate in moved_removed: + continue + if not _is_under(candidate, normalized): + continue + depth = _depth_of(candidate) + if max_depth is not None and depth > max_depth: + continue + modified = "" + if row.updated_at is not None: + with contextlib.suppress(Exception): + modified = row.updated_at.astimezone(UTC).isoformat() + if not _add_entry( + { + "path": candidate, + "is_dir": False, + "size": 0, + "modified_at": modified, + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + if anon_path and _is_under(anon_path, normalized): + depth = _depth_of(anon_path) + if (max_depth is None or depth <= max_depth) and not _add_entry( + { + "path": anon_path, + "is_dir": False, + "size": len(str(anon.get("content") or "")), + "modified_at": "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + for path_key, fd in files.items(): + if not isinstance(path_key, str): + continue + if not _is_under(path_key, normalized): + continue + if any(e["path"] == path_key for e in entries): + continue + if not ( + path_key.startswith(DOCUMENTS_ROOT) + or _basename(path_key).startswith(_TEMP_PREFIX) + ): + continue + depth = _depth_of(path_key) + if max_depth is not None and depth > max_depth: + continue + size = self._file_data_size(fd) if isinstance(fd, dict) else 0 + if not _add_entry( + { + "path": path_key, + "is_dir": False, + "size": int(size), + "modified_at": fd.get("modified_at", "") + if isinstance(fd, dict) + else "", + "depth": depth, + } + ): + return {"entries": entries, "truncated": True} + + return {"entries": entries, "truncated": truncated} + + # ------------------------------------------------------------------ uploads (unsupported) + + def upload_files( # type: ignore[override] + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: + msg = "KBPostgresBackend does not support upload_files." + raise NotImplementedError(msg) + + def download_files( # type: ignore[override] + self, paths: list[str] + ) -> list[FileDownloadResponse]: + responses: list[FileDownloadResponse] = [] + files = self._state_files() + for path in paths: + fd = files.get(path) + if fd is None: + responses.append( + FileDownloadResponse( + path=path, content=None, error="file_not_found" + ) + ) + continue + content_str = file_data_to_string(fd) + responses.append( + FileDownloadResponse( + path=path, + content=content_str.encode("utf-8"), + error=None, + ) + ) + return responses + + +# --- module-level small helpers --------------------------------------------- + + +async def list_tree_listing( + backend: KBPostgresBackend, + path: str, + *, + max_depth: int | None = 8, + page_size: int = 500, + include_files: bool = True, + include_dirs: bool = True, +) -> dict[str, Any]: + """Async helper used by the overridden ``list_tree`` tool wrapper.""" + return await backend.alist_tree_listing( + path, + max_depth=max_depth, + page_size=page_size, + include_files=include_files, + include_dirs=include_dirs, + ) + + +__all__ = [ + "KBPostgresBackend", + "list_tree_listing", + "paginate_listing", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py index 51378a013..0820e8c3e 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_search.py @@ -1,10 +1,24 @@ -"""Knowledge-base pre-search middleware for the SurfSense new chat agent. +"""Hybrid-search priority middleware for the SurfSense new chat agent. -This middleware runs before the main agent loop and seeds a virtual filesystem -(`files` state) with relevant documents retrieved via hybrid search. On each -turn the filesystem is *expanded* — new results merge with documents loaded -during prior turns — and a synthetic ``ls`` result is injected into the message -history so the LLM is immediately aware of the current filesystem structure. +This middleware runs ``before_agent`` on every turn and writes: + +* ``state["kb_priority"]`` — the top-K most relevant documents for the + current user message, used to render a ```` system + message immediately before the user turn. +* ``state["kb_matched_chunk_ids"]`` — internal hand-off mapping + (``Document.id`` → matched chunk IDs) consumed by + :class:`KBPostgresBackend._load_file_data` when the agent first reads each + document, so the XML wrapper can flag matched sections in + ````. + +The previous "scoped filesystem" behaviour (synthetic ``ls`` + state +``files`` seeding) is intentionally removed: documents are now lazy-loaded +from Postgres on demand, with the full workspace tree rendered separately +by :class:`KnowledgeTreeMiddleware`. + +In anonymous mode the middleware skips hybrid search entirely and emits a +single-entry priority list pointing at the Redis-loaded document +(``state["kb_anon_doc"]``). """ from __future__ import annotations @@ -13,27 +27,33 @@ import asyncio import json import logging import re -import uuid from collections.abc import Sequence from datetime import UTC, datetime from typing import Any +from langchain.agents import create_agent from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.runnables import Runnable from langgraph.runtime import Runtime from litellm import token_counter from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range +from app.agents.new_chat.feature_flags import get_flags from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + PathIndex, + build_path_index, + doc_to_virtual_path, +) +from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range from app.db import ( NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, - Folder, shielded_async_session, ) from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever @@ -70,7 +90,6 @@ class KBSearchPlan(BaseModel): def _extract_text_from_message(message: BaseMessage) -> str: - """Extract plain text from a message content.""" content = getattr(message, "content", "") if isinstance(content, str): return content @@ -85,19 +104,6 @@ def _extract_text_from_message(message: BaseMessage) -> str: return str(content) -def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str: - """Convert arbitrary text into a filesystem-safe filename.""" - name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip() - name = re.sub(r"\s+", " ", name) - if not name: - name = fallback - if len(name) > 180: - name = name[:180].rstrip() - if not name.lower().endswith(".xml"): - name = f"{name}.xml" - return name - - def _render_recent_conversation( messages: Sequence[BaseMessage], *, @@ -107,10 +113,9 @@ def _render_recent_conversation( ) -> str: """Render recent dialogue for internal planning under a token budget. - Prefers the latest messages and uses the project's existing model-aware - token budgeting hooks when available on the LLM (`_count_tokens`, - `_get_max_input_tokens`). Falls back to the prior fixed-message heuristic - if token counting is unavailable. + Filters to ``HumanMessage`` and ``AIMessage`` (without tool_calls) so that + injected ``SystemMessage`` artefacts (priority list, workspace tree, + file-write contract) don't pollute the planner prompt. """ rendered: list[tuple[str, str]] = [] for message in messages: @@ -133,8 +138,6 @@ def _render_recent_conversation( if not rendered: return "" - # Exclude the latest user message from "recent conversation" because it is - # already passed separately as "Latest user message" in the planner prompt. if rendered and rendered[-1][0] == "user" and rendered[-1][1] == user_text.strip(): rendered = rendered[:-1] @@ -216,8 +219,6 @@ def _render_recent_conversation( selected_lines = candidate_lines continue - # If the full message does not fit, keep as much of this most-recent - # older message as possible via binary search. lo, hi = 1, len(text) best_line: str | None = None while lo <= hi: @@ -249,7 +250,6 @@ def _build_kb_planner_prompt( recent_conversation: str, user_text: str, ) -> str: - """Build a compact internal prompt for KB query rewriting and date scoping.""" today = datetime.now(UTC).date().isoformat() return ( "You optimize internal knowledge-base search inputs for document retrieval.\n" @@ -275,12 +275,10 @@ def _build_kb_planner_prompt( def _extract_json_payload(text: str) -> str: - """Extract a JSON object from a raw LLM response.""" stripped = text.strip() fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", stripped, re.DOTALL) if fenced: return fenced.group(1) - start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: @@ -289,7 +287,6 @@ def _extract_json_payload(text: str) -> str: def _parse_kb_search_plan_response(response_text: str) -> KBSearchPlan: - """Parse and validate the planner's JSON response.""" payload = json.loads(_extract_json_payload(response_text)) return KBSearchPlan.model_validate(payload) @@ -298,212 +295,19 @@ def _normalize_optional_date_range( start_date: str | None, end_date: str | None, ) -> tuple[datetime | None, datetime | None]: - """Normalize optional planner dates into a UTC datetime range.""" parsed_start = parse_date_or_datetime(start_date) if start_date else None parsed_end = parse_date_or_datetime(end_date) if end_date else None if parsed_start is None and parsed_end is None: return None, None - resolved_start, resolved_end = resolve_date_range(parsed_start, parsed_end) - return resolved_start, resolved_end - - -def _build_document_xml( - document: dict[str, Any], - matched_chunk_ids: set[int] | None = None, -) -> str: - """Build citation-friendly XML with a ```` for smart seeking. - - The ```` at the top of each document lists every chunk with its - line range inside ```` and flags chunks that directly - matched the search query (``matched="true"``). This lets the LLM jump - straight to the most relevant section via ``read_file(offset=…, limit=…)`` - instead of reading sequentially from the start. - """ - matched = matched_chunk_ids or set() - - doc_meta = document.get("document") or {} - metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {} - document_id = doc_meta.get("id", document.get("document_id", "unknown")) - document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN")) - title = doc_meta.get("title") or metadata.get("title") or "Untitled Document" - url = ( - metadata.get("url") or metadata.get("source") or metadata.get("page_url") or "" - ) - metadata_json = json.dumps(metadata, ensure_ascii=False) - - # --- 1. Metadata header (fixed structure) --- - metadata_lines: list[str] = [ - "", - "", - f" {document_id}", - f" {document_type}", - f" <![CDATA[{title}]]>", - f" ", - f" ", - "", - "", - ] - - # --- 2. Pre-build chunk XML strings to compute line counts --- - chunks = document.get("chunks") or [] - chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string) - if isinstance(chunks, list): - for chunk in chunks: - if not isinstance(chunk, dict): - continue - chunk_id = chunk.get("chunk_id") or chunk.get("id") - chunk_content = str(chunk.get("content", "")).strip() - if not chunk_content: - continue - if chunk_id is None: - xml = f" " - else: - xml = f" " - chunk_entries.append((chunk_id, xml)) - - # --- 3. Compute line numbers for every chunk --- - # Layout (1-indexed lines for read_file): - # metadata_lines -> len(metadata_lines) lines - # -> 1 line - # index entries -> len(chunk_entries) lines - # -> 1 line - # (empty line) -> 1 line - # -> 1 line - # chunk xml lines… - # -> 1 line - # -> 1 line - index_overhead = ( - 1 + len(chunk_entries) + 1 + 1 + 1 - ) # tags + empty + - first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed - - current_line = first_chunk_line - index_entry_lines: list[str] = [] - for cid, xml_str in chunk_entries: - num_lines = xml_str.count("\n") + 1 - end_line = current_line + num_lines - 1 - matched_attr = ' matched="true"' if cid is not None and cid in matched else "" - if cid is not None: - index_entry_lines.append( - f' ' - ) - else: - index_entry_lines.append( - f' ' - ) - current_line = end_line + 1 - - # --- 4. Assemble final XML --- - lines = metadata_lines.copy() - lines.append("") - lines.extend(index_entry_lines) - lines.append("") - lines.append("") - lines.append("") - for _, xml_str in chunk_entries: - lines.append(xml_str) - lines.extend(["", ""]) - return "\n".join(lines) - - -async def _get_folder_paths( - session: AsyncSession, search_space_id: int -) -> dict[int, str]: - """Return a map of folder_id -> virtual folder path under /documents.""" - result = await session.execute( - select(Folder.id, Folder.name, Folder.parent_id).where( - Folder.search_space_id == search_space_id - ) - ) - rows = result.all() - by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows} - - cache: dict[int, str] = {} - - def resolve_path(folder_id: int) -> str: - if folder_id in cache: - return cache[folder_id] - parts: list[str] = [] - cursor: int | None = folder_id - visited: set[int] = set() - while cursor is not None and cursor in by_id and cursor not in visited: - visited.add(cursor) - entry = by_id[cursor] - parts.append( - _safe_filename(str(entry["name"]), fallback="folder").removesuffix( - ".xml" - ) - ) - cursor = entry["parent_id"] - parts.reverse() - path = "/documents/" + "/".join(parts) if parts else "/documents" - cache[folder_id] = path - return path - - for folder_id in by_id: - resolve_path(folder_id) - return cache - - -def _build_synthetic_ls( - existing_files: dict[str, Any] | None, - new_files: dict[str, Any], - *, - mentioned_paths: set[str] | None = None, -) -> tuple[AIMessage, ToolMessage]: - """Build a synthetic ls("/documents") tool-call + result for the LLM context. - - Mentioned files are listed first. A separate header tells the LLM which - files the user explicitly selected; the path list itself stays clean so - paths can be passed directly to ``read_file`` without stripping tags. - """ - _mentioned = mentioned_paths or set() - merged: dict[str, Any] = {**(existing_files or {}), **new_files} - doc_paths = [ - p for p, v in merged.items() if p.startswith("/documents/") and v is not None - ] - - new_set = set(new_files) - mentioned_list = [p for p in doc_paths if p in _mentioned] - new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned] - old_paths = [p for p in doc_paths if p not in new_set] - ordered = mentioned_list + new_non_mentioned + old_paths - - parts: list[str] = [] - if mentioned_list: - parts.append( - "USER-MENTIONED documents (read these thoroughly before answering):" - ) - for p in mentioned_list: - parts.append(f" {p}") - parts.append("") - parts.append(str(ordered) if ordered else "No documents found.") - - tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}" - ai_msg = AIMessage( - content="", - tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}], - ) - tool_msg = ToolMessage( - content="\n".join(parts), - tool_call_id=tool_call_id, - ) - return ai_msg, tool_msg + return resolve_date_range(parsed_start, parsed_end) def _resolve_search_types( available_connectors: list[str] | None, available_document_types: list[str] | None, ) -> list[str] | None: - """Build a flat list of document-type strings for the chunk retriever. - - Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that - old documents indexed under Composio names are still found. - - Returns ``None`` when no filtering is desired (search all types). - """ types: set[str] = set() if available_document_types: types.update(available_document_types) @@ -531,13 +335,8 @@ async def browse_recent_documents( start_date: datetime | None = None, end_date: datetime | None = None, ) -> list[dict[str, Any]]: - """Return documents ordered by recency (newest first), no relevance ranking. - - Used when the user's intent is temporal ("latest file", "most recent upload") - and hybrid search would produce poor results because the query has no - meaningful topical signal. - """ - from sqlalchemy import func, select + """Return documents ordered by recency (newest first), no relevance ranking.""" + from sqlalchemy import func from app.db import DocumentType @@ -581,7 +380,6 @@ async def browse_recent_documents( return [] doc_ids = [d.id for d in documents] - numbered = ( select( Chunk.id.label("chunk_id"), @@ -632,6 +430,7 @@ async def browse_recent_documents( else None ), "metadata": metadata, + "folder_id": getattr(doc, "folder_id", None), }, "source": ( doc.document_type.value @@ -640,12 +439,6 @@ async def browse_recent_documents( ), } ) - - logger.info( - "browse_recent_documents: %d docs returned for space=%d", - len(results), - search_space_id, - ) return results @@ -659,17 +452,11 @@ async def search_knowledge_base( start_date: datetime | None = None, end_date: datetime | None = None, ) -> list[dict[str, Any]]: - """Run a single unified hybrid search against the knowledge base. - - Uses one ``ChucksHybridSearchRetriever`` call across all document types - instead of fanning out per-connector. This reduces the number of DB - queries from ~10 to 2 (one RRF query + one chunk fetch). - """ + """Run a single unified hybrid search against the knowledge base.""" if not query: return [] [embedding] = embed_texts([query]) - doc_types = _resolve_search_types(available_connectors, available_document_types) retriever_top_k = min(top_k * 3, 30) @@ -693,14 +480,7 @@ async def fetch_mentioned_documents( document_ids: list[int], search_space_id: int, ) -> list[dict[str, Any]]: - """Fetch explicitly mentioned documents with *all* their chunks. - - Returns the same dict structure as ``search_knowledge_base`` so results - can be merged directly into ``build_scoped_filesystem``. Unlike search - results, every chunk is included (no top-K limiting) and none are marked - as ``matched`` since the entire document is relevant by virtue of the - user's explicit mention. - """ + """Fetch explicitly mentioned documents.""" if not document_ids: return [] @@ -750,6 +530,7 @@ async def fetch_mentioned_documents( else None ), "metadata": metadata, + "folder_id": getattr(doc, "folder_id", None), }, "source": ( doc.document_type.value @@ -762,96 +543,36 @@ async def fetch_mentioned_documents( return results -async def build_scoped_filesystem( - *, - documents: Sequence[dict[str, Any]], - search_space_id: int, -) -> tuple[dict[str, dict[str, str]], dict[int, str]]: - """Build a StateBackend-compatible files dict from search results. - - Returns ``(files, doc_id_to_path)`` so callers can reliably map a - document id back to its filesystem path without guessing by title. - Paths are collision-proof: when two documents resolve to the same - path the doc-id is appended to disambiguate. - """ - async with shielded_async_session() as session: - folder_paths = await _get_folder_paths(session, search_space_id) - doc_ids = [ - (doc.get("document") or {}).get("id") - for doc in documents - if isinstance(doc, dict) - ] - doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)] - folder_by_doc_id: dict[int, int | None] = {} - if doc_ids: - doc_rows = await session.execute( - select(Document.id, Document.folder_id).where( - Document.search_space_id == search_space_id, - Document.id.in_(doc_ids), - ) - ) - folder_by_doc_id = { - row.id: row.folder_id for row in doc_rows.all() if row.id is not None - } - - files: dict[str, dict[str, str]] = {} - doc_id_to_path: dict[int, str] = {} - for document in documents: - doc_meta = document.get("document") or {} - title = str(doc_meta.get("title") or "untitled") - doc_id = doc_meta.get("id") - folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None - base_folder = folder_paths.get(folder_id, "/documents") - file_name = _safe_filename(title) - path = f"{base_folder}/{file_name}" - if path in files: - stem = file_name.removesuffix(".xml") - path = f"{base_folder}/{stem} ({doc_id}).xml" - matched_ids = set(document.get("matched_chunk_ids") or []) - xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids) - files[path] = { - "content": xml_content.split("\n"), - "encoding": "utf-8", - "created_at": "", - "modified_at": "", - } - if isinstance(doc_id, int): - doc_id_to_path[doc_id] = path - return files, doc_id_to_path +def _render_priority_message(priority: list[dict[str, Any]]) -> SystemMessage: + """Render the priority list as a single ```` system message.""" + if not priority: + body = "(no priority documents for this turn)" + else: + lines: list[str] = [] + for entry in priority: + score = entry.get("score") + mentioned = entry.get("mentioned") + score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a" + mark = " [USER-MENTIONED]" if mentioned else "" + lines.append(f"- {entry.get('path', '')} (score={score_str}){mark}") + body = "\n".join(lines) + return SystemMessage( + content=( + "\n" + "These documents are most relevant to the latest user message; " + "read them first. Matched sections are flagged inside each " + "document's .\n" + f"{body}\n" + "" + ) + ) -def _build_anon_scoped_filesystem( - documents: Sequence[dict[str, Any]], -) -> dict[str, dict[str, str]]: - """Build a scoped filesystem for anonymous documents without DB queries. - - Anonymous uploads have no folders, so all files go under /documents. - """ - files: dict[str, dict[str, str]] = {} - for document in documents: - doc_meta = document.get("document") or {} - title = str(doc_meta.get("title") or "untitled") - file_name = _safe_filename(title) - path = f"/documents/{file_name}" - if path in files: - doc_id = doc_meta.get("id", "dup") - stem = file_name.removesuffix(".xml") - path = f"/documents/{stem} ({doc_id}).xml" - matched_ids = set(document.get("matched_chunk_ids") or []) - xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids) - files[path] = { - "content": xml_content.split("\n"), - "encoding": "utf-8", - "created_at": "", - "modified_at": "", - } - return files - - -class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] - """Pre-agent middleware that always searches the KB and seeds a scoped filesystem.""" +class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Compute hybrid-search priority hints for the current turn.""" tools = () + state_schema = SurfSenseFilesystemState def __init__( self, @@ -863,7 +584,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] available_document_types: list[str] | None = None, top_k: int = 10, mentioned_document_ids: list[int] | None = None, - anon_session_id: str | None = None, ) -> None: self.llm = llm self.search_space_id = search_space_id @@ -872,7 +592,51 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] self.available_document_types = available_document_types self.top_k = top_k self.mentioned_document_ids = mentioned_document_ids or [] - self.anon_session_id = anon_session_id + # Build the kb-planner private Runnable ONCE here so we don't pay + # the ``create_agent`` compile cost (50-200ms) on every turn. + # Disabled by default behind ``enable_kb_planner_runnable``; when + # off the planner falls back to the legacy ``self.llm.ainvoke`` + # path. + self._planner: Runnable | None = None + self._planner_compile_failed = False + + def _build_kb_planner_runnable(self) -> Runnable | None: + """Compile the kb-planner private :class:`Runnable` once. + + Returns ``None`` when the feature flag is disabled, when the LLM is + unavailable, or when ``create_agent`` raises (we fall back to the + legacy ``self.llm.ainvoke`` path in that case). Compilation happens + lazily on first call, then memoized via ``self._planner``. + + The compiled agent is constructed without tools — the planner's + contract is "answer with structured JSON" — but it inherits the + :class:`RetryAfterMiddleware` so transient rate-limit errors + from the planner LLM call don't fail the whole turn. + """ + if self._planner is not None or self._planner_compile_failed: + return self._planner + if self.llm is None: + return None + flags = get_flags() + if not flags.enable_kb_planner_runnable or flags.disable_new_agent_stack: + return None + + from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware + + try: + self._planner = create_agent( + self.llm, + tools=[], + middleware=[RetryAfterMiddleware(max_retries=2)], + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "kb-planner Runnable compile failed; falling back to llm.ainvoke: %s", + exc, + ) + self._planner_compile_failed = True + self._planner = None + return self._planner async def _plan_search_inputs( self, @@ -880,10 +644,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] messages: Sequence[BaseMessage], user_text: str, ) -> tuple[str, datetime | None, datetime | None, bool]: - """Rewrite the KB query and infer optional date filters with the LLM. - - Returns (optimized_query, start_date, end_date, is_recency_query). - """ if self.llm is None: return user_text, None, None, False @@ -899,11 +659,32 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] loop = asyncio.get_running_loop() t0 = loop.time() + # Prefer the compiled-once planner Runnable when enabled; otherwise + # fall back to ``self.llm.ainvoke``. The ``surfsense:internal`` tag + # is preserved on both paths so ``_stream_agent_events`` still + # suppresses the planner's intermediate events from the UI. + planner = self._build_kb_planner_runnable() try: - response = await self.llm.ainvoke( - [HumanMessage(content=prompt)], - config={"tags": ["surfsense:internal"]}, - ) + if planner is not None: + planner_state = await planner.ainvoke( + {"messages": [HumanMessage(content=prompt)]}, + config={"tags": ["surfsense:internal"]}, + ) + response_messages = ( + planner_state.get("messages", []) + if isinstance(planner_state, dict) + else [] + ) + response = ( + response_messages[-1] + if response_messages + else AIMessage(content="") + ) + else: + response = await self.llm.ainvoke( + [HumanMessage(content=prompt)], + config={"tags": ["surfsense:internal"]}, + ) plan = _parse_kb_search_plan_response(_extract_text_from_message(response)) optimized_query = ( re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text @@ -914,7 +695,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] ) is_recency = plan.is_recency_query _perf_log.info( - "[kb_fs_middleware] planner in %.3fs query=%r optimized=%r " + "[kb_priority] planner in %.3fs query=%r optimized=%r " "start=%s end=%s recency=%s", loop.time() - t0, user_text[:80], @@ -946,106 +727,68 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] pass return asyncio.run(self.abefore_agent(state, runtime)) - async def _load_anon_document(self) -> dict[str, Any] | None: - """Load the anonymous user's uploaded document from Redis.""" - if not self.anon_session_id: - return None - try: - import redis.asyncio as aioredis - - from app.config import config - - redis_client = aioredis.from_url( - config.REDIS_APP_URL, decode_responses=True - ) - try: - redis_key = f"anon:doc:{self.anon_session_id}" - data = await redis_client.get(redis_key) - if not data: - return None - doc = json.loads(data) - return { - "document_id": -1, - "content": doc.get("content", ""), - "score": 1.0, - "chunks": [ - { - "chunk_id": -1, - "content": doc.get("content", ""), - } - ], - "matched_chunk_ids": [-1], - "document": { - "id": -1, - "title": doc.get("filename", "uploaded_document"), - "document_type": "FILE", - "metadata": {"source": "anonymous_upload"}, - }, - "source": "FILE", - "_user_mentioned": True, - } - finally: - await redis_client.aclose() - except Exception as exc: - logger.warning("Failed to load anonymous document from Redis: %s", exc) - return None - async def abefore_agent( # type: ignore[override] self, state: AgentState, runtime: Runtime[Any], ) -> dict[str, Any] | None: del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + messages = state.get("messages") or [] if not messages: return None - if self.filesystem_mode != FilesystemMode.CLOUD: - # Local-folder mode should not seed cloud KB documents into filesystem. - return None - last_human = None + last_human: HumanMessage | None = None for msg in reversed(messages): if isinstance(msg, HumanMessage): last_human = msg break if last_human is None: return None - user_text = _extract_text_from_message(last_human).strip() if not user_text: return None - t0 = _perf_log and asyncio.get_event_loop().time() - existing_files = state.get("files") + anon_doc = state.get("kb_anon_doc") + if anon_doc: + return self._anon_priority(state, anon_doc) - # --- Anonymous session: load Redis doc and skip DB queries --- - if self.anon_session_id: - merged: list[dict[str, Any]] = [] - anon_doc = await self._load_anon_document() - if anon_doc: - merged.append(anon_doc) + return await self._authenticated_priority(state, messages, user_text) - if merged: - new_files = _build_anon_scoped_filesystem(merged) - mentioned_paths = set(new_files.keys()) - else: - new_files = {} - mentioned_paths = set() + def _anon_priority( + self, + state: AgentState, + anon_doc: dict[str, Any], + ) -> dict[str, Any]: + path = str(anon_doc.get("path") or "") + title = str(anon_doc.get("title") or "uploaded_document") + priority = [ + { + "path": path, + "score": 1.0, + "document_id": None, + "title": title, + "mentioned": True, + } + ] + new_messages = list(state.get("messages") or []) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, _render_priority_message(priority)) + return { + "kb_priority": priority, + "kb_matched_chunk_ids": {}, + "messages": new_messages, + } - ai_msg, tool_msg = _build_synthetic_ls( - existing_files, - new_files, - mentioned_paths=mentioned_paths, - ) - if t0 is not None: - _perf_log.info( - "[kb_fs_middleware] anon completed in %.3fs new_files=%d", - asyncio.get_event_loop().time() - t0, - len(new_files), - ) - return {"files": new_files, "messages": [ai_msg, tool_msg]} - - # --- Authenticated session: full KB search --- + async def _authenticated_priority( + self, + state: AgentState, + messages: Sequence[BaseMessage], + user_text: str, + ) -> dict[str, Any]: + t0 = asyncio.get_event_loop().time() ( planned_query, start_date, @@ -1056,7 +799,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] user_text=user_text, ) - # --- 1. Fetch mentioned documents (user-selected, all chunks) --- mentioned_results: list[dict[str, Any]] = [] if self.mentioned_document_ids: mentioned_results = await fetch_mentioned_documents( @@ -1065,7 +807,6 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] ) self.mentioned_document_ids = [] - # --- 2. Run KB search (recency browse or hybrid) --- if is_recency: doc_types = _resolve_search_types( self.available_connectors, self.available_document_types @@ -1088,48 +829,108 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg] end_date=end_date, ) - # --- 3. Merge: mentioned first, then search (dedup by doc id) --- seen_doc_ids: set[int] = set() - merged_auth: list[dict[str, Any]] = [] + merged: list[dict[str, Any]] = [] for doc in mentioned_results: doc_id = (doc.get("document") or {}).get("id") - if doc_id is not None: + if isinstance(doc_id, int): seen_doc_ids.add(doc_id) - merged_auth.append(doc) + merged.append(doc) for doc in search_results: doc_id = (doc.get("document") or {}).get("id") - if doc_id is not None and doc_id in seen_doc_ids: + if isinstance(doc_id, int) and doc_id in seen_doc_ids: continue - merged_auth.append(doc) + merged.append(doc) - # --- 4. Build scoped filesystem --- - new_files, doc_id_to_path = await build_scoped_filesystem( - documents=merged_auth, - search_space_id=self.search_space_id, + priority, matched_chunk_ids = await self._materialize_priority(merged) + + new_messages = list(messages) + insert_at = max(len(new_messages) - 1, 0) + new_messages.insert(insert_at, _render_priority_message(priority)) + + _perf_log.info( + "[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d", + asyncio.get_event_loop().time() - t0, + user_text[:80], + len(priority), + len(mentioned_results), ) - mentioned_doc_ids = { - (d.get("document") or {}).get("id") for d in mentioned_results - } - mentioned_paths = { - doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path + return { + "kb_priority": priority, + "kb_matched_chunk_ids": matched_chunk_ids, + "messages": new_messages, } - ai_msg, tool_msg = _build_synthetic_ls( - existing_files, - new_files, - mentioned_paths=mentioned_paths, - ) + async def _materialize_priority( + self, merged: list[dict[str, Any]] + ) -> tuple[list[dict[str, Any]], dict[int, list[int]]]: + """Resolve canonical paths and matched chunk ids for the priority list.""" + priority: list[dict[str, Any]] = [] + matched_chunk_ids: dict[int, list[int]] = {} - if t0 is not None: - _perf_log.info( - "[kb_fs_middleware] completed in %.3fs query=%r optimized=%r " - "mentioned=%d new_files=%d total=%d", - asyncio.get_event_loop().time() - t0, - user_text[:80], - planned_query[:120], - len(mentioned_results), - len(new_files), - len(new_files) + len(existing_files or {}), + if not merged: + return priority, matched_chunk_ids + + async with shielded_async_session() as session: + index: PathIndex = await build_path_index(session, self.search_space_id) + doc_ids = [ + (doc.get("document") or {}).get("id") + for doc in merged + if isinstance(doc, dict) + ] + doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)] + folder_by_doc_id: dict[int, int | None] = {} + if doc_ids: + folder_rows = await session.execute( + select(Document.id, Document.folder_id).where( + Document.search_space_id == self.search_space_id, + Document.id.in_(doc_ids), + ) + ) + folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()} + + for doc in merged: + doc_meta = doc.get("document") or {} + doc_id = doc_meta.get("id") + title = doc_meta.get("title") or "untitled" + folder_id = ( + folder_by_doc_id.get(doc_id) + if isinstance(doc_id, int) + else doc_meta.get("folder_id") ) - return {"files": new_files, "messages": [ai_msg, tool_msg]} + path = doc_to_virtual_path( + doc_id=doc_id if isinstance(doc_id, int) else None, + title=str(title), + folder_id=folder_id if isinstance(folder_id, int) else None, + index=index, + ) + priority.append( + { + "path": path, + "score": float(doc.get("score") or 0.0), + "document_id": doc_id if isinstance(doc_id, int) else None, + "title": str(title), + "mentioned": bool(doc.get("_user_mentioned")), + } + ) + if isinstance(doc_id, int): + chunk_ids = doc.get("matched_chunk_ids") or [] + if chunk_ids: + matched_chunk_ids[doc_id] = [ + int(cid) for cid in chunk_ids if isinstance(cid, int | str) + ] + return priority, matched_chunk_ids + + +# Backwards-compatible alias for any external imports. +KnowledgeBaseSearchMiddleware = KnowledgePriorityMiddleware + + +__all__ = [ + "KnowledgeBaseSearchMiddleware", + "KnowledgePriorityMiddleware", + "browse_recent_documents", + "fetch_mentioned_documents", + "search_knowledge_base", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py new file mode 100644 index 000000000..467d19747 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/knowledge_tree.py @@ -0,0 +1,272 @@ +"""Workspace-tree middleware for the SurfSense agent. + +Renders the full ``Folder``+``Document`` tree under ``/documents/`` once per +turn (cloud only), caches it by ``(search_space_id, tree_version)``, and +injects the result as a ```` system message immediately +before the latest human turn. + +The render is bounded by two truncation layers: + +1. **Entry cap** — at most ``MAX_TREE_ENTRIES`` lines. The remainder is + replaced with a "use ls" hint. +2. **Token cap** — at most ``MAX_TREE_TOKENS`` tokens (using the LLM's + token-count profile when available). If the entry-truncated tree still + exceeds the token cap we fall back to a root-only summary. + +Anonymous mode renders only ``state['kb_anon_doc']`` (no DB calls). + +This middleware also performs a one-time initialization of ``state['cwd']`` +to ``"/documents"`` so subsequent middlewares and tools always see a valid +cwd in cloud mode. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import SystemMessage +from langgraph.runtime import Runtime +from sqlalchemy import select + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + PathIndex, + build_path_index, + doc_to_virtual_path, +) +from app.db import Document, shielded_async_session + +try: + from litellm import token_counter +except Exception: # pragma: no cover - optional dep + token_counter = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + + +MAX_TREE_ENTRIES = 500 +MAX_TREE_TOKENS = 4000 + + +def _approx_tokens(text: str) -> int: + """Cheap fallback token estimate (1 token ~= 4 chars).""" + return max(1, (len(text) + 3) // 4) + + +def _count_tokens(text: str, *, llm: BaseChatModel | None) -> int: + if llm is None: + return _approx_tokens(text) + count_fn = getattr(llm, "_count_tokens", None) + if callable(count_fn): + try: + return int(count_fn([{"role": "user", "content": text}])) + except Exception: + pass + profile = getattr(llm, "profile", None) + model_names: list[str] = [] + if isinstance(profile, dict): + tcms = profile.get("token_count_models") + if isinstance(tcms, list): + model_names.extend(name for name in tcms if isinstance(name, str) and name) + tcm = profile.get("token_count_model") + if isinstance(tcm, str) and tcm and tcm not in model_names: + model_names.append(tcm) + model_name = model_names[0] if model_names else getattr(llm, "model", None) + if not isinstance(model_name, str) or not model_name or token_counter is None: + return _approx_tokens(text) + try: + return int( + token_counter( + messages=[{"role": "user", "content": text}], + model=model_name, + ) + ) + except Exception: + return _approx_tokens(text) + + +class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Inject the workspace folder/document tree into the agent's context.""" + + tools = () + state_schema = SurfSenseFilesystemState + + def __init__( + self, + *, + search_space_id: int, + filesystem_mode: FilesystemMode, + llm: BaseChatModel | None = None, + max_entries: int = MAX_TREE_ENTRIES, + max_tokens: int = MAX_TREE_TOKENS, + ) -> None: + self.search_space_id = search_space_id + self.filesystem_mode = filesystem_mode + self.llm = llm + self.max_entries = max_entries + self.max_tokens = max_tokens + self._cache: dict[tuple[int, int, bool], str] = {} + + async def abefore_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime + if self.filesystem_mode != FilesystemMode.CLOUD: + return None + + update: dict[str, Any] = {} + if not state.get("cwd"): + update["cwd"] = DOCUMENTS_ROOT + + anon_doc = state.get("kb_anon_doc") + if anon_doc: + tree_msg = self._render_anon_tree(anon_doc) + else: + tree_msg = await self._render_kb_tree(state) + + messages = list(state.get("messages") or []) + insert_at = max(len(messages) - 1, 0) + messages.insert(insert_at, SystemMessage(content=tree_msg)) + update["messages"] = messages + return update + + def before_agent( # type: ignore[override] + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + try: + loop = asyncio.get_running_loop() + if loop.is_running(): + return None + except RuntimeError: + pass + return asyncio.run(self.abefore_agent(state, runtime)) + + # ------------------------------------------------------------------ render + + def _render_anon_tree(self, anon_doc: dict[str, Any]) -> str: + path = str(anon_doc.get("path") or "") + title = str(anon_doc.get("title") or "uploaded_document") + return ( + "\n" + "Anonymous session — only one read-only document is available.\n" + f"{DOCUMENTS_ROOT}/\n" + f" {path} — {title}\n" + "" + ) + + async def _render_kb_tree(self, state: AgentState) -> str: + version = int(state.get("tree_version") or 0) + cache_key = (self.search_space_id, version, False) + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + try: + async with shielded_async_session() as session: + index = await build_path_index(session, self.search_space_id) + doc_rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == self.search_space_id + ) + ) + docs = list(doc_rows.all()) + except Exception as exc: # pragma: no cover - defensive + logger.warning("knowledge_tree: DB error %s", exc) + return "\n(unavailable)\n" + + rendered = self._format_tree(index, docs) + self._cache[cache_key] = rendered + return rendered + + def _format_tree(self, index: PathIndex, docs: list[Any]) -> str: + folder_paths = sorted(set(index.folder_paths.values())) + doc_paths = sorted( + doc_to_virtual_path( + doc_id=row.id, + title=str(row.title or "untitled"), + folder_id=row.folder_id, + index=index, + ) + for row in docs + ) + all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT])) + + lines: list[str] = [] + for path in all_paths: + depth = ( + 0 + if path == DOCUMENTS_ROOT + else len([p for p in path[len(DOCUMENTS_ROOT) :].split("/") if p]) + ) + indent = " " * depth + is_dir = path == DOCUMENTS_ROOT or path in folder_paths + display = ( + path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents" + ) + if is_dir: + lines.append(f"{indent}{display}/") + else: + lines.append(f"{indent}{display}") + if len(lines) >= self.max_entries: + remaining = len(all_paths) - len(lines) + if remaining > 0: + lines.append( + f"... {remaining} more entries — use " + "ls('/documents/', offset, limit) to expand" + ) + break + + body = "\n".join(lines) + rendered = f"\n{body}\n" + + token_count = _count_tokens(rendered, llm=self.llm) + if token_count <= self.max_tokens: + return rendered + + return self._format_root_summary(folder_paths, doc_paths) + + def _format_root_summary( + self, folder_paths: list[str], doc_paths: list[str] + ) -> str: + top_level: dict[str, int] = {} + loose_docs = 0 + for path in doc_paths: + rel = path[len(DOCUMENTS_ROOT) :].lstrip("/") + if "/" in rel: + top = rel.split("/", 1)[0] + top_level[top] = top_level.get(top, 0) + 1 + else: + loose_docs += 1 + for path in folder_paths: + rel = path[len(DOCUMENTS_ROOT) :].lstrip("/") + if not rel: + continue + top = rel.split("/", 1)[0] + top_level.setdefault(top, 0) + + lines = [DOCUMENTS_ROOT + "/"] + for name in sorted(top_level): + count = top_level[name] + lines.append(f" {name}/ ({count} document{'s' if count != 1 else ''})") + if loose_docs: + lines.append( + f" ({loose_docs} loose document{'s' if loose_docs != 1 else ''})" + ) + lines.append( + "Tree is large; use list_tree('/documents/') to drill in " + "or ls('/documents/', offset, limit) for paginated listings." + ) + return "\n" + "\n".join(lines) + "\n" + + +__all__ = ["KnowledgeTreeMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py index 0cee3e007..565fcb48b 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/local_folder_backend.py @@ -120,7 +120,9 @@ class LocalFolderBackend: if not target.exists() or not target.is_dir(): return [] infos: list[FileInfo] = [] - for child in sorted(target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())): + for child in sorted( + target.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()) + ): infos.append( FileInfo( path=self._to_virtual(child, self._root), @@ -317,7 +319,9 @@ class LocalFolderBackend: return WriteResult(error="Error: source and destination paths are the same") with self._acquire_path_locks(source_path, destination_path): if not source.exists(): - return WriteResult(error=f"Error: source path '{source_path}' not found") + return WriteResult( + error=f"Error: source path '{source_path}' not found" + ) if destination.exists(): if not overwrite: return WriteResult( @@ -339,8 +343,12 @@ class LocalFolderBackend: else: source.rename(destination) except OSError as exc: - return WriteResult(error=f"Error: failed to move '{source_path}': {exc}") - return WriteResult(path=self._to_virtual(destination, self._root), files_update=None) + return WriteResult( + error=f"Error: failed to move '{source_path}': {exc}" + ) + return WriteResult( + path=self._to_virtual(destination, self._root), files_update=None + ) async def amove( self, @@ -368,12 +376,16 @@ class LocalFolderBackend: if not path.exists() or not path.is_file(): return EditResult(error=f"Error: File '{file_path}' not found") content = path.read_text(encoding="utf-8", errors="replace") - result = perform_string_replacement(content, old_string, new_string, replace_all) + result = perform_string_replacement( + content, old_string, new_string, replace_all + ) if isinstance(result, str): return EditResult(error=result) updated_content, occurrences = result self._write_text_atomic(path, updated_content) - return EditResult(path=file_path, files_update=None, occurrences=int(occurrences)) + return EditResult( + path=file_path, files_update=None, occurrences=int(occurrences) + ) async def aedit( self, @@ -447,7 +459,9 @@ class LocalFolderBackend: matches: list[GrepMatch] = [] for file_path in self._iter_candidate_files(path, glob): try: - lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines() + lines = file_path.read_text( + encoding="utf-8", errors="replace" + ).splitlines() except Exception: continue for idx, line in enumerate(lines, start=1): @@ -481,12 +495,18 @@ class LocalFolderBackend: FileUploadResponse(path=virtual_path, error=_FILE_NOT_FOUND) ) except IsADirectoryError: - responses.append(FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY)) + responses.append( + FileUploadResponse(path=virtual_path, error=_IS_DIRECTORY) + ) except Exception: - responses.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + responses.append( + FileUploadResponse(path=virtual_path, error=_INVALID_PATH) + ) return responses - async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + async def aupload_files( + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: return await asyncio.to_thread(self.upload_files, files) def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: @@ -515,7 +535,9 @@ class LocalFolderBackend: ) except Exception: responses.append( - FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + FileDownloadResponse( + path=virtual_path, content=None, error=_INVALID_PATH + ) ) return responses diff --git a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py index 82914f9ce..93eabe6ff 100644 --- a/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py +++ b/surfsense_backend/app/agents/new_chat/middleware/multi_root_local_folder_backend.py @@ -127,7 +127,9 @@ class MultiRootLocalFolderBackend: mount, local_path = self._split_mount_path(path) except ValueError: return [] - return self._transform_infos(mount, self._mount_to_backend[mount].ls_info(local_path)) + return self._transform_infos( + mount, self._mount_to_backend[mount].ls_info(local_path) + ) async def als_info(self, path: str) -> list[FileInfo]: return await asyncio.to_thread(self.ls_info, path) @@ -355,7 +357,9 @@ class MultiRootLocalFolderBackend: all_matches.extend( [ GrepMatch( - path=self._prefix_mount_path(mount, self._get_str(match, "path")), + path=self._prefix_mount_path( + mount, self._get_str(match, "path") + ), line=self._get_int(match, "line"), text=self._get_str(match, "text"), ) @@ -394,7 +398,9 @@ class MultiRootLocalFolderBackend: try: mount, local_path = self._split_mount_path(virtual_path) except ValueError: - invalid.append(FileUploadResponse(path=virtual_path, error=_INVALID_PATH)) + invalid.append( + FileUploadResponse(path=virtual_path, error=_INVALID_PATH) + ) continue grouped.setdefault(mount, []).append((local_path, content)) @@ -404,7 +410,9 @@ class MultiRootLocalFolderBackend: responses.extend( [ FileUploadResponse( - path=self._prefix_mount_path(mount, self._get_str(item, "path")), + path=self._prefix_mount_path( + mount, self._get_str(item, "path") + ), error=self._get_str(item, "error") or None, ) for item in result @@ -412,7 +420,9 @@ class MultiRootLocalFolderBackend: ) return responses - async def aupload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + async def aupload_files( + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: return await asyncio.to_thread(self.upload_files, files) def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: @@ -423,7 +433,9 @@ class MultiRootLocalFolderBackend: mount, local_path = self._split_mount_path(virtual_path) except ValueError: invalid.append( - FileDownloadResponse(path=virtual_path, content=None, error=_INVALID_PATH) + FileDownloadResponse( + path=virtual_path, content=None, error=_INVALID_PATH + ) ) continue grouped.setdefault(mount, []).append(local_path) @@ -434,7 +446,9 @@ class MultiRootLocalFolderBackend: responses.extend( [ FileDownloadResponse( - path=self._prefix_mount_path(mount, self._get_str(item, "path")), + path=self._prefix_mount_path( + mount, self._get_str(item, "path") + ), content=self._get_value(item, "content"), error=self._get_str(item, "error") or None, ) diff --git a/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py new file mode 100644 index 000000000..503c73ccc --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/noop_injection.py @@ -0,0 +1,141 @@ +""" +``_noop`` provider-compatibility tool + injection middleware. + +Some providers (LiteLLM, Bedrock, Copilot) 400 when a model call has +empty ``tools`` but the message history includes prior ``tool_calls`` — +they treat that shape as malformed even though it's perfectly valid +LangChain. SurfSense hits this on the compaction summarize call (no +tools, history full of tool calls). + +Ported from OpenCode's ``packages/opencode/src/session/llm.ts:209-228``, +which discovered and codified the workaround: inject a no-op tool *only* +on those provider shapes so the request validates without ever being +called. + +Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks +if the request has zero tools but the last AI message in history includes +``tool_calls``. If yes, it injects the ``_noop`` tool only — never +globally — mirroring OpenCode's gating exactly. The :func:`noop_tool` +returns empty content when called (which it should never be in +practice). +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import AIMessage +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +NOOP_TOOL_NAME = "_noop" +NOOP_TOOL_DESCRIPTION = "Do not call this tool. It exists only for API compatibility." + + +@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION) +def noop_tool() -> str: + """Return empty content. Never expected to be called.""" + return "" + + +# Provider markers that benefit from ``_noop`` injection. These match +# OpenCode's gating list (``llm.ts:209-228``). We also accept any string +# containing one of these substrings so e.g. ``litellm`` matches +# ``ChatLiteLLM``. +_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = ( + "litellm", + "bedrock", + "copilot", +) + + +def _provider_needs_noop(model: Any) -> bool: + """Heuristic: does this model's provider need the _noop injection?""" + try: + ls_params = model._get_ls_params() + provider = str(ls_params.get("ls_provider", "")).lower() + except Exception: + provider = "" + + if not provider: + cls_name = type(model).__name__.lower() + provider = cls_name + + return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS) + + +def _last_ai_has_tool_calls(messages: list[Any]) -> bool: + for msg in reversed(messages): + if isinstance(msg, AIMessage): + return bool(msg.tool_calls) + return False + + +class NoopInjectionMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): + """Inject the ``_noop`` tool only when the provider would otherwise 400. + + The check fires per model call, not at agent build time, because the + summarization path generates a no-tool subcall at runtime. The + extra tool is appended to ``request.tools`` as an instance — the + actual ``langchain_core.tools.BaseTool`` is bound on every call site + that creates the agent. + """ + + def __init__(self, *, noop_tool_instance: Any | None = None) -> None: + super().__init__() + self._noop_tool = noop_tool_instance or noop_tool + self.tools = [] + + def _should_inject(self, request: ModelRequest[ContextT]) -> bool: + if request.tools: + return False + if not _last_ai_has_tool_calls(request.messages): + return False + return _provider_needs_noop(request.model) + + def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: + return request.override(tools=[self._noop_tool]) + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> Any: + if self._should_inject(request): + logger.debug("Injecting _noop tool for provider compatibility") + return handler(self._augmented(request)) + return handler(request) + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> Any: + if self._should_inject(request): + logger.debug("Injecting _noop tool for provider compatibility") + return await handler(self._augmented(request)) + return await handler(request) + + +__all__ = [ + "NOOP_TOOL_DESCRIPTION", + "NOOP_TOOL_NAME", + "NoopInjectionMiddleware", + "_provider_needs_noop", + "noop_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/otel_span.py b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py new file mode 100644 index 000000000..cfe1edae4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/otel_span.py @@ -0,0 +1,202 @@ +""" +OpenTelemetry span middleware for the SurfSense ``new_chat`` agent. + +Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool +executions) with OTel spans, attaching low-cardinality span names and +high-cardinality identifiers as attributes. + +This middleware is intentionally a thin adapter over +:mod:`app.observability.otel`; when OTel is not configured all spans +collapse to no-ops and the wrapper adds <1µs overhead per call. When +OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every +model and tool call gets a span with the standard attributes our +dashboards expect. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import AIMessage, ToolMessage + +from app.observability import otel as ot + +if TYPE_CHECKING: # pragma: no cover — type-only + from langchain.agents.middleware.types import ( + ModelRequest, + ModelResponse, + ToolCallRequest, + ) + from langgraph.types import Command + +logger = logging.getLogger(__name__) + + +class OtelSpanMiddleware(AgentMiddleware): + """Emit ``model.call`` and ``tool.call`` OTel spans for every invocation. + + Should be placed near the **outer** end of the middleware list so + that the spans encompass retry/fallback wrapper effects (i.e. ``N`` + model.call spans for ``N`` retry attempts) but inside any concurrency/ + auth gate. Empirically this means **between** ``BusyMutex`` and + ``RetryAfter``. + """ + + def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None: + super().__init__() + self._instrumentation_name = instrumentation_name + + # ------------------------------------------------------------------ + # Model call spans + # ------------------------------------------------------------------ + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]], + ) -> ModelResponse | AIMessage | Any: + if not ot.is_enabled(): + return await handler(request) + + model_id, provider = _resolve_model_attrs(request) + with ot.model_call_span(model_id=model_id, provider=provider) as sp: + try: + result = await handler(request) + except Exception: + # span context manager records + re-raises + raise + else: + _annotate_model_response(sp, result) + return result + + # ------------------------------------------------------------------ + # Tool call spans + # ------------------------------------------------------------------ + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: + if not ot.is_enabled(): + return await handler(request) + + tool_name = _resolve_tool_name(request) + input_size = _resolve_input_size(request) + + with ot.tool_call_span(tool_name, input_size=input_size) as sp: + result = await handler(request) + _annotate_tool_result(sp, result) + return result + + +# --------------------------------------------------------------------------- +# Attribute helpers (kept defensive; we never want OTel bookkeeping to break +# a real model/tool call). +# --------------------------------------------------------------------------- + + +def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]: + """Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``.""" + model_id: str | None = None + provider: str | None = None + try: + model = getattr(request, "model", None) + if model is None: + return None, None + # langchain BaseChatModel exposes a few different identifiers + for attr in ("model_name", "model", "model_id"): + value = getattr(model, attr, None) + if value: + model_id = str(value) + break + # provider sometimes lives on ``_llm_type`` (legacy) or ``provider`` + for attr in ("provider", "_llm_type"): + value = getattr(model, attr, None) + if value: + provider = str(value) + break + except Exception: # pragma: no cover — defensive + pass + return model_id, provider + + +def _resolve_tool_name(request: Any) -> str: + try: + tool = getattr(request, "tool", None) + if tool is not None: + name = getattr(tool, "name", None) + if isinstance(name, str) and name: + return name + # Fall back to the tool_call dict + call = getattr(request, "tool_call", None) or {} + name = call.get("name") if isinstance(call, dict) else None + if isinstance(name, str) and name: + return name + except Exception: # pragma: no cover — defensive + pass + return "unknown" + + +def _resolve_input_size(request: Any) -> int | None: + try: + call = getattr(request, "tool_call", None) + if not isinstance(call, dict) or not call: + return None + args = call.get("args") + if args is None: + return None + return len(repr(args)) + except Exception: # pragma: no cover — defensive + return None + + +def _annotate_model_response(span: Any, result: Any) -> None: + """Best-effort: attach prompt/completion token counts when available.""" + try: + # ModelResponse may be a dataclass with .result containing AIMessage + msg: Any + if isinstance(result, AIMessage): + msg = result + else: + inner = getattr(result, "result", None) + msg = inner[-1] if isinstance(inner, list) and inner else inner + if msg is None: + return + usage = getattr(msg, "usage_metadata", None) or {} + if isinstance(usage, dict): + if (n := usage.get("input_tokens")) is not None: + span.set_attribute("tokens.prompt", int(n)) + if (n := usage.get("output_tokens")) is not None: + span.set_attribute("tokens.completion", int(n)) + if (n := usage.get("total_tokens")) is not None: + span.set_attribute("tokens.total", int(n)) + tool_calls = getattr(msg, "tool_calls", None) or [] + span.set_attribute("model.tool_calls", len(tool_calls)) + except Exception: # pragma: no cover — defensive + pass + + +def _annotate_tool_result(span: Any, result: Any) -> None: + try: + if isinstance(result, ToolMessage): + content = ( + result.content + if isinstance(result.content, str) + else repr(result.content) + ) + span.set_attribute("tool.output.size", len(content)) + status = getattr(result, "status", None) + if isinstance(status, str): + span.set_attribute("tool.status", status) + kwargs = getattr(result, "additional_kwargs", None) or {} + if isinstance(kwargs, dict) and kwargs.get("error"): + span.set_attribute("tool.error", True) + except Exception: # pragma: no cover — defensive + pass + + +__all__ = ["OtelSpanMiddleware"] diff --git a/surfsense_backend/app/agents/new_chat/middleware/permission.py b/surfsense_backend/app/agents/new_chat/middleware/permission.py new file mode 100644 index 000000000..37719e96a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/permission.py @@ -0,0 +1,358 @@ +""" +PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback. + +LangChain's :class:`HumanInTheLoopMiddleware` only supports a static +"this tool always asks" decision per tool. There's no rule-based +allow/deny/ask layered ruleset, no glob patterns, no per-search-space or +per-thread overrides, and no auto-deny synthesis. + +This middleware ports OpenCode's ``packages/opencode/src/permission/index.ts`` +ruleset model on top of SurfSense's existing ``interrupt({type, action, +context})`` payload shape (see ``app/agents/new_chat/tools/hitl.py``) so +the frontend keeps working unchanged. + +Operation: +1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``. +2. For each call, the middleware builds a list of ``patterns`` (the + tool name plus any tool-specific patterns from the resolver). It + evaluates each pattern against the layered rulesets and aggregates + the results: ``deny`` > ``ask`` > ``allow``. +3. On ``deny``: replaces the call with a synthetic ``ToolMessage`` + containing a :class:`StreamingError`. +4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply + shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``. + - ``once``: proceed. + - ``always``: also persist allow rules for ``request.always`` patterns. + - ``reject`` w/o feedback: raise :class:`RejectedError`. + - ``reject`` w/ feedback: raise :class:`CorrectedError`. +5. On ``allow``: proceed unchanged. + +The middleware also performs a *pre-model* tool-filter step (the +``before_model`` hook) so globally denied tools are stripped from the +exposed tool list before the model gets to see them. This mirrors +OpenCode's ``Permission.disabled`` and dramatically reduces the chance +the model emits a deny-only call. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, +) +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.runtime import Runtime +from langgraph.types import interrupt + +from app.agents.new_chat.errors import ( + CorrectedError, + RejectedError, + StreamingError, +) +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) +from app.observability import otel as ot + +logger = logging.getLogger(__name__) + + +# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of +# patterns to evaluate. The first pattern is conventionally the bare +# tool name; later entries narrow down to specific resources. +PatternResolver = Callable[[dict[str, Any]], list[str]] + + +def _default_pattern_resolver(name: str) -> PatternResolver: + def _resolve(args: dict[str, Any]) -> list[str]: + # Bare name covers the default catch-all; primary-arg fallbacks + # are best added per-tool by callers. + del args + return [name] + + return _resolve + + +class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Allow/deny/ask layer over the agent's tool calls. + + Args: + rulesets: Layered rulesets to evaluate. Earlier entries are + overridden by later ones (last-match-wins). Typical layering: + ``defaults < global < space < thread < runtime_approved``. + pattern_resolvers: Optional per-tool callables that return a list + of patterns to evaluate. When a tool isn't listed, the bare + tool name is used as the only pattern. + runtime_ruleset: Mutable :class:`Ruleset` that the middleware + extends in-place when the user replies ``"always"`` to an + ask interrupt. Reused across all calls in the same agent + instance so newly-allowed rules apply to subsequent calls. + always_emit_interrupt_payload: If True, every ask uses the + SurfSense interrupt wire format (default). Set False to + disable interrupts and treat ``ask`` as ``deny`` for + non-interactive deployments. + """ + + tools = () + + def __init__( + self, + *, + rulesets: list[Ruleset] | None = None, + pattern_resolvers: dict[str, PatternResolver] | None = None, + runtime_ruleset: Ruleset | None = None, + always_emit_interrupt_payload: bool = True, + ) -> None: + super().__init__() + self._static_rulesets: list[Ruleset] = list(rulesets or []) + self._pattern_resolvers: dict[str, PatternResolver] = dict( + pattern_resolvers or {} + ) + self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset( + origin="runtime_approved" + ) + self._emit_interrupt = always_emit_interrupt_payload + + # ------------------------------------------------------------------ + # Tool-filter step (mirrors OpenCode's ``Permission.disabled``) + # ------------------------------------------------------------------ + + def _globally_denied(self, tool_name: str) -> bool: + """Return True if a deny rule with no narrowing pattern matches.""" + rules = evaluate_many(tool_name, ["*"], *self._all_rulesets()) + return aggregate_action(rules) == "deny" + + def _all_rulesets(self) -> list[Ruleset]: + return [*self._static_rulesets, self._runtime_ruleset] + + # NOTE: ``before_model`` filtering of the tools list is left to the + # agent factory. This middleware only blocks at execution time — and + # only via the rule-evaluator path, not by mutating ``request.tools``. + # Mutating ``request.tools`` per-call would invalidate provider + # prompt-cache prefixes (see Operational risks: prompt-cache regression). + + # ------------------------------------------------------------------ + # Tool-call evaluation + # ------------------------------------------------------------------ + + def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]: + resolver = self._pattern_resolvers.get( + tool_name, _default_pattern_resolver(tool_name) + ) + try: + patterns = resolver(args or {}) + except Exception: + logger.exception( + "Pattern resolver for %s raised; using bare name", tool_name + ) + patterns = [tool_name] + if not patterns: + patterns = [tool_name] + return patterns + + def _evaluate( + self, tool_name: str, args: dict[str, Any] + ) -> tuple[str, list[str], list[Rule]]: + patterns = self._resolve_patterns(tool_name, args) + rules = evaluate_many(tool_name, patterns, *self._all_rulesets()) + action = aggregate_action(rules) + return action, patterns, rules + + # ------------------------------------------------------------------ + # HITL ask flow — SurfSense wire format + # ------------------------------------------------------------------ + + def _raise_interrupt( + self, + *, + tool_name: str, + args: dict[str, Any], + patterns: list[str], + rules: list[Rule], + ) -> dict[str, Any]: + """Block on user approval via SurfSense's ``interrupt`` shape.""" + if not self._emit_interrupt: + return {"decision_type": "reject"} + + # ``params`` (NOT ``args``) is what SurfSense's streaming + # normalizer forwards. Other fields move into ``context``. + payload = { + "type": "permission_ask", + "action": {"tool": tool_name, "params": args or {}}, + "context": { + "patterns": patterns, + "rules": [ + { + "permission": r.permission, + "pattern": r.pattern, + "action": r.action, + } + for r in rules + ], + # Rules of thumb for the frontend: surface the patterns + # the user can promote to "always" with a single reply. + "always": patterns, + }, + } + # Open ``permission.asked`` + ``interrupt.raised`` OTel spans + # (no-op when OTel is disabled) so dashboards can correlate + # "we asked X" with "interrupt was actually delivered". + with ( + ot.permission_asked_span( + permission=tool_name, + pattern=patterns[0] if patterns else None, + extra={"permission.patterns": list(patterns)}, + ), + ot.interrupt_span(interrupt_type="permission_ask"), + ): + decision = interrupt(payload) + if isinstance(decision, dict): + return decision + # Tolerate a plain string reply ("once", "always", "reject") + if isinstance(decision, str): + return {"decision_type": decision} + return {"decision_type": "reject"} + + def _persist_always(self, tool_name: str, patterns: list[str]) -> None: + """Promote ``always`` reply into runtime allow rules. + + Persistence to ``agent_permission_rules`` is done by the + streaming layer (``stream_new_chat``) once it observes the + ``always`` reply — the middleware just keeps an in-memory + copy so subsequent calls in the same stream see the rule. + """ + for pattern in patterns: + self._runtime_ruleset.rules.append( + Rule(permission=tool_name, pattern=pattern, action="allow") + ) + + # ------------------------------------------------------------------ + # Synthesizing deny -> ToolMessage + # ------------------------------------------------------------------ + + @staticmethod + def _deny_message( + tool_call: dict[str, Any], + rule: Rule, + ) -> ToolMessage: + err = StreamingError( + code="permission_denied", + retryable=False, + suggestion=( + f"rule permission={rule.permission!r} pattern={rule.pattern!r} " + f"blocked this call" + ), + ) + return ToolMessage( + content=( + f"Permission denied: rule {rule.permission}/{rule.pattern} " + f"blocked tool {tool_call.get('name')!r}." + ), + tool_call_id=tool_call.get("id") or "", + name=tool_call.get("name"), + status="error", + additional_kwargs={"error": err.model_dump()}, + ) + + # ------------------------------------------------------------------ + # The hook: aafter_model + # ------------------------------------------------------------------ + + def _process( + self, + state: AgentState, + runtime: Runtime[Any], + ) -> dict[str, Any] | None: + del runtime # unused + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage) or not last.tool_calls: + return None + + deny_messages: list[ToolMessage] = [] + kept_calls: list[dict[str, Any]] = [] + any_change = False + + for raw in last.tool_calls: + call = ( + dict(raw) + if isinstance(raw, dict) + else { + "name": getattr(raw, "name", None), + "args": getattr(raw, "args", {}), + "id": getattr(raw, "id", None), + "type": "tool_call", + } + ) + name = call.get("name") or "" + args = call.get("args") or {} + action, patterns, rules = self._evaluate(name, args) + + if action == "deny": + # Find the deny rule for the suggestion text + deny_rule = next((r for r in rules if r.action == "deny"), rules[0]) + deny_messages.append(self._deny_message(call, deny_rule)) + any_change = True + continue + + if action == "ask": + decision = self._raise_interrupt( + tool_name=name, args=args, patterns=patterns, rules=rules + ) + kind = str(decision.get("decision_type") or "reject").lower() + if kind == "once": + kept_calls.append(call) + elif kind == "always": + self._persist_always(name, patterns) + kept_calls.append(call) + elif kind == "reject": + feedback = decision.get("feedback") + if isinstance(feedback, str) and feedback.strip(): + raise CorrectedError(feedback, tool=name) + raise RejectedError( + tool=name, pattern=patterns[0] if patterns else None + ) + else: + logger.warning( + "Unknown permission decision %r; treating as reject", kind + ) + raise RejectedError(tool=name) + continue + + # allow + kept_calls.append(call) + + if not any_change and len(kept_calls) == len(last.tool_calls): + return None + + updated = last.model_copy(update={"tool_calls": kept_calls}) + result_messages: list[Any] = [updated] + if deny_messages: + result_messages.extend(deny_messages) + return {"messages": result_messages} + + def after_model( # type: ignore[override] + self, state: AgentState, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return self._process(state, runtime) + + async def aafter_model( # type: ignore[override] + self, state: AgentState, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: + return self._process(state, runtime) + + +__all__ = [ + "PatternResolver", + "PermissionMiddleware", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/retry_after.py b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py new file mode 100644 index 000000000..0c3d3d017 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/retry_after.py @@ -0,0 +1,257 @@ +""" +RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing. + +LangChain's :class:`ModelRetryMiddleware` retries on exceptions but ignores +the ``Retry-After`` HTTP header — it just runs its own exponential backoff. +That wastes time when a provider has explicitly told us how long to wait. +This middleware honors the header (mirroring OpenCode's +``packages/opencode/src/session/llm.ts`` retry pathway) and emits an SSE +event so the UI can show "rate-limited, retrying in Ns". + +We can't subclass ``ModelRetryMiddleware`` cleanly because its loop calls a +module-level ``calculate_delay`` inline (no overridable +``_calculate_delay`` hook), so this is a standalone implementation. + +Behaviour: +- Extracts ``Retry-After`` / ``retry-after-ms`` from + ``litellm.exceptions.RateLimitError.response.headers`` (or any exception + exposing a similar shape). +- Sleeps ``max(exponential_backoff, header_delay)`` between retries. +- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` / + ``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or + the LangChain summarization fallback path) handles those instead. +- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry + so ``stream_new_chat`` can forward it to clients as an SSE event. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import re +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event +from langchain_core.messages import AIMessage + +logger = logging.getLogger(__name__) + +# Names of exception classes for which a retry would not help — context +# overflow needs compaction, auth needs human intervention, etc. Detected +# by class-name substring so we don't have to import LiteLLM/Anthropic +# here (which would tie this module to optional deps). +_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = ( + "ContextWindowExceeded", + "ContextOverflow", + "AuthenticationError", + "InvalidRequestError", + "PermissionDenied", + "InvalidApiKey", + "ContextLimit", +) + + +def _is_non_retryable(exc: BaseException) -> bool: + name = type(exc).__name__ + return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS) + + +def _extract_retry_after_seconds(exc: BaseException) -> float | None: + """Return seconds-to-wait suggested by the provider, if any. + + Looks at ``exc.response.headers`` or ``exc.headers`` for the standard + HTTP ``Retry-After`` header (in seconds) or its millisecond cousin + ``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back + to a regex on the exception message for shapes like + ``"Please retry after 30s"``. + """ + headers: dict[str, Any] | None = None + response = getattr(exc, "response", None) + if response is not None: + headers = getattr(response, "headers", None) + if headers is None: + headers = getattr(exc, "headers", None) + + if isinstance(headers, dict): + # Normalize keys to lowercase for case-insensitive matching + norm = {str(k).lower(): v for k, v in headers.items()} + ms = norm.get("retry-after-ms") + if ms is not None: + try: + return float(ms) / 1000.0 + except (TypeError, ValueError): + pass + seconds = norm.get("retry-after") + if seconds is not None: + try: + return float(seconds) + except (TypeError, ValueError): + pass + + # Last resort: scan the message for "retry after Xs" or "X seconds" + msg = str(exc) + match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE) + if match: + try: + return float(match.group(1)) + except ValueError: + return None + return None + + +def _exponential_delay( + attempt: int, + *, + initial_delay: float, + backoff_factor: float, + max_delay: float, + jitter: bool, +) -> float: + """Compute an exponential-backoff delay with optional ±25% jitter.""" + delay = ( + initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay + ) + delay = min(delay, max_delay) + if jitter and delay > 0: + delay *= 1 + random.uniform(-0.25, 0.25) + return max(delay, 0.0) + + +class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Retry middleware that honors provider-issued Retry-After hints. + + Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware` + when working with LiteLLM/Anthropic/OpenAI providers that surface + rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE + events so the UI can show a friendly "rate limited, retrying in Xs" + indicator. + + Args: + max_retries: Maximum retries after the initial attempt (default 3). + initial_delay: Initial backoff delay in seconds. + backoff_factor: Exponential growth factor for backoff. + max_delay: Cap on per-attempt delay in seconds. + jitter: Whether to add ±25% jitter. + retry_on: Optional callable that returns True for retryable + exceptions. The default retries everything except known + non-retryable classes (context overflow, auth, etc.). + """ + + def __init__( + self, + *, + max_retries: int = 3, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_delay: float = 60.0, + jitter: bool = True, + retry_on: Callable[[BaseException], bool] | None = None, + ) -> None: + super().__init__() + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + self.max_delay = max_delay + self.jitter = jitter + self._retry_on: Callable[[BaseException], bool] = retry_on or ( + lambda exc: not _is_non_retryable(exc) + ) + + def _should_retry(self, exc: BaseException) -> bool: + try: + return bool(self._retry_on(exc)) + except Exception: + logger.exception("retry_on callable raised; defaulting to False") + return False + + def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float: + backoff = _exponential_delay( + attempt, + initial_delay=self.initial_delay, + backoff_factor=self.backoff_factor, + max_delay=self.max_delay, + jitter=self.jitter, + ) + header = _extract_retry_after_seconds(exc) or 0.0 + return max(backoff, header) + + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: + for attempt in range(self.max_retries + 1): + try: + return handler(request) + except Exception as exc: + if not self._should_retry(exc) or attempt >= self.max_retries: + raise + delay = self._delay_for_attempt(attempt, exc) + try: + dispatch_custom_event( + "surfsense.retrying", + { + "attempt": attempt + 1, + "max_retries": self.max_retries, + "delay_ms": int(delay * 1000), + "reason": type(exc).__name__, + }, + ) + except Exception: + logger.debug( + "dispatch_custom_event failed; suppressed", exc_info=True + ) + if delay > 0: + time.sleep(delay) + # Unreachable + raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution") + + async def awrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> ModelResponse[ResponseT] | AIMessage: + for attempt in range(self.max_retries + 1): + try: + return await handler(request) + except Exception as exc: + if not self._should_retry(exc) or attempt >= self.max_retries: + raise + delay = self._delay_for_attempt(attempt, exc) + try: + await adispatch_custom_event( + "surfsense.retrying", + { + "attempt": attempt + 1, + "max_retries": self.max_retries, + "delay_ms": int(delay * 1000), + "reason": type(exc).__name__, + }, + ) + except Exception: + logger.debug( + "adispatch_custom_event failed; suppressed", exc_info=True + ) + if delay > 0: + await asyncio.sleep(delay) + raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution") + + +__all__ = [ + "RetryAfterMiddleware", + "_extract_retry_after_seconds", + "_is_non_retryable", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py b/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py deleted file mode 100644 index 4ddcf334f..000000000 --- a/surfsense_backend/app/agents/new_chat/middleware/safe_summarization.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Safe wrapper around deepagents' SummarizationMiddleware. - -Upstream issue --------------- -`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend` -(and its sync counterpart) call -``get_buffer_string(filtered_messages)`` before writing the evicted history -to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string`` -accesses ``m.text`` which iterates ``self.content`` — this raises -``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage`` -has ``content=None`` (common when a model returns *only* tool_calls, seen -frequently with Azure OpenAI ``gpt-5.x`` responses streamed through -LiteLLM). - -The exception aborts the whole agent turn, so the user just sees "Error during -chat" with no assistant response. - -Fix ---- -We subclass ``SummarizationMiddleware`` and override -``_filter_summary_messages`` — the only call site that feeds messages into -``get_buffer_string`` — to return *copies* of messages whose ``content`` is -``None`` with ``content=""``. The originals flowing through the rest of the -agent state are untouched. - -We also expose a drop-in ``create_safe_summarization_middleware`` factory -that mirrors ``deepagents.middleware.summarization.create_summarization_middleware`` -but instantiates our safe subclass. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from deepagents.middleware.summarization import ( - SummarizationMiddleware, - compute_summarization_defaults, -) - -if TYPE_CHECKING: - from deepagents.backends.protocol import BACKEND_TYPES - from langchain_core.language_models import BaseChatModel - from langchain_core.messages import AnyMessage - -logger = logging.getLogger(__name__) - - -def _sanitize_message_content(msg: AnyMessage) -> AnyMessage: - """Return ``msg`` with ``content`` coerced to a non-``None`` value. - - ``get_buffer_string`` reads ``m.text`` which iterates ``self.content``; - when a provider streams back an ``AIMessage`` with only tool_calls and - no text, ``content`` can be ``None`` and the iteration explodes. We - replace ``None`` with an empty string so downstream consumers that only - care about text see an empty body. - - The original message is left untouched — we return a copy via - pydantic's ``model_copy`` when available, otherwise we fall back to - re-setting the attribute on a shallow copy. - """ - - if getattr(msg, "content", "not-missing") is not None: - return msg - - try: - return msg.model_copy(update={"content": ""}) - except AttributeError: - import copy - - new_msg = copy.copy(msg) - try: - new_msg.content = "" - except Exception: # pragma: no cover - defensive - logger.debug( - "Could not sanitize content=None on message of type %s", - type(msg).__name__, - ) - return msg - return new_msg - - -class SafeSummarizationMiddleware(SummarizationMiddleware): - """`SummarizationMiddleware` that tolerates messages with ``content=None``. - - Only ``_filter_summary_messages`` is overridden — this is the single - helper invoked by both the sync and async offload paths immediately - before ``get_buffer_string``. Normalising here means we get coverage - for both without having to copy the (long, rapidly-changing) offload - implementations from upstream. - """ - - def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]: - filtered = super()._filter_summary_messages(messages) - return [_sanitize_message_content(m) for m in filtered] - - -def create_safe_summarization_middleware( - model: BaseChatModel, - backend: BACKEND_TYPES, -) -> SafeSummarizationMiddleware: - """Drop-in replacement for ``create_summarization_middleware``. - - Mirrors the defaults computed by ``deepagents`` but returns our - ``SafeSummarizationMiddleware`` subclass so the - ``content=None`` crash in ``get_buffer_string`` is avoided. - """ - - defaults = compute_summarization_defaults(model) - return SafeSummarizationMiddleware( - model=model, - backend=backend, - trigger=defaults["trigger"], - keep=defaults["keep"], - trim_tokens_to_summarize=None, - truncate_args_settings=defaults["truncate_args_settings"], - ) - - -__all__ = [ - "SafeSummarizationMiddleware", - "create_safe_summarization_middleware", -] diff --git a/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py new file mode 100644 index 000000000..072d73401 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/skills_backends.py @@ -0,0 +1,337 @@ +"""Skills backends for SurfSense. + +Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol` +subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`. + +The middleware only needs four methods to load skills from a backend: + +* ``ls_info`` / ``als_info`` — list directories under a source path. +* ``download_files`` / ``adownload_files`` — fetch ``SKILL.md`` bytes. + +Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` …) +default to ``NotImplementedError`` from the base class. They are never reached +by the skills middleware because skill content is rendered into the system +prompt at agent build time, not edited at runtime. + +Two backends are provided: + +* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from + ``app/agents/new_chat/skills/builtin/``. +* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over + :class:`KBPostgresBackend` that filters notes under the privileged folder + ``/documents/_skills/``. + +Both backends are intentionally read-only: skill authoring happens out of band +(via filesystem or a search-space-admin route), so we never expose +``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError`` +gives a clean failure mode if anything tries. +""" + +from __future__ import annotations + +import contextlib +import logging +from collections.abc import Callable +from dataclasses import replace +from pathlib import Path +from typing import TYPE_CHECKING + +from deepagents.backends.composite import CompositeBackend +from deepagents.backends.protocol import ( + BackendProtocol, + FileDownloadResponse, + FileInfo, +) +from deepagents.backends.state import StateBackend + +if TYPE_CHECKING: + from langchain.tools import ToolRuntime + + from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend + +logger = logging.getLogger(__name__) + + +# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE. +_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024 + + +def _default_builtin_root() -> Path: + """Return the absolute path to the bundled builtin skills directory. + + Located at ``app/agents/new_chat/skills/builtin/`` relative to this module. + """ + return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve() + + +class BuiltinSkillsBackend(BackendProtocol): + """Read-only disk-backed skills source. + + Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk, + where each skill is its own subdirectory containing a ``SKILL.md`` file:: + + //SKILL.md + + The middleware calls :meth:`als_info` with the source path and expects a + ``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it + calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and + parses YAML frontmatter from the returned ``content`` bytes. + + Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at + prefix ``/skills/builtin/`` means the middleware can issue paths like + ``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to + ``/kb-research/SKILL.md`` before forwarding here. We treat any leading + slash as anchoring at :attr:`root`. + """ + + def __init__(self, root: Path | str | None = None) -> None: + self.root: Path = Path(root).resolve() if root else _default_builtin_root() + if not self.root.exists(): + logger.info( + "BuiltinSkillsBackend root %s does not exist; skills will be empty.", + self.root, + ) + + def _resolve(self, path: str) -> Path: + """Resolve a virtual posix path under :attr:`root`, refusing escapes.""" + bare = path.lstrip("/") + candidate = (self.root / bare).resolve() if bare else self.root + # Refuse symlink/.. traversal that escapes the root. + try: + candidate.relative_to(self.root) + except ValueError as exc: + raise ValueError(f"path {path!r} escapes builtin skills root") from exc + return candidate + + def ls_info(self, path: str) -> list[FileInfo]: + try: + target = self._resolve(path) + except ValueError as exc: + logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc) + return [] + if not target.exists() or not target.is_dir(): + return [] + + infos: list[FileInfo] = [] + # Build virtual paths anchored at "/" because CompositeBackend already + # stripped the route prefix before calling us. + target_virtual = ( + "/" + if target == self.root + else ("/" + str(target.relative_to(self.root)).replace("\\", "/")) + ) + for child in sorted(target.iterdir()): + child_virtual = ( + target_virtual.rstrip("/") + "/" + child.name + if target_virtual != "/" + else "/" + child.name + ) + info: FileInfo = { + "path": child_virtual, + "is_dir": child.is_dir(), + } + if child.is_file(): + with contextlib.suppress(OSError): # pragma: no cover - defensive + info["size"] = child.stat().st_size + infos.append(info) + return infos + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + responses: list[FileDownloadResponse] = [] + for p in paths: + try: + target = self._resolve(p) + except ValueError: + responses.append(FileDownloadResponse(path=p, error="invalid_path")) + continue + if not target.exists(): + responses.append(FileDownloadResponse(path=p, error="file_not_found")) + continue + if target.is_dir(): + responses.append(FileDownloadResponse(path=p, error="is_directory")) + continue + try: + # Hard cap to avoid loading rogue mega-files into memory. + size = target.stat().st_size + if size > _MAX_SKILL_FILE_SIZE: + logger.warning( + "Builtin skill file %s exceeds %d bytes; truncating.", + target, + _MAX_SKILL_FILE_SIZE, + ) + with target.open("rb") as fh: + content = fh.read(_MAX_SKILL_FILE_SIZE) + else: + content = target.read_bytes() + except PermissionError: + responses.append( + FileDownloadResponse(path=p, error="permission_denied") + ) + continue + except OSError as exc: # pragma: no cover - defensive + logger.warning("Builtin skill read failed %s: %s", target, exc) + responses.append(FileDownloadResponse(path=p, error="file_not_found")) + continue + responses.append(FileDownloadResponse(path=p, content=content, error=None)) + return responses + + +class SearchSpaceSkillsBackend(BackendProtocol): + """Read-only view of search-space-authored skills. + + Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged + folder ``/documents/_skills/`` (configurable). The folder is intended to be + writable only by search-space admins; this backend never writes. + + The skills middleware expects a layout like:: + + ///SKILL.md + + But the KB stores documents like ``/documents/_skills//SKILL.md``. + We expose the inner namespace by remapping each path. When mounted under + :class:`CompositeBackend` at prefix ``/skills/space/`` the paths the + middleware sees become ``/skills/space//SKILL.md``; the composite + strips ``/skills/space/`` and hands us ``//SKILL.md``, which we + rewrite to ``/documents/_skills//SKILL.md`` before forwarding to the + KB. + + No new database table is needed: the privileged folder convention is + enforced server-side outside of this class. We intentionally swallow any + write/edit attempts (the base class raises ``NotImplementedError``). + """ + + DEFAULT_KB_ROOT: str = "/documents/_skills" + + def __init__( + self, + kb_backend: KBPostgresBackend, + *, + kb_root: str = DEFAULT_KB_ROOT, + ) -> None: + self._kb = kb_backend + # Normalize trailing slash off so we can join cleanly. + self._kb_root = kb_root.rstrip("/") or "/" + + def _to_kb(self, path: str) -> str: + """Rewrite a virtual path into the underlying KB namespace.""" + bare = path.lstrip("/") + if not bare: + return self._kb_root + return f"{self._kb_root}/{bare}" + + def _from_kb(self, kb_path: str) -> str: + """Rewrite a KB path back into our virtual namespace.""" + if not kb_path.startswith(self._kb_root): + return kb_path # pragma: no cover - defensive + rel = kb_path[len(self._kb_root) :] + return rel if rel.startswith("/") else "/" + rel + + def ls_info(self, path: str) -> list[FileInfo]: + # KBPostgresBackend exposes only the async API meaningfully; the sync + # path falls back to ``asyncio.to_thread(...)`` in the base class. We + # keep this stub to satisfy abstract resolution; the middleware calls + # ``als_info``. + raise NotImplementedError("SearchSpaceSkillsBackend is async-only") + + async def als_info(self, path: str) -> list[FileInfo]: + kb_path = self._to_kb(path) + try: + infos = await self._kb.als_info(kb_path) + except Exception as exc: # pragma: no cover - defensive + logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc) + return [] + remapped: list[FileInfo] = [] + for info in infos: + kb_p = info.get("path", "") + if not kb_p.startswith(self._kb_root): + continue + remapped.append({**info, "path": self._from_kb(kb_p)}) + return remapped + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + raise NotImplementedError("SearchSpaceSkillsBackend is async-only") + + async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]: + kb_paths = [self._to_kb(p) for p in paths] + responses = await self._kb.adownload_files(kb_paths) + # Re-map response paths back to the virtual namespace so the middleware + # correlates them to the input list correctly. + remapped: list[FileDownloadResponse] = [] + for original, resp in zip(paths, responses, strict=True): + remapped.append(replace(resp, path=original)) + return remapped + + +SKILLS_BUILTIN_PREFIX = "/skills/builtin/" +SKILLS_SPACE_PREFIX = "/skills/space/" + + +def build_skills_backend_factory( + *, + builtin_root: Path | str | None = None, + search_space_id: int | None = None, +) -> Callable[[ToolRuntime], BackendProtocol]: + """Return a runtime-aware factory for the skills :class:`CompositeBackend`. + + When ``search_space_id`` is provided the composite includes a + :class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh + per-runtime :class:`KBPostgresBackend`, mirroring how + :func:`build_backend_resolver` constructs the main filesystem backend. + + When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit + tests) only the bundled :class:`BuiltinSkillsBackend` is exposed. + + Returning a factory rather than a fixed instance is intentional: the + underlying KB backend depends on per-call ``ToolRuntime`` state + (``staged_dirs``, ``files`` cache, runtime config), so a single shared + instance cannot serve multiple concurrent agent runs. + """ + builtin = BuiltinSkillsBackend(builtin_root) + + if search_space_id is None: + + def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol: + # Default StateBackend is intentionally inert: any path outside the + # ``/skills/builtin/`` route resolves to an empty per-runtime state + # so the SkillsMiddleware can iterate sources without raising. + return CompositeBackend( + default=StateBackend(runtime), + routes={SKILLS_BUILTIN_PREFIX: builtin}, + ) + + return _factory_builtin_only + + def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol: + # Imported lazily to avoid a hard dependency at module import time: + # ``KBPostgresBackend`` pulls in DB models, which are unnecessary for + # the unit-tested builtin path. + from app.agents.new_chat.middleware.kb_postgres_backend import ( + KBPostgresBackend, + ) + + kb = KBPostgresBackend(search_space_id, runtime) + space = SearchSpaceSkillsBackend(kb) + return CompositeBackend( + default=StateBackend(runtime), + routes={ + SKILLS_BUILTIN_PREFIX: builtin, + SKILLS_SPACE_PREFIX: space, + }, + ) + + return _factory_with_space + + +def default_skills_sources() -> list[str]: + """Return the canonical source list for SkillsMiddleware (built-in then space).""" + return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX] + + +__all__ = [ + "SKILLS_BUILTIN_PREFIX", + "SKILLS_SPACE_PREFIX", + "BuiltinSkillsBackend", + "SearchSpaceSkillsBackend", + "build_skills_backend_factory", + "default_skills_sources", +] diff --git a/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py new file mode 100644 index 000000000..9f81a168b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/middleware/tool_call_repair.py @@ -0,0 +1,193 @@ +""" +ToolCallNameRepairMiddleware — two-stage tool-name repair. + +Operation: +1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in + the registry but ``name.lower()`` is, rewrite in place. Catches + models that emit ``Search`` instead of ``search``. +2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call + to ``invalid`` with ``args={"tool": original_name, "error": }`` + so the registered :func:`invalid_tool` returns the error to the model + for self-correction. + +Ported from OpenCode's ``packages/opencode/src/session/llm.ts:339-358`` ++ ``packages/opencode/src/tool/invalid.ts``. LangChain has no equivalent: +:class:`deepagents.middleware.PatchToolCallsMiddleware` patches +*dangling* tool calls (no matching ToolMessage) but does nothing about +wrong names, and the model framework's default behavior on an unknown +name is to crash the turn rather than route to a self-correction +fallback. +""" + +from __future__ import annotations + +import difflib +import logging +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, +) +from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime + +from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME + +logger = logging.getLogger(__name__) + + +def _coerce_existing_tool_call(call: Any) -> dict[str, Any]: + """Normalize a tool call entry to a mutable dict.""" + if isinstance(call, dict): + return dict(call) + return { + "name": getattr(call, "name", None), + "args": getattr(call, "args", {}), + "id": getattr(call, "id", None), + "type": "tool_call", + } + + +class ToolCallNameRepairMiddleware( + AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] +): + """Two-stage tool-name repair on the most recent ``AIMessage``. + + Args: + registered_tool_names: Set of canonically-registered tool names. + ``invalid`` should be in this set so the fallback dispatches. + fuzzy_match_threshold: Optional ``difflib`` ratio (0-1) for the + fuzzy-match step that runs *between* lowercase and invalid. + Set to ``None`` to disable fuzzy matching (default in + OpenCode; we mirror that to avoid silent rewrites). + """ + + def __init__( + self, + *, + registered_tool_names: set[str], + fuzzy_match_threshold: float | None = 0.85, + ) -> None: + super().__init__() + self._registered = set(registered_tool_names) + self._registered_lower = {name.lower(): name for name in self._registered} + self._fuzzy_threshold = fuzzy_match_threshold + self.tools = [] + + def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]: + """Allow runtime overrides to expand the set (e.g. dynamic MCP tools).""" + ctx_tools = getattr(runtime.context, "registered_tool_names", None) + if isinstance(ctx_tools, set | frozenset): + return self._registered | set(ctx_tools) + if isinstance(ctx_tools, list | tuple): + return self._registered | set(ctx_tools) + return self._registered + + def _repair_one( + self, + call: dict[str, Any], + registered: set[str], + ) -> dict[str, Any]: + name = call.get("name") + if not isinstance(name, str): + return call + + if name in registered: + return call + + # Stage 1 — lowercase + lowered = name.lower() + if lowered in registered: + call["name"] = lowered + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", "lowercase") + call["response_metadata"] = metadata + return call + + # Optional fuzzy step (off by default — see class docstring) + if self._fuzzy_threshold is not None: + close = difflib.get_close_matches( + name, registered, n=1, cutoff=self._fuzzy_threshold + ) + if close: + call["name"] = close[0] + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}") + call["response_metadata"] = metadata + return call + + # Stage 2 — invalid fallback + if INVALID_TOOL_NAME in registered: + original_args = call.get("args") or {} + error_msg = ( + f"Tool name '{name}' is not registered. " + f"Original arguments were: {original_args!r}." + ) + call["name"] = INVALID_TOOL_NAME + call["args"] = {"tool": name, "error": error_msg} + metadata = dict(call.get("response_metadata") or {}) + metadata.setdefault("repair", f"invalid_fallback:{name}") + call["response_metadata"] = metadata + else: + logger.warning( + "Could not repair unknown tool call %r; 'invalid' tool not registered", + name, + ) + return call + + def _maybe_repair( + self, + message: AIMessage, + registered: set[str], + ) -> AIMessage | None: + if not message.tool_calls: + return None + + new_calls: list[dict[str, Any]] = [] + any_changed = False + for raw in message.tool_calls: + call = _coerce_existing_tool_call(raw) + before = (call.get("name"), call.get("args")) + repaired = self._repair_one(call, registered) + after = (repaired.get("name"), repaired.get("args")) + if before != after: + any_changed = True + new_calls.append(repaired) + + if not any_changed: + return None + + return message.model_copy(update={"tool_calls": new_calls}) + + def after_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + messages = state.get("messages") or [] + if not messages: + return None + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + registered = self._registered_for_runtime(runtime) + repaired = self._maybe_repair(last, registered) + if repaired is None: + return None + return {"messages": [repaired]} + + async def aafter_model( # type: ignore[override] + self, + state: AgentState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + return self.after_model(state, runtime) + + +__all__ = [ + "ToolCallNameRepairMiddleware", +] diff --git a/surfsense_backend/app/agents/new_chat/path_resolver.py b/surfsense_backend/app/agents/new_chat/path_resolver.py new file mode 100644 index 000000000..861f48ee7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/path_resolver.py @@ -0,0 +1,351 @@ +"""Canonical virtual-path resolver for SurfSense knowledge-base documents. + +This module is the single source of truth for mapping ``Document`` rows to +virtual paths under ``/documents/`` and back. It is used by: + +* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree) +* :class:`KnowledgePriorityMiddleware` (computing priority paths) +* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations) +* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates) + +Centralising the logic ensures that title-collision suffixes, folder paths, +and ``unique_identifier_hash`` lookups never drift between renders and +commits. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import Document, DocumentType, Folder +from app.utils.document_converters import generate_unique_identifier_hash + +DOCUMENTS_ROOT = "/documents" +"""Root virtual folder for all KB documents.""" + +_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+") +_WHITESPACE_RUN = re.compile(r"\s+") + + +def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str: + """Convert arbitrary text into a filesystem-safe ``.xml`` filename.""" + name = _INVALID_FILENAME_CHARS.sub("_", value).strip() + name = _WHITESPACE_RUN.sub(" ", name) + if not name: + name = fallback + if len(name) > 180: + name = name[:180].rstrip() + if not name.lower().endswith(".xml"): + name = f"{name}.xml" + return name + + +def safe_folder_segment(value: str, *, fallback: str = "folder") -> str: + """Sanitize a single folder name into a path-safe segment.""" + name = _INVALID_FILENAME_CHARS.sub("_", value).strip() + name = _WHITESPACE_RUN.sub(" ", name) + if not name: + return fallback + if len(name) > 180: + name = name[:180].rstrip() + return name + + +def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str: + if doc_id is None: + return filename + if not filename.lower().endswith(".xml"): + return f"{filename} ({doc_id}).xml" + stem = filename[:-4] + return f"{stem} ({doc_id}).xml" + + +_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE) + + +def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]: + """Strip a trailing ``" ().xml"`` suffix; return ``(stem, doc_id)``. + + If no suffix is present, returns ``(stem_without_xml_extension, None)``. + """ + match = _SUFFIX_PATTERN.search(filename) + if match: + doc_id = int(match.group(1)) + stem = filename[: match.start()] + return stem, doc_id + if filename.lower().endswith(".xml"): + return filename[:-4], None + return filename, None + + +@dataclass +class PathIndex: + """In-memory occupancy snapshot used by :func:`doc_to_virtual_path`. + + Built once per call site so collision handling is deterministic and so + we don't perform N folder lookups per render. + """ + + folder_paths: dict[int, str] = field(default_factory=dict) + """``Folder.id`` -> absolute virtual folder path under ``/documents``.""" + + occupants: dict[str, int] = field(default_factory=dict) + """virtual path -> ``Document.id`` already occupying that path (this render).""" + + +async def _build_folder_paths( + session: AsyncSession, + search_space_id: int, +) -> dict[int, str]: + """Compute ``Folder.id`` -> absolute virtual path under ``/documents``.""" + result = await session.execute( + select(Folder.id, Folder.name, Folder.parent_id).where( + Folder.search_space_id == search_space_id + ) + ) + rows = result.all() + by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows} + cache: dict[int, str] = {} + + def resolve(folder_id: int) -> str: + if folder_id in cache: + return cache[folder_id] + parts: list[str] = [] + cursor: int | None = folder_id + visited: set[int] = set() + while cursor is not None and cursor in by_id and cursor not in visited: + visited.add(cursor) + entry = by_id[cursor] + parts.append(safe_folder_segment(str(entry["name"]))) + cursor = entry["parent_id"] + parts.reverse() + path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT + cache[folder_id] = path + return path + + for folder_id in by_id: + resolve(folder_id) + return cache + + +async def build_path_index( + session: AsyncSession, + search_space_id: int, + *, + populate_occupants: bool = True, +) -> PathIndex: + """Build a :class:`PathIndex` for a search space. + + ``populate_occupants`` controls whether the occupancy map is pre-seeded + from existing ``Document`` rows. Most callers want this so that + :func:`doc_to_virtual_path` can detect collisions across the whole space; + the persistence middleware sets this to ``False`` when it is iterating to + decide where to place fresh documents. + """ + folder_paths = await _build_folder_paths(session, search_space_id) + occupants: dict[str, int] = {} + if populate_occupants: + rows = await session.execute( + select(Document.id, Document.title, Document.folder_id).where( + Document.search_space_id == search_space_id, + ) + ) + for row in rows.all(): + base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT) + filename = safe_filename(str(row.title or "untitled")) + path = f"{base}/{filename}" + if path in occupants and occupants[path] != row.id: + path = f"{base}/{_suffix_with_doc_id(filename, row.id)}" + occupants[path] = row.id + return PathIndex(folder_paths=folder_paths, occupants=occupants) + + +def doc_to_virtual_path( + *, + doc_id: int | None, + title: str, + folder_id: int | None, + index: PathIndex, +) -> str: + """Return the canonical virtual path for a document. + + Mutates ``index.occupants`` so subsequent calls see this assignment and + deterministically pick a different suffix for the next colliding doc. + """ + base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT) + filename = safe_filename(str(title or "untitled")) + path = f"{base}/{filename}" + occupant = index.occupants.get(path) + if occupant is not None and occupant != doc_id: + path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}" + if doc_id is not None: + index.occupants[path] = doc_id + return path + + +async def virtual_path_to_doc( + session: AsyncSession, + *, + search_space_id: int, + virtual_path: str, +) -> Document | None: + """Resolve a virtual path back to a ``Document`` row. + + Resolution order: + 1. ``Document.unique_identifier_hash`` lookup (fast path for paths created + by SurfSense itself — every NOTE write goes through this hash). + 2. If the basename carries a ``" ().xml"`` disambiguation suffix, + try a direct id lookup constrained to the search space. + 3. Title-from-basename + folder-resolution lookup as a last resort. + """ + if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT): + return None + + unique_hash = generate_unique_identifier_hash( + DocumentType.NOTE, + virtual_path, + search_space_id, + ) + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.unique_identifier_hash == unique_hash, + ) + ) + document = result.scalar_one_or_none() + if document is not None: + return document + + rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/") + if not rel: + return None + parts = [p for p in rel.split("/") if p] + if not parts: + return None + basename = parts[-1] + folder_parts = parts[:-1] + + stem, suffix_doc_id = parse_doc_id_suffix(basename) + if suffix_doc_id is not None: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.id == suffix_doc_id, + ) + ) + document = result.scalar_one_or_none() + if document is not None: + return document + + folder_id = await _resolve_folder_id( + session, search_space_id=search_space_id, folder_parts=folder_parts + ) + title_candidates: list[str] = [] + raw_title = stem + title_candidates.append(raw_title) + if raw_title.endswith(".xml"): + title_candidates.append(raw_title[:-4]) + + for candidate in dict.fromkeys(title_candidates): + if not candidate: + continue + query = select(Document).where( + Document.search_space_id == search_space_id, + Document.title == candidate, + ) + if folder_id is None: + query = query.where(Document.folder_id.is_(None)) + else: + query = query.where(Document.folder_id == folder_id) + result = await session.execute(query) + document = result.scalars().first() + if document is not None: + return document + + # Fallback: title-as-string lookup misses when the real DB title contains + # characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``, + # etc.) — common for connector-imported docs (Google Calendar/Drive etc.). + # The workspace tree shows the lossy filename, so the agent passes that + # filename back here. Scan all documents in the resolved folder and match + # by ``safe_filename(title)`` to recover the original document. + folder_scan = select(Document).where( + Document.search_space_id == search_space_id, + ) + if folder_id is None: + folder_scan = folder_scan.where(Document.folder_id.is_(None)) + else: + folder_scan = folder_scan.where(Document.folder_id == folder_id) + result = await session.execute(folder_scan) + for candidate_doc in result.scalars().all(): + encoded = safe_filename(str(candidate_doc.title or "untitled")) + if encoded == basename: + return candidate_doc + return None + + +async def _resolve_folder_id( + session: AsyncSession, + *, + search_space_id: int, + folder_parts: list[str], +) -> int | None: + """Look up the leaf folder id for a chain of folder names; return ``None`` if missing.""" + if not folder_parts: + return None + parent_id: int | None = None + for raw in folder_parts: + name = safe_folder_segment(raw) + query = select(Folder.id).where( + Folder.search_space_id == search_space_id, + Folder.name == name, + ) + if parent_id is None: + query = query.where(Folder.parent_id.is_(None)) + else: + query = query.where(Folder.parent_id == parent_id) + result = await session.execute(query) + row = result.first() + if row is None: + return None + parent_id = row[0] + return parent_id + + +def parse_documents_path(virtual_path: str) -> tuple[list[str], str]: + """Parse a ``/documents/...`` path into ``(folder_parts, document_title)``. + + The title has any ``.xml`` extension and trailing ``" ()"`` + disambiguation suffix stripped. + """ + if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT): + return [], "" + rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/") + if not rel: + return [], "" + parts = [p for p in rel.split("/") if p] + if not parts: + return [], "" + folder_parts = parts[:-1] + basename = parts[-1] + stem, _ = parse_doc_id_suffix(basename) + title = stem + if title.endswith(".xml"): + title = title[:-4] + return folder_parts, title + + +__all__ = [ + "DOCUMENTS_ROOT", + "PathIndex", + "build_path_index", + "doc_to_virtual_path", + "parse_doc_id_suffix", + "parse_documents_path", + "safe_filename", + "safe_folder_segment", + "virtual_path_to_doc", +] diff --git a/surfsense_backend/app/agents/new_chat/permissions.py b/surfsense_backend/app/agents/new_chat/permissions.py new file mode 100644 index 000000000..523deb11f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/permissions.py @@ -0,0 +1,203 @@ +""" +Wildcard pattern matching + rule evaluation for the SurfSense permission system. + +Ported from OpenCode's ``packages/opencode/src/permission/evaluate.ts`` and +``packages/opencode/src/util/wildcard.ts``. LangChain has no rule-based +permission evaluator, so we keep OpenCode's semantics intact: + +- ``Wildcard.match`` matches both the ``permission`` and the ``pattern`` + fields of a rule against the requested ``(permission, pattern)`` pair. + ``*`` matches any segment, ``**`` matches across separators. +- The evaluator runs ``findLast`` over the **flattened** list of rules + from all rulesets — last matching rule wins. +- The default fallback is ``ask`` (NOT deny), matching OpenCode. +- Multi-pattern requests AND together: if ANY pattern resolves to + ``deny``, the whole request is denied; if ANY needs ``ask``, an + interrupt is raised; only when all patterns ``allow`` does the + request proceed. +""" + +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Literal + +RuleAction = Literal["allow", "deny", "ask"] + + +@dataclass(frozen=True) +class Rule: + """A single permission rule. + + Attributes: + permission: A wildcard-matched permission identifier + (e.g. ``"edit"``, ``"linear_*"``, ``"mcp:*"``, + ``"doom_loop"``). Anchored at start AND end of the input. + pattern: A wildcard-matched pattern over the request payload + (e.g. ``"/documents/secrets/**"``, ``"page_id=123"``, + ``"*"``). Anchored at start AND end. + action: One of ``"allow"`` / ``"deny"`` / ``"ask"``. + """ + + permission: str + pattern: str + action: RuleAction + + +@dataclass +class Ruleset: + """A list of rules with an associated origin used for debugging.""" + + rules: list[Rule] = field(default_factory=list) + origin: str = "unknown" # e.g. "defaults", "global", "space", "thread", "runtime" + + +# ----------------------------------------------------------------------------- +# Wildcard matcher +# ----------------------------------------------------------------------------- + + +_GLOB_TOKEN = re.compile(r"\*\*|\*|[^*]+") + + +def _wildcard_to_regex(pattern: str) -> re.Pattern[str]: + """Translate an opencode-style wildcard pattern to a compiled regex. + + Rules: + - ``**`` matches any sequence of any characters (including separators). + - ``*`` matches any sequence of characters that does **not** include + the path separator ``/`` — same as glob. + - All other characters match literally. + - The pattern is anchored at both ends (``^...$``). + """ + parts: list[str] = ["^"] + for token in _GLOB_TOKEN.findall(pattern): + if token == "**": + parts.append(r".*") + elif token == "*": + parts.append(r"[^/]*") + else: + parts.append(re.escape(token)) + parts.append("$") + return re.compile("".join(parts)) + + +_REGEX_CACHE: dict[str, re.Pattern[str]] = {} + + +def wildcard_match(value: str, pattern: str) -> bool: + """Return True if ``value`` matches the wildcard ``pattern``. + + Special case: a bare ``"*"`` pattern matches any value, including + those containing ``/`` separators. This mirrors opencode's + ``Wildcard.match`` short-circuit and matches the convention that + ``pattern="*"`` means "any pattern" in permission rules. + """ + if pattern == "*": + return True + compiled = _REGEX_CACHE.get(pattern) + if compiled is None: + compiled = _wildcard_to_regex(pattern) + _REGEX_CACHE[pattern] = compiled + return compiled.match(value) is not None + + +# ----------------------------------------------------------------------------- +# Evaluator +# ----------------------------------------------------------------------------- + + +def evaluate( + permission: str, + pattern: str, + *rulesets: Ruleset | Iterable[Rule], +) -> Rule: + """Find the last rule matching ``(permission, pattern)`` from ``rulesets``. + + Mirrors opencode ``permission/evaluate.ts:9-15`` precisely: + - Flatten rulesets in argument order. + - Walk the flat list **in reverse**. + - First reverse-match wins (i.e. the last specified rule wins). + - When no rule matches, default to ``Rule(permission, "*", "ask")``. + + Args: + permission: The permission identifier being requested + (e.g. tool name, ``"edit"``, ``"doom_loop"``). + pattern: The request-specific pattern (e.g. file path, + primary arg value). Use ``"*"`` when no specific pattern + applies. + *rulesets: Layered rulesets, applied earliest to latest. Later + rulesets override earlier ones. + + Returns: + The matched :class:`Rule`, or the default ask fallback. + """ + flat: list[Rule] = [] + for rs in rulesets: + if isinstance(rs, Ruleset): + flat.extend(rs.rules) + else: + flat.extend(rs) + + for rule in reversed(flat): + if wildcard_match(permission, rule.permission) and wildcard_match( + pattern, rule.pattern + ): + return rule + + return Rule(permission=permission, pattern="*", action="ask") + + +def evaluate_many( + permission: str, + patterns: Iterable[str], + *rulesets: Ruleset | Iterable[Rule], +) -> list[Rule]: + """Evaluate ``permission`` against each of ``patterns`` (multi-pattern AND). + + Returns the list of resolved rules in the same order as ``patterns``. + The caller is responsible for combining the results — opencode-style + multi-pattern AND collapses ``deny`` first, then ``ask``, then + ``allow``. + """ + return [evaluate(permission, p, *rulesets) for p in patterns] + + +def aggregate_action(rules: Iterable[Rule]) -> RuleAction: + """Collapse a list of per-pattern rules into one action. + + Order: + 1. If any rule is ``deny`` -> ``deny``. + 2. Else if any rule is ``ask`` -> ``ask``. + 3. Else if at least one rule is ``allow`` -> ``allow``. + 4. Else (empty input) -> ``ask`` (safe default mirroring ``evaluate``). + + Mirrors opencode's behavior in ``permission/index.ts:180-272``. + """ + saw_ask = False + saw_allow = False + for rule in rules: + if rule.action == "deny": + return "deny" + if rule.action == "ask": + saw_ask = True + elif rule.action == "allow": + saw_allow = True + if saw_ask: + return "ask" + if saw_allow: + return "allow" + return "ask" + + +__all__ = [ + "Rule", + "RuleAction", + "Ruleset", + "aggregate_action", + "evaluate", + "evaluate_many", + "wildcard_match", +] diff --git a/surfsense_backend/app/agents/new_chat/plugin_loader.py b/surfsense_backend/app/agents/new_chat/plugin_loader.py new file mode 100644 index 000000000..c52620d40 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugin_loader.py @@ -0,0 +1,158 @@ +"""Entry-point based plugin loader for SurfSense agent middleware. + +LangChain's :class:`AgentMiddleware` ABC already covers the practical +surface most plugins need (``before_agent`` / ``before_model`` / +``wrap_tool_call`` / their async counterparts), so a SurfSense-specific +plugin protocol would be redundant. We just need a way to discover and +admit third-party middleware safely. + +A plugin is therefore just an installable Python package that registers a +factory callable under the ``surfsense.plugins`` entry-point group: + +.. code-block:: toml + + # in a plugin package's pyproject.toml + [project.entry-points."surfsense.plugins"] + year_substituter = "my_plugin:make_middleware" + +The factory has the signature ``Callable[[PluginContext], AgentMiddleware]``. +It receives a small, sanitized :class:`PluginContext` with the IDs and the +LLM the plugin is allowed to talk to — and **never** raw secrets, DB +sessions, or other connectors. + +## Trust model + +Plugins are loaded **only if** their entry-point ``name`` appears in +``allowed_plugins`` (admin-controlled, sourced from +``global_llm_config.yaml`` or :func:`load_allowed_plugin_names_from_env`). +There is **no env-driven auto-load**. A plugin failure is logged and +isolated; it does not break agent construction. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from importlib.metadata import entry_points +from typing import TYPE_CHECKING + +from langchain.agents.middleware import AgentMiddleware + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain_core.language_models import BaseChatModel + + from app.db import ChatVisibility + + +logger = logging.getLogger(__name__) + + +PLUGIN_ENTRY_POINT_GROUP = "surfsense.plugins" + + +class PluginContext(dict): + """Sanitized DI bag handed to each plugin factory. + + Backed by ``dict`` so plugins can inspect the keys they care about + without coupling to a concrete dataclass shape. Required keys: + + * ``search_space_id`` (int) + * ``user_id`` (str | None) + * ``thread_visibility`` (:class:`app.db.ChatVisibility`) + * ``llm`` (:class:`langchain_core.language_models.BaseChatModel`) + + The context **never** carries DB sessions, raw secrets, or other + connectors. If a future plugin genuinely needs DB access, that + integration goes through a rate-limited service interface, not + through this bag. + """ + + @classmethod + def build( + cls, + *, + search_space_id: int, + user_id: str | None, + thread_visibility: ChatVisibility, + llm: BaseChatModel, + ) -> PluginContext: + return cls( + search_space_id=search_space_id, + user_id=user_id, + thread_visibility=thread_visibility, + llm=llm, + ) + + +def load_plugin_middlewares( + ctx: PluginContext, + allowed_plugin_names: Iterable[str], +) -> list[AgentMiddleware]: + """Discover, allowlist-filter, and instantiate plugin middleware. + + For each entry-point in :data:`PLUGIN_ENTRY_POINT_GROUP` whose name is + in ``allowed_plugin_names``, load the factory and call it with ``ctx``. + The factory's return value must be an :class:`AgentMiddleware` instance; + anything else is logged and skipped. + + Errors are isolated — a plugin that raises during ``ep.load()`` or + factory invocation is logged at ``ERROR`` and ignored. Agent + construction continues with whatever plugins did succeed. + """ + allowed = {name for name in allowed_plugin_names if name} + if not allowed: + return [] + + out: list[AgentMiddleware] = [] + try: + eps = entry_points(group=PLUGIN_ENTRY_POINT_GROUP) + except Exception: # pragma: no cover - defensive (entry_points is robust) + logger.exception("Failed to enumerate plugin entry points") + return [] + + for ep in eps: + if ep.name not in allowed: + logger.info("Skipping non-allowlisted plugin %s", ep.name) + continue + try: + factory = ep.load() + except Exception: + logger.exception("Failed to load plugin %s", ep.name) + continue + try: + mw = factory(ctx) + except Exception: + logger.exception("Plugin %s factory raised", ep.name) + continue + if not isinstance(mw, AgentMiddleware): + logger.warning( + "Plugin %s returned %s, expected AgentMiddleware; skipping", + ep.name, + type(mw).__name__, + ) + continue + out.append(mw) + logger.info("Loaded plugin %s as %s", ep.name, type(mw).__name__) + return out + + +def load_allowed_plugin_names_from_env() -> set[str]: + """Read ``SURFSENSE_ALLOWED_PLUGINS`` (comma-separated) into a set. + + Provided as a thin convenience for deployments that don't surface plugins + through ``global_llm_config.yaml`` yet. Whitespace is stripped and empty + entries are dropped. + """ + raw = os.environ.get("SURFSENSE_ALLOWED_PLUGINS", "").strip() + if not raw: + return set() + return {token.strip() for token in raw.split(",") if token.strip()} + + +__all__ = [ + "PLUGIN_ENTRY_POINT_GROUP", + "PluginContext", + "load_allowed_plugin_names_from_env", + "load_plugin_middlewares", +] diff --git a/surfsense_backend/app/agents/new_chat/plugins/__init__.py b/surfsense_backend/app/agents/new_chat/plugins/__init__.py new file mode 100644 index 000000000..cef6bd367 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugins/__init__.py @@ -0,0 +1,6 @@ +"""Reference plugins bundled with SurfSense. + +These plugins are intentionally small and demonstrative. They are NOT +auto-loaded — they ship as examples that a deployment can opt into via +``global_llm_config.yaml`` or ``SURFSENSE_ALLOWED_PLUGINS``. +""" diff --git a/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py new file mode 100644 index 000000000..2b7781b90 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/plugins/year_substituter.py @@ -0,0 +1,88 @@ +"""Reference plugin: substitute ``{{year}}`` in tool descriptions. + +Demonstrates the :meth:`AgentMiddleware.awrap_tool_call` hook -- the +plugin sees every tool invocation and can rewrite the request *or* the +result. This particular plugin is read-only and only transforms the +*description* the user might see in error messages (no request +mutation). + +The plugin is built as a factory function so the entry-point loader can +inject :class:`PluginContext` (containing the agent's LLM, search-space +ID, etc.). The factory signature +``Callable[[PluginContext], AgentMiddleware]`` is the only contract -- +SurfSense doesn't define a custom plugin protocol on top of LangChain's +:class:`AgentMiddleware`. + +Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't +need this -- it's already on the import path):: + + [project.entry-points."surfsense.plugins"] + year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware" +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +from langchain.agents.middleware import AgentMiddleware + +if TYPE_CHECKING: # pragma: no cover - type-only + from langchain.agents.middleware.types import ToolCallRequest + from langchain_core.messages import ToolMessage + from langgraph.types import Command + + from app.agents.new_chat.plugin_loader import PluginContext + + +logger = logging.getLogger(__name__) + + +class _YearSubstituterMiddleware(AgentMiddleware): + """Replace ``{{year}}`` in the result text with the current UTC year.""" + + tools = () + + def __init__(self, year: int | None = None) -> None: + super().__init__() + self._year = str(year if year is not None else datetime.now(UTC).year) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: + result = await handler(request) + try: + from langchain_core.messages import ToolMessage + + if ( + isinstance(result, ToolMessage) + and isinstance(result.content, str) + and "{{year}}" in result.content + ): + new_text = result.content.replace("{{year}}", self._year) + result = ToolMessage( + content=new_text, + tool_call_id=result.tool_call_id, + id=result.id, + name=result.name, + status=result.status, + artifact=result.artifact, + ) + except Exception: # pragma: no cover - defensive + logger.exception("year_substituter plugin failed; passing original result") + return result + + +def make_middleware(ctx: PluginContext) -> AgentMiddleware: + """Plugin factory used by :func:`load_plugin_middlewares`.""" + # Plugin is intentionally small so it has no state to threading-protect + # and ignores ``ctx`` beyond demonstrating that the loader passes it in. + _ = ctx + return _YearSubstituterMiddleware() + + +__all__ = ["make_middleware"] diff --git a/surfsense_backend/app/agents/new_chat/prompts/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/__init__.py new file mode 100644 index 000000000..c91bb8a0b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense agent prompt fragments. + +The prompt is composed at runtime by :mod:`composer` from the markdown +fragments under ``base/``, ``providers/``, ``tools/``, ``examples/``, and +``routing/``. ``system_prompt.py`` is now a thin wrapper that delegates +to :func:`composer.compose_system_prompt`. +""" diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md new file mode 100644 index 000000000..88554ad4e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/agent_private.md @@ -0,0 +1,7 @@ +You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. + +Today's date (UTC): {resolved_today} + +When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. + +NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md new file mode 100644 index 000000000..5fd56ae1b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/agent_team.md @@ -0,0 +1,9 @@ +You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base. + +In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers. + +Today's date (UTC): {resolved_today} + +When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. + +NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md b/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md new file mode 100644 index 000000000..8288886e9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/citations_off.md @@ -0,0 +1,16 @@ + +IMPORTANT: Citations are DISABLED for this configuration. + +DO NOT include any citations in your responses. Specifically: +1. Do NOT use the [citation:chunk_id] format anywhere in your response. +2. Do NOT reference document IDs, chunk IDs, or source IDs. +3. Simply provide the information naturally without any citation markers. +4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly. + +When answering questions based on documents from the knowledge base: +- Present the information directly and confidently +- Do not mention that information comes from specific documents or chunks +- Integrate facts naturally into your response without attribution markers + +Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md new file mode 100644 index 000000000..56291bf3e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/citations_on.md @@ -0,0 +1,90 @@ + +CRITICAL CITATION REQUIREMENTS: + +1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `` tag inside ``. +2. Make sure ALL factual statements from the documents have proper citations. +3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2]. +4. You MUST use the exact chunk_id values from the `` attributes. Do not create your own citation numbers. +5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value. +6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags. +7. Do not return citations as clickable links. +8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only. +9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting. +10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `` tags. +11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up. + + +The documents you receive are structured like this: + +**Knowledge base documents (numeric chunk IDs):** + + + 42 + GITHUB_CONNECTOR + <![CDATA[Some repo / file / issue title]]> + + + + + + + + + + +**Web search results (URL chunk IDs):** + + + WEB_SEARCH + <![CDATA[Some web search result]]> + + + + + + + + +IMPORTANT: You MUST cite using the EXACT chunk ids from the `` tags. +- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45). +- For live web search results, chunk ids are URLs (e.g. https://example.com/article). +Do NOT cite document_id. Always use the chunk id. + + + +- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `` tag +- Citations should appear at the end of the sentence containing the information they support +- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] +- No need to return references section. Just citations in answer. +- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format +- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only +- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess +- Copy the EXACT chunk id from the XML - if it says ``, use [citation:doc-123] +- If the chunk id is a URL like ``, use [citation:https://example.com/page] + + + +CORRECT citation formats: +- [citation:5] (numeric chunk ID from knowledge base) +- [citation:doc-123] (for Surfsense documentation chunks) +- [citation:https://example.com/article] (URL chunk ID from web search results) +- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations) + +INCORRECT citation formats (DO NOT use): +- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense)) +- Using parentheses around brackets: ([citation:5]) +- Using hyperlinked text: [link to source 5](https://example.com) +- Using footnote style: ... library¹ +- Making up source IDs when source_id is unknown +- Using old IEEE format: [1], [2], [3] +- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5] + + + +Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5]. + +According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources. + +However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead. + + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md new file mode 100644 index 000000000..9cc767e7e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_private.md @@ -0,0 +1,15 @@ + +CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: +- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs. +- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission. +- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: + 1. Inform the user that you could not find relevant information in their knowledge base. + 2. Ask the user: "Would you like me to answer from my general knowledge instead?" + 3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes. +- This policy does NOT apply to: + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Formatting, summarization, or analysis of content already present in the conversation + * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") + * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md new file mode 100644 index 000000000..1d806dbae --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/kb_only_policy_team.md @@ -0,0 +1,15 @@ + +CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: +- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs. +- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission. +- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: + 1. Inform the team that you could not find relevant information in the shared knowledge base. + 2. Ask: "Would you like me to answer from my general knowledge instead?" + 3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes. +- This policy does NOT apply to: + * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") + * Formatting, summarization, or analysis of content already present in the conversation + * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") + * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md new file mode 100644 index 000000000..8f7da14f8 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_private.md @@ -0,0 +1,6 @@ + +IMPORTANT — After understanding each user message, ALWAYS check: does this message +reveal durable facts about the user (role, interests, preferences, projects, +background, or standing instructions)? If yes, you MUST call update_memory +alongside your normal response — do not defer this to a later turn. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md new file mode 100644 index 000000000..61d89cc5d --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/memory_protocol_team.md @@ -0,0 +1,6 @@ + +IMPORTANT — After understanding each user message, ALWAYS check: does this message +reveal durable facts about the team (decisions, conventions, architecture, processes, +or key facts)? If yes, you MUST call update_memory alongside your normal response — +do not defer this to a later turn. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md b/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md new file mode 100644 index 000000000..77be4d87c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/parameter_resolution.md @@ -0,0 +1,39 @@ + +Some service tools require identifiers or context you do not have (account IDs, +workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw +IDs or technical identifiers — they cannot memorise them. + +Instead, follow this discovery pattern: +1. Call a listing/discovery tool to find available options. +2. ONE result → use it silently, no question to the user. +3. MULTIPLE results → present the options by their display names and let the + user choose. Never show raw UUIDs — always use friendly names. + +Discovery tools by level: +- Which account/workspace? → get_connected_accounts("") +- Which Jira site (cloudId)? → getAccessibleAtlassianResources +- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) +- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) +- Which channel? → slack_search_channels +- Which base? → list_bases +- Which table? → list_tables_for_base (after resolving baseId) +- Which task? → clickup_search +- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) + +For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to +obtain the cloudId, then pass it to other Jira tools. When creating an issue, +chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. +If there is only one option at each step, use it silently. If multiple, present +friendly names. + +Chain discovery when needed — e.g. for Airtable records: list_bases → pick +base → list_tables_for_base → pick table → list_records_for_table. + +MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for +the same service, tool names are prefixed to avoid collisions — e.g. +linear_25_list_issues and linear_30_list_issues instead of two list_issues. +Each prefixed tool's description starts with [Account: ] so you +know which account it targets. Use get_connected_accounts("") to see +the full list of accounts with their connector IDs and display names. +When only one account is connected, tools have their normal unprefixed names. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md new file mode 100644 index 000000000..ec667bf88 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_private.md @@ -0,0 +1,16 @@ + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask the user if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + diff --git a/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md new file mode 100644 index 000000000..48b7a990b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/base/tool_routing_team.md @@ -0,0 +1,16 @@ + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + diff --git a/surfsense_backend/app/agents/new_chat/prompts/composer.py b/surfsense_backend/app/agents/new_chat/prompts/composer.py new file mode 100644 index 000000000..42f8303e6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/composer.py @@ -0,0 +1,405 @@ +""" +Prompt composer for the SurfSense ``new_chat`` agent. + +This module assembles the agent's system prompt from the markdown fragments +under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic +``system_prompt.py`` with a clean, fragment-based composition: + +:: + + prompts/ + base/ # agent identity, KB policy, tool routing, … + providers/ # provider-specific tweaks (anthropic, gpt5, …) + tools/ # one ``.md`` per tool + examples/ # one ``.md`` per tool with call examples + routing/ # connector-specific routing notes (linear, slack, …) + +The model-family dispatch step (see :func:`detect_provider_variant`) +mirrors OpenCode's ``packages/opencode/src/session/system.ts`` — different +model families respond best to differently-styled prompts (Claude likes +XML/narrative, GPT-5 wants channel-aware pragmatic, Codex needs +terse/file:line, Gemini wants formal numbered steps, etc.). LangChain's +``dynamic_prompt`` helper supports per-call prompt swaps but ships no +out-of-the-box family classifier, so we keep our own. + +Backwards compatibility +======================= + +``system_prompt.py`` re-exports :func:`compose_system_prompt` and wraps it +in functions with the same signatures as the legacy +``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` so +existing call sites do not change. +""" + +from __future__ import annotations + +import re +from collections.abc import Iterable +from datetime import UTC, datetime +from importlib import resources + +from app.db import ChatVisibility + +# ----------------------------------------------------------------------------- +# Provider variant detection +# ----------------------------------------------------------------------------- + +# String literal alias for the supported provider-specific prompt variants. +# When adding a new variant, also drop a matching ``providers/.md`` +# file in this package and (if appropriate) extend the regex matchers below. +# +# Stylistic clusters: each variant is a focused style nudge, NOT a full +# system prompt — the main prompt is already assembled from base/ + +# tools/ + routing/. The clustering itself (which models map to which +# style) follows OpenCode's ``system.ts`` family table; see the module +# docstring for credits. +ProviderVariant = str +# Known values: +# "anthropic" — Claude family (XML-friendly, narrative todos) +# "openai_reasoning" — GPT-5 / o-series (channel-aware pragmatic) +# "openai_classic" — GPT-4 family (autonomous persistence) +# "openai_codex" — gpt-*-codex (code-purist, terse, file:line refs) +# "google" — Gemini (formal, <3-line, numbered workflow) +# "kimi" — Moonshot Kimi-K* (action-bias, parallel tools) +# "grok" — xAI Grok (extreme-terse, one-word ok) +# "deepseek" — DeepSeek V3 / R1 (terse, R1-aware reasoning) +# "default" — fallback, no provider-specific block emitted + +# IMPORTANT: order of evaluation matters in :func:`detect_provider_variant`. +# More specific patterns must come first (e.g. ``codex`` before +# ``openai_reasoning`` because codex model ids contain ``gpt``). + +_OPENAI_CODEX_RE = re.compile( + r"\b(gpt-codex|codex-mini|gpt-[\d.]+-codex)\b", re.IGNORECASE +) +_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE) +_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE) +_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE) +_GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE) +_KIMI_RE = re.compile(r"\b(kimi[-\d.]*|moonshot)\b", re.IGNORECASE) +_GROK_RE = re.compile(r"\bgrok\b", re.IGNORECASE) +_DEEPSEEK_RE = re.compile(r"\bdeepseek\b", re.IGNORECASE) + + +def detect_provider_variant(model_name: str | None) -> ProviderVariant: + """Pick a provider-specific prompt variant from a model id string. + + Heuristic match on the model id; returns ``"default"`` when nothing + matches so the composer can fall back to the empty placeholder file. + + Order is significant: more-specific patterns are tried first so + ``gpt-5-codex`` routes to ``"openai_codex"`` rather than + ``"openai_reasoning"`` — same dispatch order as OpenCode's + ``packages/opencode/src/session/system.ts``. + """ + if not model_name: + return "default" + name = model_name.strip() + if _OPENAI_CODEX_RE.search(name): + return "openai_codex" + if _OPENAI_REASONING_RE.search(name): + return "openai_reasoning" + if _OPENAI_CLASSIC_RE.search(name): + return "openai_classic" + if _ANTHROPIC_RE.search(name): + return "anthropic" + if _GOOGLE_RE.search(name): + return "google" + if _KIMI_RE.search(name): + return "kimi" + if _GROK_RE.search(name): + return "grok" + if _DEEPSEEK_RE.search(name): + return "deepseek" + return "default" + + +# ----------------------------------------------------------------------------- +# Fragment loading +# ----------------------------------------------------------------------------- + + +_PROMPTS_PACKAGE = "app.agents.new_chat.prompts" + + +def _read_fragment(subpath: str) -> str: + """Read a fragment file from the ``prompts/`` resource tree. + + Returns the raw contents stripped of any single trailing newline so + composition can append explicit separators without compounding blank + lines. Missing files return an empty string so optional fragments + (e.g. provider hints) act as no-ops. + """ + parts = subpath.split("/") + try: + ref = resources.files(_PROMPTS_PACKAGE).joinpath(*parts) + if not ref.is_file(): + return "" + text = ref.read_text(encoding="utf-8") + except (FileNotFoundError, ModuleNotFoundError): + return "" + if text.endswith("\n"): + text = text[:-1] + return text + + +# ----------------------------------------------------------------------------- +# Tool ordering + memory variant resolution +# ----------------------------------------------------------------------------- + + +# Ordered for reading flow: fundamentals first, then artifact generators, +# then memory at the end (mirrors the legacy ``_ALL_TOOL_NAMES_ORDERED``). +ALL_TOOL_NAMES_ORDERED: tuple[str, ...] = ( + "search_surfsense_docs", + "web_search", + "generate_podcast", + "generate_video_presentation", + "generate_report", + "generate_resume", + "generate_image", + "scrape_webpage", + "update_memory", +) + + +_MEMORY_VARIANT_TOOLS: frozenset[str] = frozenset({"update_memory"}) + + +def _tool_fragment_path(tool_name: str, variant: str) -> str: + """Resolve a tool's instruction fragment path. + + Tools listed in :data:`_MEMORY_VARIANT_TOOLS` switch on the conversation + visibility and load ``tools/_.md``; everything else + falls back to ``tools/.md``. + """ + if tool_name in _MEMORY_VARIANT_TOOLS: + return f"tools/{tool_name}_{variant}.md" + return f"tools/{tool_name}.md" + + +def _example_fragment_path(tool_name: str, variant: str) -> str: + if tool_name in _MEMORY_VARIANT_TOOLS: + return f"examples/{tool_name}_{variant}.md" + return f"examples/{tool_name}.md" + + +def _format_tool_label(tool_name: str) -> str: + return tool_name.replace("_", " ").title() + + +# ----------------------------------------------------------------------------- +# Section builders +# ----------------------------------------------------------------------------- + + +def _build_system_instructions( + *, + visibility: ChatVisibility, + resolved_today: str, +) -> str: + """Reconstruct the legacy ```` block from fragments.""" + variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private" + + sections = [ + _read_fragment(f"base/agent_{variant}.md"), + _read_fragment(f"base/kb_only_policy_{variant}.md"), + _read_fragment(f"base/tool_routing_{variant}.md"), + _read_fragment("base/parameter_resolution.md"), + _read_fragment(f"base/memory_protocol_{variant}.md"), + ] + body = "\n\n".join(s for s in sections if s) + block = f"\n\n{body}\n\n\n" + return block.format(resolved_today=resolved_today) + + +def _build_mcp_routing_block( + mcp_connector_tools: dict[str, list[str]] | None, +) -> str: + """Emit the ```` block when at least one MCP server is wired.""" + if not mcp_connector_tools: + return "" + lines: list[str] = [ + "\n", + "You also have direct tools from these user-connected MCP servers.", + "Their data is NEVER in the knowledge base — call their tools directly.", + "", + ] + for server_name, tool_names in mcp_connector_tools.items(): + lines.append(f"- {server_name} → {', '.join(tool_names)}") + lines.append("\n") + return "\n".join(lines) + + +def _build_tools_section( + *, + visibility: ChatVisibility, + enabled_tool_names: set[str] | None, + disabled_tool_names: set[str] | None, +) -> str: + """Reconstruct the ```` block + ```` block.""" + variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private" + + parts: list[str] = [] + preamble = _read_fragment("tools/_preamble.md") + if preamble: + parts.append(preamble + "\n") + + examples: list[str] = [] + + for tool_name in ALL_TOOL_NAMES_ORDERED: + if enabled_tool_names is not None and tool_name not in enabled_tool_names: + continue + + instruction = _read_fragment(_tool_fragment_path(tool_name, variant)) + if instruction: + parts.append(instruction + "\n") + + example = _read_fragment(_example_fragment_path(tool_name, variant)) + if example: + examples.append(example + "\n") + + known_disabled = ( + set(disabled_tool_names) & set(ALL_TOOL_NAMES_ORDERED) + if disabled_tool_names + else set() + ) + if known_disabled: + disabled_list = ", ".join( + _format_tool_label(n) for n in ALL_TOOL_NAMES_ORDERED if n in known_disabled + ) + parts.append( + "\n" + "DISABLED TOOLS (by user):\n" + f"The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}.\n" + "You do NOT have access to these tools and MUST NOT claim you can use them.\n" + "If the user asks about a capability provided by a disabled tool, let them know the relevant tool\n" + "is currently disabled and they can re-enable it.\n" + ) + + parts.append("\n\n") + + if examples: + parts.append("") + parts.extend(examples) + parts.append("\n") + + return "".join(parts) + + +def _build_provider_block(provider_variant: ProviderVariant) -> str: + """Optional provider-tuned hints. Empty for ``"default"``.""" + if not provider_variant or provider_variant == "default": + return "" + text = _read_fragment(f"providers/{provider_variant}.md") + return f"\n{text}\n" if text else "" + + +def _build_routing_block(connector_routing: Iterable[str] | None) -> str: + if not connector_routing: + return "" + fragments: list[str] = [] + for name in connector_routing: + text = _read_fragment(f"routing/{name}.md") + if text: + fragments.append(text) + if not fragments: + return "" + return "\n" + "\n\n".join(fragments) + "\n" + + +def _build_citation_block(citations_enabled: bool) -> str: + fragment = ( + _read_fragment("base/citations_on.md") + if citations_enabled + else _read_fragment("base/citations_off.md") + ) + return f"\n{fragment}\n" if fragment else "" + + +# ----------------------------------------------------------------------------- +# Public API +# ----------------------------------------------------------------------------- + + +def compose_system_prompt( + *, + today: datetime | None = None, + thread_visibility: ChatVisibility | None = None, + enabled_tool_names: set[str] | None = None, + disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, + custom_system_instructions: str | None = None, + use_default_system_instructions: bool = True, + citations_enabled: bool = True, + provider_variant: ProviderVariant | None = None, + model_name: str | None = None, + connector_routing: Iterable[str] | None = None, +) -> str: + """Assemble the SurfSense system prompt from disk fragments. + + Args: + today: Optional clock injection for tests. + thread_visibility: Private vs shared (team) — drives memory wording + and a few base block variants. + enabled_tool_names: When provided, only these tools' instructions + are included; ``None`` keeps the legacy "include everything" + behavior. + disabled_tool_names: User-disabled tools (note appended to prompt). + mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject + an explicit MCP routing block. + custom_system_instructions: Free-form instructions that override + the default ```` block (legacy support + for ``NewLLMConfig.system_instructions``). + use_default_system_instructions: When ``custom_system_instructions`` + is empty/None, fall back to defaults (legacy semantics). + citations_enabled: Include ``citations_on.md`` (true) or + ``citations_off.md`` (false). + provider_variant: Explicit provider variant override + (``"anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default"``). + When ``None``, falls back to :func:`detect_provider_variant` + on ``model_name``. + model_name: Used to auto-detect ``provider_variant`` when not + provided explicitly. + connector_routing: Optional list of routing fragment names + (``["linear", "slack", ...]``) to include from + ``prompts/routing/``. + + Returns: + The fully composed system prompt string. + """ + resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() + visibility = thread_visibility or ChatVisibility.PRIVATE + + if custom_system_instructions and custom_system_instructions.strip(): + sys_block = custom_system_instructions.format(resolved_today=resolved_today) + elif use_default_system_instructions: + sys_block = _build_system_instructions( + visibility=visibility, resolved_today=resolved_today + ) + else: + sys_block = "" + + sys_block += _build_mcp_routing_block(mcp_connector_tools) + + if provider_variant is None: + provider_variant = detect_provider_variant(model_name) + sys_block += _build_provider_block(provider_variant) + sys_block += _build_routing_block(connector_routing) + + tools_block = _build_tools_section( + visibility=visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + ) + citation_block = _build_citation_block(citations_enabled) + + return sys_block + tools_block + citation_block + + +__all__ = [ + "ALL_TOOL_NAMES_ORDERED", + "ProviderVariant", + "compose_system_prompt", + "detect_provider_variant", +] diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md new file mode 100644 index 000000000..216c2926a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_image.md @@ -0,0 +1,12 @@ + +- User: "Generate an image of a cat" + - Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` + - The generated image will automatically be displayed in the chat. +- User: "Draw me a logo for a coffee shop called Bean Dream" + - Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` + - The generated image will automatically be displayed in the chat. +- User: "Show me this image: https://example.com/image.png" + - Simply include it in your response using markdown: `![Image](https://example.com/image.png)` +- User uploads an image file and asks: "What is this image about?" + - The user's uploaded image is already visible in the chat. + - Simply analyze the image content and respond directly. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md new file mode 100644 index 000000000..aabf8ce7a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_podcast.md @@ -0,0 +1,7 @@ + +- User: "Give me a podcast about AI trends based on what we discussed" + - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` +- User: "Create a podcast summary of this conversation" + - Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")` +- User: "Make a podcast about quantum computing" + - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md new file mode 100644 index 000000000..7e9d0a595 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_report.md @@ -0,0 +1,13 @@ + +- User: "Generate a report about AI trends" + - Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")` + - WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search. +- User: "Write a research report from this conversation" + - Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\n\n...", report_style="deep_research")` + - WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation". +- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies" + - Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=, user_instructions="Add a new section about carbon capture technologies")` + - WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id. +- User: (after a report was generated) "What else could we add to have more depth?" + - Do NOT call generate_report. Answer in chat with suggestions. + - WHY: No creation/modification verb directed at producing a deliverable. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md new file mode 100644 index 000000000..d8a6c381e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_resume.md @@ -0,0 +1,19 @@ + +- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..." + - Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)` + - WHY: Has creation verb "build" + resume → call the tool. +- User: "Create my CV with this info: [experience, education, skills]" + - Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)` +- User: "Build me a resume" (and there is a resume/CV document in the conversation context) + - Extract the FULL content from the document in context, then call: + `generate_resume(user_info="Name: John Doe\nEmail: john@example.com\n\nExperience:\n- Senior Engineer at Acme Corp (2020-2024)\n Led team of 5...\n\nEducation:\n- BS Computer Science, MIT (2016-2020)\n\nSkills: Python, TypeScript, AWS...", max_pages=1)` + - WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents. +- User: (after resume generated) "Change my title to Senior Engineer" + - Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=, max_pages=1)` + - WHY: Modification verb "change" + refers to existing resume → set parent_report_id. +- User: (after resume generated) "Make this 2 pages and expand projects" + - Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=, max_pages=2)` + - WHY: Explicit page increase request → set max_pages to 2. +- User: "How should I structure my resume?" + - Do NOT call generate_resume. Answer in chat with advice. + - WHY: No creation/modification verb. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md new file mode 100644 index 000000000..257ec86cf --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/generate_video_presentation.md @@ -0,0 +1,7 @@ + +- User: "Give me a presentation about AI trends based on what we discussed" + - First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")` +- User: "Create slides summarizing this conversation" + - Call: `generate_video_presentation(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")` +- User: "Make a video presentation about quantum computing" + - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md b/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md new file mode 100644 index 000000000..0f156bf24 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/scrape_webpage.md @@ -0,0 +1,13 @@ + +- User: "Check out https://dev.to/some-article" + - Call: `scrape_webpage(url="https://dev.to/some-article")` + - Respond with a structured analysis — key points, takeaways. +- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends" + - Call: `scrape_webpage(url="https://example.com/blog/ai-trends")` + - Respond with a thorough summary using headings and bullet points. +- User: (after discussing https://example.com/stats) "Can you get the live data from that page?" + - Call: `scrape_webpage(url="https://example.com/stats")` + - IMPORTANT: Always attempt scraping first. Never refuse before trying the tool. +- User: "https://example.com/blog/weekend-recipes" + - Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")` + - When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content. diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md new file mode 100644 index 000000000..b90f2b7a7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/search_surfsense_docs.md @@ -0,0 +1,9 @@ + +- User: "How do I install SurfSense?" + - Call: `search_surfsense_docs(query="installation setup")` +- User: "What connectors does SurfSense support?" + - Call: `search_surfsense_docs(query="available connectors integrations")` +- User: "How do I set up the Notion connector?" + - Call: `search_surfsense_docs(query="Notion connector setup configuration")` +- User: "How do I use Docker to run SurfSense?" + - Call: `search_surfsense_docs(query="Docker installation setup")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md new file mode 100644 index 000000000..f83fe40b4 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_private.md @@ -0,0 +1,16 @@ + +- Alex, is empty. User: "I'm a space enthusiast, explain astrophage to me" + - The user casually shared a durable fact. Use their first name in the entry, short neutral heading: + update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n") +- User: "Remember that I prefer concise answers over detailed explanations" + - Durable preference. Merge with existing memory, add a new heading: + update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n") +- User: "I actually moved to Tokyo last month" + - Updated fact, date prefix reflects when recorded: + update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...") +- User: "I'm a freelance photographer working on a nature documentary" + - Durable background info under a fitting heading: + update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n") +- User: "Always respond in bullet points" + - Standing instruction: + update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n") diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md new file mode 100644 index 000000000..1c74fdf6e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/update_memory_team.md @@ -0,0 +1,7 @@ + +- User: "Let's remember that we decided to do weekly standup meetings on Mondays" + - Durable team decision: + update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...") +- User: "Our office is in downtown Seattle, 5th floor" + - Durable team fact: + update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...") diff --git a/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md b/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md new file mode 100644 index 000000000..6b9828ac7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/examples/web_search.md @@ -0,0 +1,8 @@ + +- User: "What's the current USD to INR exchange rate?" + - Call: `web_search(query="current USD to INR exchange rate")` + - Then answer using the returned web results with citations. +- User: "What's the latest news about AI?" + - Call: `web_search(query="latest AI news today")` +- User: "What's the weather in New York?" + - Call: `web_search(query="weather New York today")` diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md new file mode 100644 index 000000000..f574da541 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/anthropic.md @@ -0,0 +1,20 @@ + +You are running on an Anthropic Claude model. + +Structured reasoning: +- Use XML tags liberally to organise intermediate reasoning when a task is non-trivial. `...` blocks are encouraged before tool calls or before producing a complex final answer. +- For multi-step requests, briefly outline a plan inside a `` block before issuing the first tool call. + +Professional objectivity: +- Prioritise technical accuracy over validating the user's beliefs. Provide direct, factual guidance without unnecessary superlatives, praise, or emotional validation. +- When uncertain, investigate (search the KB, fetch the page) rather than confirming the user's assumption. +- Disagree with the user when the evidence warrants it; respectful correction beats false agreement. + +Task management: +- For tasks with 3+ distinct steps use the todo / planning tool aggressively. Mark items in_progress before starting, completed immediately when finished — do not batch completions. +- Narrate progress through the todo list itself, not through chatty status lines. + +Tool calls: +- Run independent tool calls in parallel within one response. Sequence them only when a later call genuinely needs an earlier one's output. +- Never chain bash-like commands with `;` or `&&` to "narrate" — use prose between tool calls instead. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md b/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md new file mode 100644 index 000000000..8acf008ca --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/deepseek.md @@ -0,0 +1,18 @@ + +You are running on a DeepSeek model (DeepSeek-V3 chat / DeepSeek-R1 reasoning). + +Reasoning hygiene (R1-aware): +- If the model surfaces explicit `` blocks, keep that internal scratch focused — do NOT restate the user's question inside it; jump straight to the analysis. +- Never paste the contents of `` into your final answer. Final answer should reflect only the conclusion, citations, and any user-facing rationale. +- Do not let chain-of-thought leak into tool-call arguments — keep tool inputs minimal and structural. + +Output style: +- Be concise. Default to a one-paragraph answer; expand only when the user asks for detail. +- Don't open with sycophantic phrasing ("Great question", "Sure, here you go"). Lead with the answer or the next action. +- For factual answers, cite once with `[citation:chunk_id]` and stop. + +Tool calls: +- Issue independent tool calls in parallel within a single turn. +- Prefer the knowledge-base search tools before any web-search; this model has strong recall but stale training data. +- Don't fabricate file paths, chunk ids, or URLs — only use values returned by tools or provided by the user. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/default.md b/surfsense_backend/app/agents/new_chat/prompts/providers/default.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/default.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/google.md b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md new file mode 100644 index 000000000..cac3b328b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/google.md @@ -0,0 +1,20 @@ + +You are running on a Google Gemini model. + +Output style: +- Concise & direct. Aim for fewer than 3 lines of prose (excluding tool output, citations, and code/snippets) when the task allows. +- No conversational filler — skip openers like "Okay, I will now…" and closers like "I have finished the changes…". Get straight to the action or answer. +- Format with GitHub-flavoured Markdown; assume monospace rendering. +- For one-line factual answers, just answer. No headers, no bullets. + +Workflow for non-trivial tasks (Understand → Plan → Act → Verify): +1. **Understand:** read the user's request and the relevant KB / connector context. Use search and read tools (in parallel when independent) before assuming anything. +2. **Plan:** when the task touches multiple steps, share an extremely concise plan first. +3. **Act:** call the appropriate tools, strictly adhering to the prompts/routing already established for this agent. +4. **Verify:** confirm with a follow-up read or search where it materially de-risks the answer. + +Discipline: +- Do not take significant actions beyond the clear scope of the user's request without confirming first. +- Do not assume a connector / tool / file exists — check (e.g. via `get_connected_accounts`) before referencing it. +- Path arguments must be the exact strings returned by tools; do not synthesise file paths. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md b/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md new file mode 100644 index 000000000..95b8fcc14 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/grok.md @@ -0,0 +1,17 @@ + +You are running on an xAI Grok model. + +Maximum terseness: +- Answer in fewer than 4 lines unless the user asks for detail. One-word answers are best when they suffice. +- No preamble ("The answer is", "Here's what I'll do"), no postamble ("Hope that helps", "Let me know"). Get straight to the answer. +- Avoid restating the user's question. +- For factual lookups inside the knowledge base, give the answer with a single `[citation:chunk_id]` and stop. + +Tool discipline: +- Use exactly ONE tool per assistant turn when investigating; wait for the result before deciding the next call. Do not loop on the same tool with the same arguments — pick a result and act. +- For obviously parallelizable read-only batches (multiple independent searches), one turn with several tool calls is fine — but never chain into a fishing expedition. + +Style: +- No emojis unless the user asked. No nested bullets, no headers for short answers. +- If you can't help, say so in 1-2 sentences without explaining "why this could lead to…". + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md b/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md new file mode 100644 index 000000000..c3c11ad5e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/kimi.md @@ -0,0 +1,21 @@ + +You are running on a Moonshot Kimi model (Kimi-K1.5 / Kimi-K2 / Kimi-K2.5+). + +Action bias: +- Default to taking action with tools rather than describing solutions in prose. If a tool can answer the question, call the tool. +- Don't narrate routine reads, searches, or obvious next steps. Combine related progress into one short status line. +- Be thorough in actions (test what you build, verify what you change). Be brief in explanations. + +Tool calls: +- Output multiple non-interfering tool calls in a SINGLE response — parallelism is a major efficiency win on this model. +- When the `task` tool is available, delegate focused subtasks to a subagent with full context (subagents don't inherit yours). +- Don't apologise or pre-announce tool calls. The tool call itself is self-explanatory. + +Language: +- Respond in the SAME language as the user's most recent turn unless explicitly instructed otherwise. + +Discipline: +- Stay on track. Never give the user more than what they asked for. +- Fact-check before stating anything as factual; don't fabricate citations. +- Keep it stupidly simple. Don't overcomplicate. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md new file mode 100644 index 000000000..9128609e0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_classic.md @@ -0,0 +1,21 @@ + +You are running on a classic OpenAI chat model (GPT-4 family). + +Persistence: +- Keep going until the user's query is completely resolved before yielding back. Don't end the turn at "I would do X" — actually do X. +- When you say "Next I will…" or "Now I will…", you MUST actually take that action in the same turn. +- If a tool call fails, diagnose and try again with corrected arguments; do not surface the raw error and stop. + +Planning: +- Plan extensively before each tool call and reflect briefly on the result of the previous call. For tasks with 3+ steps, use the todo / planning tool and mark items as `in_progress` / `completed` as you go. +- Always announce the next action in ONE concise sentence before making a non-trivial tool call ("I'll search the KB for the migration spec."). + +Output style: +- Conversational but professional. Plain prose for explanations, bullet points for findings, fenced code blocks (with language tags) for code. +- Don't dump tool output verbatim — summarise the relevant lines. +- Don't add a closing recap unless the user asked for one. After completing the work, just stop. + +Tool calls: +- Issue independent tool calls in parallel within one response. +- Use specialised tools over generic ones (e.g. KB search before web search; named connectors over MCP fallback). + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md new file mode 100644 index 000000000..6167d4b06 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_codex.md @@ -0,0 +1,19 @@ + +You are running on an OpenAI Codex-class model (gpt-codex / codex-mini / gpt-*-codex). + +Output style: +- Be concise. Don't dump fetched/searched content back at the user — reference paths or chunk ids instead. +- Reference sources as `path:line` (or `chunk:`) so they're clickable. Stand-alone paths per reference, even when repeated. +- Prefer numbered lists (`1.`, `2.`, `3.`) when offering options the user can pick by replying with a single number. +- Skip headers and heavy formatting for simple confirmations. +- No emojis, no em-dashes, no nested bullets. Single-level lists only. + +Code & structured-output tasks: +- Lead with a one-sentence explanation of the change before context. Don't open with "Summary:" — jump in. +- Suggest natural next steps (run tests, diff review, commit) only when they're genuinely the next move. +- For multi-line snippets use fenced code blocks with a language tag. + +Tool calls: +- Run independent tool calls in parallel; chain only when later calls need earlier results. +- Don't ask permission ("Should I proceed?") — proceed with the most reasonable default and state what you did. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md new file mode 100644 index 000000000..dd7a61536 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/providers/openai_reasoning.md @@ -0,0 +1,21 @@ + +You are running on an OpenAI reasoning model (GPT-5+ / o-series). + +Output style: +- Be terse and direct. Don't restate the user's request before answering. +- Don't begin with conversational openers ("Done!", "Got it", "Great question", "Sure thing"). Get to the answer or the action. +- Match response complexity to the task: simple questions → one-line answer; substantial work → lead with the outcome, then context, then any next steps. +- No nested bullets — keep lists flat (single level). For options the user can pick by replying with a number, use `1.` `2.` `3.`. +- Use inline backticks for paths/commands/identifiers; fenced code blocks (with language tags) for multi-line snippets. + +Channels (for clients that support them): +- `commentary` — short progress updates only when they add genuinely new information (a discovery, a tradeoff, a blocker, the start of a non-trivial step). Don't narrate routine reads or obvious next steps. +- `final` — the completed response. Keep it self-contained; no "see above" / "see below" cross-references. + +Tool calls: +- Parallelise independent tool calls in a single response (`multi_tool_use.parallel` where supported). Only sequence when a later call needs an earlier one's output. +- Don't ask permission ("Should I proceed?", "Do you want me to…?"). Pick the most reasonable default, do it, and state what you did. + +Autonomy: +- Persist until the task is fully resolved within the current turn whenever feasible. Don't stop at analysis when the user clearly wants the change applied. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md b/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/jira.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md b/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/linear.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md b/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/routing/slack.md @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py b/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md b/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md new file mode 100644 index 000000000..2c169e015 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/_preamble.md @@ -0,0 +1,6 @@ + +You have access to the following tools: + +IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it. +Do NOT claim you can do something if the corresponding tool is not listed. + diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md new file mode 100644 index 000000000..8bde13f22 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_image.md @@ -0,0 +1,11 @@ + +- generate_image: Generate images from text descriptions using AI image models. + - Use this when the user asks you to create, generate, draw, design, or make an image. + - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" + - Args: + - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. + - n: Number of images to generate (1-4, default: 1) + - Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat. + - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - + expand and improve the prompt with specific details about style, lighting, composition, and mood. + - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md new file mode 100644 index 000000000..58be143d7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_podcast.md @@ -0,0 +1,15 @@ + +- generate_podcast: Generate an audio podcast from provided content. + - Use this when the user asks to create, generate, or make a podcast. + - Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast" + - Args: + - source_content: The text content to convert into a podcast. This MUST be comprehensive and include: + * If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses) + * If based on knowledge base search: Include the key findings and insights from the search results + * You can combine both: conversation context + search results for richer podcasts + * The more detailed the source_content, the better the podcast quality + - podcast_title: Optional title for the podcast (default: "SurfSense Podcast") + - user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun") + - Returns: A task_id for tracking. The podcast will be generated in the background. + - IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating". + - After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes). diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md new file mode 100644 index 000000000..8a285a433 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_report.md @@ -0,0 +1,39 @@ + +- generate_report: Generate or revise a structured Markdown report artifact. + - WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable: + * Creation verbs: write, create, generate, draft, produce, summarize into, turn into, make + * Modification verbs: revise, update, expand, add (a section), rewrite, make (it shorter/longer/formal) + * Example triggers: "generate a report about...", "write a document on...", "add a section about budget", "make the report shorter", "rewrite in formal tone" + - WHEN NOT TO CALL THIS TOOL (answer in chat instead): + * Questions or discussion about the report: "What can we add?", "What's missing?", "Is the data accurate?", "How could this be improved?" + * Suggestions or brainstorming: "What other topics could be covered?", "What else could be added?", "What would make this better?" + * Asking for explanations: "Can you explain section 2?", "Why did you include that?", "What does this part mean?" + * Quick follow-ups or critiques: "Is the conclusion strong enough?", "Are there any gaps?", "What about the competitors?" + * THE TEST: Does the message contain a creation/modification VERB (from the list above) directed at producing or changing a deliverable? If NO verb → answer conversationally in chat. Do NOT assume the user wants a revision just because a report exists in the conversation. + - IMPORTANT FORMAT RULE: Reports are ALWAYS generated in Markdown. + - Args: + - topic: Short title for the report (max ~8 words). + - source_content: The text content to base the report on. + * For source_strategy="conversation" or "provided": Include a comprehensive summary of the relevant content. + * For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally. + * For source_strategy="auto": Include what you have; the tool searches KB if it's not enough. + - source_strategy: Controls how the tool collects source material. One of: + * "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. + * "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. + * "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries. + * "provided" — Use only what is in source_content (default, backward-compatible). + - search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated. + - report_style: Controls report depth. Options: "detailed" (DEFAULT), "deep_research", "brief". + Use "brief" ONLY when the user explicitly asks for a short/concise/one-page report (e.g., "one page", "keep it short", "brief report", "500 words"). Default to "detailed" for all other requests. + - user_instructions: Optional specific instructions (e.g., "focus on financial impacts", "include recommendations"). When revising (parent_report_id set), describe WHAT TO CHANGE. If the user mentions a length preference (e.g., "one page", "500 words", "2 pages"), include that VERBATIM here AND set report_style="brief". + - parent_report_id: Set this to the report_id from a previous generate_report result when the user wants to MODIFY an existing report. Do NOT set it for new reports or questions about reports. + - Returns: A dictionary with status "ready" or "failed", report_id, title, and word_count. + - The report is generated immediately in Markdown and displayed inline in the chat. + - Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report. + - SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly): + * If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content. + * If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries. + * If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries. + * When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content. + * NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally. + - AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md new file mode 100644 index 000000000..321ea90c9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_resume.md @@ -0,0 +1,30 @@ + +- generate_resume: Generate or revise a professional resume as a Typst document. + - WHEN TO CALL: The user asks to create, build, generate, write, or draft a resume or CV. + Also when they ask to modify, update, or revise an existing resume from this conversation. + - WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing + a resume without making changes. For cover letters, use generate_report instead. + - The tool produces Typst source code that is compiled to a PDF preview automatically. + - PAGE POLICY: + - Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more. + - If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value. + - Args: + - user_info: The user's resume content — work experience, education, skills, contact + info, etc. Can be structured or unstructured text. + CRITICAL: user_info must be COMPREHENSIVE. Do NOT just pass the user's raw message. + You MUST gather and consolidate ALL available information: + * Content from referenced/mentioned documents (e.g., uploaded resumes, CVs, LinkedIn profiles) + that appear in the conversation context — extract and include their FULL content. + * Information the user shared across multiple messages in the conversation. + * Any relevant details from knowledge base search results in the context. + The more complete the user_info, the better the resume. Include names, contact info, + work experience with dates, education, skills, projects, certifications — everything available. + - user_instructions: Optional style or content preferences (e.g. "emphasize leadership", + "keep it to one page"). For revisions, describe what to change. + - parent_report_id: Set this when the user wants to MODIFY an existing resume from + this conversation. Use the report_id from a previous generate_resume result. + - max_pages: Maximum resume length in pages (integer 1-5). Default is 1. + - Returns: Dict with status, report_id, title, and content_type. + - After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically. + - VERSIONING: Same rules as generate_report — set parent_report_id for modifications + of an existing resume, leave as None for new resumes. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md new file mode 100644 index 000000000..c3def88f2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/generate_video_presentation.md @@ -0,0 +1,9 @@ + +- generate_video_presentation: Generate a video presentation from provided content. + - Use this when the user asks to create a video, presentation, slides, or slide deck. + - Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation" + - Args: + - source_content: The text content to turn into a presentation. The more detailed, the better. + - video_title: Optional title (default: "SurfSense Presentation") + - user_prompt: Optional style instructions (e.g., "Make it technical and detailed") + - After calling this tool, inform the user that generation has started and they will see the presentation when it's ready. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md b/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md new file mode 100644 index 000000000..46e299392 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/scrape_webpage.md @@ -0,0 +1,30 @@ + +- scrape_webpage: Scrape and extract the main content from a webpage. + - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. + - CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying): + * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL + * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) + * When a URL was mentioned earlier in the conversation and the user asks for its actual content + * When `/documents/` knowledge-base data is insufficient and the user wants more + - Trigger scenarios: + * "Read this article and summarize it" + * "What does this page say about X?" + * "Summarize this blog post for me" + * "Tell me the key points from this article" + * "What's in this webpage?" + * "Can you analyze this article?" + * "Can you get the live table/data from [URL]?" + * "Scrape it" / "Can you scrape that?" (referring to a previously mentioned URL) + * "Fetch the content from [URL]" + * "Pull the data from that page" + - Args: + - url: The URL of the webpage to scrape (must be HTTP/HTTPS) + - max_length: Maximum content length to return (default: 50000 chars) + - Returns: The page title, description, full content (in markdown), word count, and metadata + - After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points. + - Reference the source using markdown links [descriptive text](url) — never bare URLs. + - IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`. + * When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: `![alt text](image_url)`. + * This makes your response more visual and engaging. + * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. + * Don't show every image - just the most relevant 1-3 images that enhance understanding. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md b/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md new file mode 100644 index 000000000..133717fec --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/search_surfsense_docs.md @@ -0,0 +1,7 @@ + +- search_surfsense_docs: Search the official SurfSense documentation. + - Use this tool when the user asks anything about SurfSense itself (the application they are using). + - Args: + - query: The search query about SurfSense + - top_k: Number of documentation chunks to retrieve (default: 10) + - Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123]) diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md new file mode 100644 index 000000000..184013804 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_private.md @@ -0,0 +1,31 @@ + +- update_memory: Update your personal memory document about the user. + - Your current memory is already in in your context. The `chars` and + `limit` attributes show your current usage and the maximum allowed size. + - This is your curated long-term memory — the distilled essence of what you know about + the user, not raw conversation logs. + - Call update_memory when: + * The user explicitly asks to remember or forget something + * The user shares durable facts or preferences that will matter in future conversations + - The user's first name is provided in . Use it in memory entries + instead of "the user" (e.g. "{name} works at..." not "The user works at..."). + Do not store the name itself as a separate memory entry. + - Do not store short-lived or ephemeral info: one-off questions, greetings, + session logistics, or things that only matter for the current task. + - Args: + - updated_memory: The FULL updated markdown document (not a diff). + Merge new facts with existing ones, update contradictions, remove outdated entries. + Treat every update as a curation pass — consolidate, don't just append. + - Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text + Markers: + [fact] — durable facts (role, background, projects, tools, expertise) + [pref] — preferences (response style, languages, formats, tools) + [instr] — standing instructions (always/never do, response rules) + - Keep it concise and well under the character limit shown in . + - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and + natural. Do NOT include the user's name in headings. Organize by context — e.g. + who they are, what they're focused on, how they prefer things. Create, split, or + merge headings freely as the memory grows. + - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant + details and context rather than just a few words. + - During consolidation, prioritize keeping: [instr] > [pref] > [fact]. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md new file mode 100644 index 000000000..7eaca8818 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/update_memory_team.md @@ -0,0 +1,26 @@ + +- update_memory: Update the team's shared memory document for this search space. + - Your current team memory is already in in your context. The `chars` + and `limit` attributes show current usage and the maximum allowed size. + - This is the team's curated long-term memory — decisions, conventions, key facts. + - NEVER store personal memory in team memory (e.g. personal bio, individual + preferences, or user-only standing instructions). + - Call update_memory when: + * A team member explicitly asks to remember or forget something + * The conversation surfaces durable team decisions, conventions, or facts + that will matter in future conversations + - Do not store short-lived or ephemeral info: one-off questions, greetings, + session logistics, or things that only matter for the current task. + - Args: + - updated_memory: The FULL updated markdown document (not a diff). + Merge new facts with existing ones, update contradictions, remove outdated entries. + Treat every update as a curation pass — consolidate, don't just append. + - Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text + Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory. + - Keep it concise and well under the character limit shown in . + - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and + natural. Organize by context — e.g. what the team decided, current architecture, + active processes. Create, split, or merge headings freely as the memory grows. + - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant + details and context rather than just a few words. + - During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities. diff --git a/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md b/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md new file mode 100644 index 000000000..7ed7c332d --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/prompts/tools/web_search.md @@ -0,0 +1,18 @@ + +- web_search: Search the web for real-time information using all configured search engines. + - Use this for current events, news, prices, weather, public facts, or any question requiring + up-to-date information from the internet. + - This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in + parallel and merges the results. + - IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data + (e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call + `web_search` instead of answering from memory. + - For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet + access before attempting a web search. + - If the search returns no relevant results, explain that web sources did not return enough + data and ask the user if they want you to retry with a refined query. + - Args: + - query: The search query - use specific, descriptive terms + - top_k: Number of results to retrieve (default: 10, max: 50) + - If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content. + - When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs. diff --git a/surfsense_backend/app/agents/new_chat/skills/__init__.py b/surfsense_backend/app/agents/new_chat/skills/__init__.py new file mode 100644 index 000000000..bb7ac055c --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense built-in agent skills (Anthropic Skills format). + +Each subdirectory corresponds to one skill and contains a ``SKILL.md`` file +with YAML frontmatter (name, description, allowed_tools) plus markdown +instructions. The :class:`BuiltinSkillsBackend` exposes them to the +deepagents :class:`SkillsMiddleware`. +""" diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py b/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md new file mode 100644 index 000000000..32e599e98 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/email-drafting/SKILL.md @@ -0,0 +1,25 @@ +--- +name: email-drafting +description: Draft an email matching the user's voice, with structured intent and CTA +allowed-tools: search_surfsense_docs +--- + +# Email drafting + +## When to use this skill +"Draft an email to ...", "reply to this thread", "write a follow-up to X". Plain "summarize the email" is **not** in scope — that's a comprehension task. + +## Voice +Search the KB for prior emails from the user to similar audiences (same recipient, same topic class). Mirror tone, opening style, sign-off, and length distribution. If there is no precedent, default to: warm, direct, no filler, short paragraphs, one clear ask. + +## Required structure +Every draft includes, in this order: + +1. **Subject line** — concrete, ≤ 8 words, no clickbait, no `Re:` unless replying. +2. **Opening (1 sentence)** — context the recipient already shares; never restate what they wrote unless the thread is long. +3. **Body** — the actual point in one short paragraph. Bullets only if there are >3 discrete items. +4. **Single explicit CTA** — what you want the recipient to do, with a soft deadline if relevant. +5. **Sign-off** — match the user's prior closing style. + +## Always offer alternatives +End your message with: "Want me to make it shorter, more formal, or add a different angle?" — give the user one obvious next step. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md new file mode 100644 index 000000000..c268278ab --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/kb-research/SKILL.md @@ -0,0 +1,23 @@ +--- +name: kb-research +description: Structured approach to finding and synthesizing information from the user's knowledge base +allowed-tools: search_surfsense_docs, scrape_webpage, read_file, ls_tree, grep, web_search +--- + +# Knowledge-base research + +## When to use this skill +- The user asks "find/look up/research" something specifically inside their knowledge base. +- The user references documents, notes, repos, or connector data they expect to exist already. +- A multi-document synthesis is required (e.g., "summarize what we've discussed about X across all my notes"). + +## Plan +1. Decompose the user's question into 2-4 specific, citation-worthy sub-questions. +2. For each sub-question, run **one** targeted KB search (focused on terms the user would have written, not synonyms). Open the most relevant 2-3 documents fully via `read_file` if their excerpts are too short. +3. Use `grep` to find supporting passages in long files instead of re-reading them end to end. +4. Cite every claim with `[citation:chunk_id]` exactly as the chunk tag specifies. + +## What good output looks like +- Short paragraphs with inline citations. +- Quoted phrases when wording matters. +- An explicit "Not found in your knowledge base" callout when a sub-question has no support — never fabricate. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md new file mode 100644 index 000000000..9657eb078 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/meeting-prep/SKILL.md @@ -0,0 +1,22 @@ +--- +name: meeting-prep +description: Pull together briefing materials before a scheduled meeting +allowed-tools: search_surfsense_docs, web_search, scrape_webpage, read_file +--- + +# Meeting preparation + +## When to use this skill +The user mentions an upcoming meeting, call, or interview and asks you to "prep", "brief me", "pull background", or "what do I need to know about X before tomorrow". + +## Output structure +Always produce these sections (omit any with no signal — don't pad): + +1. **Attendees & context** — who's in the room, their roles, what they care about. Pull from KB notes about prior interactions; supplement with public profile facts via `web_search` when names or companies are unfamiliar. +2. **Open threads** — outstanding action items, unresolved decisions, last-mentioned blockers from prior conversation history. +3. **Recent moves** — within the last 30 days: relevant launches, hires, news. Cite KB chunks when present, otherwise external sources. +4. **Suggested questions** — 3-5 questions the user could ask, tailored to the open threads and the attendees' likely priorities. + +## Source ordering +- Always check the user's KB **first** for prior meeting notes, internal docs, or Slack threads about these attendees. +- Only fall back to `web_search` for *publicly verifiable* facts — never to fabricate a participant's preferences or relationships. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md new file mode 100644 index 000000000..17ac2f391 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/report-writing/SKILL.md @@ -0,0 +1,23 @@ +--- +name: report-writing +description: How to scope, draft, and revise a Markdown report artifact via generate_report +allowed-tools: generate_report, search_surfsense_docs, read_file +--- + +# Report writing + +## When to use this skill +The user explicitly requests a deliverable: "write a report on …", "draft a memo", "produce a brief", "expand the previous report". A creation or modification verb pointed at an artifact is required (see `generate_report`'s when-to-call rules). + +## Decision flow +1. **Source strategy.** Decide which `source_strategy` fits: + - `conversation` — substantive Q&A on the topic already in chat. + - `kb_search` — fresh topic; supply 1–5 precise `search_queries`. + - `auto` — partial conversation context; let the tool fall back. + - `provided` — verbatim source text only. +2. **Style.** Default to `report_style="detailed"` unless the user explicitly asks for "brief", "one page", "500 words". +3. **Revisions.** When modifying an existing report from this conversation, set `parent_report_id` and put the change list in `user_instructions` ("add carbon-capture section", "tighten conclusion"). +4. **Never paste the report back into chat** after `generate_report` returns — confirm and let the artifact card render itself. + +## Hooks for KB-only mode +If `kb_search`/`auto` returns no results, do **not** silently switch to general knowledge. Surface the gap in your confirmation message. diff --git a/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md new file mode 100644 index 000000000..33b9e72a2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/skills/builtin/slack-summary/SKILL.md @@ -0,0 +1,26 @@ +--- +name: slack-summary +description: Distill a Slack channel or thread into actionable summary +allowed-tools: search_surfsense_docs +--- + +# Slack summarization + +## When to use this skill +The user asks to summarize Slack ("what happened in #eng-platform this week", "what did Alice say about the launch", "catch me up on the design channel"). + +## Required inputs +Confirm before searching: +- **Which channel(s) or thread(s)?** Don't guess if ambiguous. +- **What time window?** Default to the last 7 days when not specified, but say so. + +## Output shape +Produce three concise sections: +1. **Key decisions** — explicit choices that were made, with the deciding message cited. +2. **Open questions** — things asked but not answered, with the asking message cited. +3. **Action items** — `@mention` who owes what by when, *only if explicitly stated*. Don't invent assignees. + +## What not to do +- Never produce a chronological play-by-play of every message — distill. +- Never quote private messages without flagging them as such. +- If the channel was empty in the time window, say so — don't fabricate filler. diff --git a/surfsense_backend/app/agents/new_chat/state_reducers.py b/surfsense_backend/app/agents/new_chat/state_reducers.py new file mode 100644 index 000000000..ce32406e6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/state_reducers.py @@ -0,0 +1,201 @@ +"""Reducers and sentinels for SurfSense filesystem state. + +These reducers back the extra state fields used by the cloud-mode filesystem +agent (`cwd`, `staged_dirs`, `pending_moves`, `dirty_paths`, `doc_id_by_path`, +`kb_priority`, `kb_matched_chunk_ids`, `kb_anon_doc`, `tree_version`). + +Tools mutate these fields ONLY via `Command(update={...})` returns; the +reducers are responsible for merging successive updates atomically and for +honouring an explicit reset sentinel (`_CLEAR`) so that a single update can +both reset and reseed a list (used by `move_file` / `aafter_agent`). + +The sentinel is intentionally a plain string constant rather than a custom +object so that LangGraph's checkpointer (which serializes raw `Command.update` +deltas via ``ormsgpack`` BEFORE reducers are applied) can round-trip writes +that contain it. The token uses a NUL-bracketed form that cannot collide with +any real virtual path, document title, or dict key produced by the agent. +""" + +from __future__ import annotations + +from typing import Any, Final, TypeVar + +_CLEAR: Final[str] = "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00" +"""Reset sentinel; pass it inside a list/dict update to request a reset. + +For list reducers: ``[_CLEAR, *items]`` resets the field then appends ``items``. +For dict reducers: ``{_CLEAR: True, **items}`` resets the field then merges ``items``. + +Because the value is a plain string with embedded NUL bytes, it is natively +serializable by ``ormsgpack`` (used by LangGraph's PostgreSQL checkpointer) +yet still distinct from any real path / key produced by application code. +""" + + +T = TypeVar("T") + + +def _replace_reducer[T](left: T | None, right: T | None) -> T | None: + """Replace `left` outright with `right`. ``None`` on the right is honored as a reset.""" + return right + + +def _is_clear(value: Any) -> bool: + return isinstance(value, str) and value == _CLEAR + + +def _add_unique_reducer( + left: list[Any] | None, + right: list[Any] | None, +) -> list[Any]: + """Append items from ``right`` to ``left`` while preserving uniqueness. + + Semantics: + - If ``right`` is ``None`` or empty, return ``left`` unchanged. + - If ``right`` contains the ``_CLEAR`` sentinel anywhere, the result is + reseeded with only the items that appear AFTER the LAST occurrence of + ``_CLEAR`` (deduplicated, preserving first-seen order). This gives a + single-update "reset and reseed" capability. + - Otherwise, items from ``right`` are appended to ``left`` (order preserved + from first seen) while skipping values that are already present. + """ + if right is None: + return list(left or []) + if not right: + return list(left or []) + + last_clear = -1 + for index, item in enumerate(right): + if _is_clear(item): + last_clear = index + + if last_clear >= 0: + seed: list[Any] = [] + seen: set[Any] = set() + for item in right[last_clear + 1 :]: + if _is_clear(item): + continue + try: + if item in seen: + continue + seen.add(item) + except TypeError: + if item in seed: + continue + seed.append(item) + return seed + + base = list(left or []) + try: + seen: set[Any] = set(base) + except TypeError: + seen = set() + for item in right: + if _is_clear(item): + continue + try: + if item in seen: + continue + seen.add(item) + except TypeError: + if item in base: + continue + base.append(item) + return base + + +def _list_append_reducer( + left: list[Any] | None, + right: list[Any] | None, +) -> list[Any]: + """Append items from ``right`` to ``left`` preserving order and duplicates. + + Honours the ``_CLEAR`` sentinel exactly like :func:`_add_unique_reducer`, + but does NOT deduplicate. Used for queues whose ordering and duplicate + occurrences matter (e.g. ``pending_moves``). + """ + if right is None: + return list(left or []) + if not right: + return list(left or []) + + last_clear = -1 + for index, item in enumerate(right): + if _is_clear(item): + last_clear = index + + if last_clear >= 0: + return [item for item in right[last_clear + 1 :] if not _is_clear(item)] + + base = list(left or []) + base.extend(item for item in right if not _is_clear(item)) + return base + + +def _dict_merge_with_tombstones_reducer( + left: dict[Any, Any] | None, + right: dict[Any, Any] | None, +) -> dict[Any, Any]: + """Merge ``right`` into ``left`` with two extra capabilities: + + * Keys whose value is ``None`` are removed from the merged result + (tombstone semantics, matching the deepagents file-data reducer). + * The special key ``_CLEAR`` (with any truthy value) resets ``left`` to + ``{}`` before merging the remaining keys from ``right``. This makes it + possible to atomically clear and reseed the dictionary in a single + update. + """ + if right is None: + return dict(left or {}) + + if _CLEAR in right or any(_is_clear(k) for k in right): + result: dict[Any, Any] = {} + for key, value in right.items(): + if _is_clear(key): + continue + if value is None: + result.pop(key, None) + continue + result[key] = value + return result + + if left is None: + return {key: value for key, value in right.items() if value is not None} + + result = dict(left) + for key, value in right.items(): + if value is None: + result.pop(key, None) + else: + result[key] = value + return result + + +def _initial_filesystem_state() -> dict[str, Any]: + """Default empty values for SurfSense filesystem state fields. + + Consumers should always treat these fields as ``state.get(key) or + DEFAULT`` so that fresh threads (without checkpointed state) work + correctly. + """ + return { + "cwd": "/documents", + "staged_dirs": [], + "pending_moves": [], + "doc_id_by_path": {}, + "dirty_paths": [], + "kb_priority": [], + "kb_matched_chunk_ids": {}, + "kb_anon_doc": None, + "tree_version": 0, + } + + +__all__ = [ + "_CLEAR", + "_add_unique_reducer", + "_dict_merge_with_tombstones_reducer", + "_initial_filesystem_state", + "_list_append_reducer", + "_replace_reducer", +] diff --git a/surfsense_backend/app/agents/new_chat/subagents/__init__.py b/surfsense_backend/app/agents/new_chat/subagents/__init__.py new file mode 100644 index 000000000..7d678ec79 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/subagents/__init__.py @@ -0,0 +1,29 @@ +"""Specialized user-facing subagents for the SurfSense agent. + +The :class:`deepagents.SubAgentMiddleware` already provides the +materialization machinery (each :class:`deepagents.SubAgent` typed-dict +spec is compiled into an ephemeral runnable invoked via the ``task`` +tool); what's specific to SurfSense is the *seeding* of those subagents +with declarative deny rules. + +Per-subagent permission rules are injected as a +:class:`PermissionMiddleware` entry inside the subagent's ``middleware`` +field. The auto-deny pattern (e.g. forbid ``task``/``todowrite`` +recursion, block write tools for read-only research roles) is borrowed +from OpenCode's ``packages/opencode/src/tool/task.ts``, which has +analogous logic for restricting child sessions. +""" + +from .config import ( + build_connector_negotiator_subagent, + build_explore_subagent, + build_report_writer_subagent, + build_specialized_subagents, +) + +__all__ = [ + "build_connector_negotiator_subagent", + "build_explore_subagent", + "build_report_writer_subagent", + "build_specialized_subagents", +] diff --git a/surfsense_backend/app/agents/new_chat/subagents/config.py b/surfsense_backend/app/agents/new_chat/subagents/config.py new file mode 100644 index 000000000..b36d35fa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/subagents/config.py @@ -0,0 +1,425 @@ +"""Builders for specialized SurfSense subagents. + +Each subagent is built from three pieces: + +1. A name + description + system prompt (the user-facing contract for + when ``task`` should delegate to this role). +2. A filtered tool list (subset of the parent's bound tools). +3. A :class:`PermissionMiddleware` instance carrying a deny ruleset that + prevents the subagent from acting outside its scope (e.g. an + explore-only role cannot mutate state). + +Skill sources (``/skills/builtin/`` + ``/skills/space/``) are inherited +from the parent unconditionally — every subagent benefits from the same +authored guidance documents. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any + +from app.agents.new_chat.middleware.skills_backends import default_skills_sources +from app.agents.new_chat.permissions import Rule, Ruleset + +if TYPE_CHECKING: + from deepagents import SubAgent + from langchain_core.language_models import BaseChatModel + from langchain_core.tools import BaseTool + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Tool name constants +# --------------------------------------------------------------------------- + +# Read-only tools that ``explore`` is permitted to use. Names match the +# tools provided by the deepagents ``FilesystemMiddleware`` (``ls``, ``read_file``, +# ``glob``, ``grep``) plus the SurfSense-side read tools. +EXPLORE_READ_TOOLS: frozenset[str] = frozenset( + { + "search_surfsense_docs", + "web_search", + "scrape_webpage", + "read_file", + "ls", + "glob", + "grep", + } +) + +# Tools ``report_writer`` may call. The set is intentionally narrow so the +# subagent doesn't drift into tangential research; if richer source-gathering +# is needed, the parent should hand off to ``explore`` first. +REPORT_WRITER_TOOLS: frozenset[str] = frozenset( + { + "search_surfsense_docs", + "read_file", + "generate_report", + } +) + +# Wildcard patterns that match write tools we deny by default in read-only +# subagents. Anchored at start AND end via :func:`Rule` semantics. We use +# substring-style ``*verb*`` patterns because connector tool names typically +# put the verb in the middle (``linear_create_issue``, ``slack_send_message``, +# ``notion_update_page``); strict suffix patterns (``*_create``) miss those. +# +# A handful of canonical exact-match names is appended so that bare verbs +# (``edit``, ``write``) are also blocked even when a connector dropped the +# usual prefix. +WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = ( + "*create*", + "*update*", + "*delete*", + "*send*", + "*write*", + "*edit*", + "*move*", + "*mkdir*", + "*upload*", + "edit_file", + "write_file", + "move_file", + "mkdir", + "update_memory", + "update_memory_team", + "update_memory_private", +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +# Tool names that are NOT in the registry's ``tools`` list because they +# are provided dynamically by middleware at compile time. We don't pass +# them through ``_filter_tools`` (the actual ``BaseTool`` instances live +# inside the middleware), but we do exempt them from the "missing" warning +# below — operators were seeing spurious noise like +# ``missing: ['glob', 'grep', 'ls', 'read_file']`` even though those +# tools are reachable via :class:`SurfSenseFilesystemMiddleware` once the +# subagent is compiled. +_MIDDLEWARE_PROVIDED_TOOL_NAMES: frozenset[str] = frozenset( + { + "ls", + "read_file", + "write_file", + "edit_file", + "glob", + "grep", + "execute", + "write_todos", + "task", + } +) + + +def _filter_tools( + tools: Sequence[BaseTool], + allowed_names: Iterable[str], +) -> list[BaseTool]: + """Return only tools whose ``name`` appears in ``allowed_names``. + + Tools are looked up by exact name. Names matching + :data:`_MIDDLEWARE_PROVIDED_TOOL_NAMES` are intentionally absent from + ``tools`` (they're injected by middleware at compile time) and are + silently excluded from the "missing" warning so operators don't see + false positives every build. + """ + allowed = set(allowed_names) + selected = [t for t in tools if t.name in allowed] + missing = sorted( + (allowed - {t.name for t in selected}) - _MIDDLEWARE_PROVIDED_TOOL_NAMES + ) + if missing: + logger.info( + "Subagent build: %d/%d registry tools available; missing: %s", + len(selected), + len(allowed - _MIDDLEWARE_PROVIDED_TOOL_NAMES), + missing, + ) + return selected + + +def _read_only_deny_rules() -> list[Rule]: + """Synthesize a list of deny rules covering common write-tool patterns.""" + return [ + Rule(permission=pattern, pattern="*", action="deny") + for pattern in WRITE_TOOL_DENY_PATTERNS + ] + + +def _build_permission_middleware(deny_rules: list[Rule], origin: str): + """Construct a :class:`PermissionMiddleware` seeded with ``deny_rules``. + + Imported lazily because the middleware module pulls in interrupt/HITL + machinery we don't want at import time of this config file. + """ + from app.agents.new_chat.middleware.permission import PermissionMiddleware + + return PermissionMiddleware( + rulesets=[Ruleset(rules=deny_rules, origin=origin)], + ) + + +def _wrap_with_subagent_essentials( + custom_middleware: list, + *, + agent_tools: Sequence[BaseTool], + extra_middleware: Sequence[Any] | None = None, +): + """Compose the final middleware list for a specialized subagent. + + Order, outer to inner: + + 1. ``extra_middleware`` — provided by the caller (typically the parent + agent's ``SurfSenseFilesystemMiddleware`` and ``TodoListMiddleware``) + so the subagent inherits the parent's filesystem/todo view. These + run **before** the subagent-local middleware so their tools are + wired up before permissioning kicks in. + 2. ``custom_middleware`` — subagent-local rules (e.g. permission deny + lists). + 3. :class:`PatchToolCallsMiddleware` — normalizes tool-call shapes. + 4. :class:`DedupHITLToolCallsMiddleware` — collapses duplicate HITL + calls using metadata declared at registry time. + + Without ``extra_middleware`` the subagent will only have the registry + tools listed in its ``tools`` field — meaning ``read_file``, ``ls``, + ``grep``, etc. won't exist. Always pass ``extra_middleware`` from the + parent unless you specifically want a sandboxed subagent. + """ + from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware + + from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware + + return [ + *(extra_middleware or []), + *custom_middleware, + PatchToolCallsMiddleware(), + DedupHITLToolCallsMiddleware(agent_tools=list(agent_tools)), + ] + + +# --------------------------------------------------------------------------- +# System prompts +# --------------------------------------------------------------------------- + +EXPLORE_SYSTEM_PROMPT = """You are the **explore** subagent for SurfSense. + +## Your job +Conduct read-only research across the user's knowledge base, the web, and any documents the parent agent has surfaced. Return a synthesized answer with explicit citations — never speculate beyond the sources you have actually inspected. + +## Tools available +- `search_surfsense_docs` — fast hybrid search over the user's knowledge base. +- `web_search` — only when the user's KB clearly does not contain the answer. +- `scrape_webpage` — to read a URL the user or the search results provided. +- `read_file`, `ls`, `glob`, `grep` — to inspect specific documents or trees the parent has flagged. + +## Rules +- Read-only. You cannot create, edit, delete, send, or move anything. +- Cite every claim. Use `[citation:chunk_id]` exactly as the chunk tag specifies. +- If a sub-question has no support in the inspected sources, say so explicitly. Do not fabricate. +- Return the most useful synthesis in your single final message. The parent agent will not be able to follow up. +""" + + +REPORT_WRITER_SYSTEM_PROMPT = """You are the **report_writer** subagent for SurfSense. + +## Your job +Produce a single high-quality report deliverable using `generate_report`. The parent has already gathered (or knows where to gather) the underlying sources. + +## Workflow +1. **Outline first.** Before calling `generate_report`, write a one-paragraph outline of the sections you plan to produce. Confirm the outline reflects the parent's instructions. +2. **Source resolution.** Decide whether to call `search_surfsense_docs` and `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set. +3. **One report.** Call `generate_report` exactly once with `source_strategy` chosen per the topic and chat history (see the `report-writing` skill). +4. **Confirm.** End with a one-sentence summary in your final message — never paste the report back into chat; the artifact card renders itself. +""" + + +CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT = """You are the **connector_negotiator** subagent for SurfSense. + +## Your job +Coordinate cross-connector workflows: chains where the result of one service's tool feeds into another's. Common shapes include "find Linear issues mentioned in last week's Slack messages", "draft a Gmail reply citing a Notion doc", or "list Linear tickets opened by the same person who filed Jira FOO-123". + +## Workflow +1. **Plan.** Identify the connector hops needed and the order they should run in. Write a short plan in your first message. +2. **Verify access.** Use `get_connected_accounts` to confirm the relevant connectors are actually wired up before issuing tool calls. If a connector is missing, stop and report — do not fabricate. +3. **Execute.** Run each hop, citing IDs (issue keys, message ts, page IDs) in your scratch notes so the parent can audit. +4. **Hand back.** Return a structured summary with the final answer plus the chain of evidence (issue → message → page, etc.). + +## Caveats +- If a hop fails, do not retry blindly — return the partial result and explain. +- Mutating tools (create, update, delete, send) require parent permission; you are NOT cleared to call them on your own. +""" + + +# --------------------------------------------------------------------------- +# Subagent builders +# --------------------------------------------------------------------------- + + +def build_explore_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the read-only ``explore`` subagent spec. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so the subagent can actually use ``read_file``, ``ls``, + ``grep``, ``glob`` — which its system prompt promises but which only + exist when their middleware is mounted. + """ + from deepagents import SubAgent # noqa: F401 (TypedDict for type clarity) + + selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS) + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware(deny_rules, origin="subagent_explore") + + spec: dict = { + "name": "explore", + "description": ( + "Read-only research across the user's knowledge base and the web. " + "Use when the parent needs deeply-cited synthesis without " + "modifying anything." + ), + "system_prompt": EXPLORE_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_report_writer_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the ``report_writer`` subagent spec. + + Read-only deny ruleset still applies — the subagent should call + ``generate_report`` and nothing else mutating. ``generate_report`` + creates a report artifact via a backend service and is intentionally + **not** denied. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so the subagent can run ``read_file`` for source-checks + before calling ``generate_report``. + """ + selected_tools = _filter_tools(tools, REPORT_WRITER_TOOLS) + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware( + deny_rules, origin="subagent_report_writer" + ) + + spec: dict = { + "name": "report_writer", + "description": ( + "Produce a single Markdown report artifact via generate_report, " + "using the outline-then-fill protocol. Use when the parent has " + "decided a deliverable is needed." + ), + "system_prompt": REPORT_WRITER_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_connector_negotiator_subagent( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> SubAgent: + """Build the ``connector_negotiator`` subagent spec. + + Inherits all MCP / connector tools the parent has plus + ``get_connected_accounts``. Read-only by default; permission rules deny + write/mutation patterns. The parent agent re-asks for permission if a + connector mutation is genuinely needed. + + Pass ``extra_middleware`` (typically the parent's filesystem + todo + middleware) so this subagent shares the parent's filesystem view when + citing evidence across hops. + """ + parent_tool_names = {t.name for t in tools} + allowed: set[str] = set() + if "get_connected_accounts" in parent_tool_names: + allowed.add("get_connected_accounts") + # Inherit anything that smells connector- or MCP-related but is not a + # bulk-write API. Heuristic: keep all parent tools; rely on the deny + # ruleset to block mutation patterns. This mirrors the plan: "all + # MCP/connector tools the parent has". + for name in parent_tool_names: + allowed.add(name) + selected_tools = _filter_tools(tools, allowed) + + deny_rules = _read_only_deny_rules() + permission_mw = _build_permission_middleware( + deny_rules, origin="subagent_connector_negotiator" + ) + + spec: dict = { + "name": "connector_negotiator", + "description": ( + "Coordinate read-only chains across connectors (Slack → Linear, " + "Notion → Gmail, etc.). Returns a structured summary with the " + "evidence chain. Cannot mutate connector state." + ), + "system_prompt": CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT, + "tools": selected_tools, + "middleware": _wrap_with_subagent_essentials( + [permission_mw], + agent_tools=selected_tools, + extra_middleware=extra_middleware, + ), + "skills": default_skills_sources(), + } + if model is not None: + spec["model"] = model + return spec # type: ignore[return-value] + + +def build_specialized_subagents( + *, + tools: Sequence[BaseTool], + model: BaseChatModel | None = None, + extra_middleware: Sequence[Any] | None = None, +) -> list[SubAgent]: + """Return the canonical list of specialized subagents to register. + + Order matters only for the order they appear in the ``task`` tool + description — most useful first. + """ + return [ + build_explore_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + build_report_writer_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + build_connector_negotiator_subagent( + tools=tools, model=model, extra_middleware=extra_middleware + ), + ] diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index e77132182..56f838d7e 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -1,842 +1,44 @@ """ -System prompt building for SurfSense agents. +Thin compatibility wrapper around :mod:`app.agents.new_chat.prompts.composer`. -This module provides functions and constants for building the SurfSense system prompt -with configurable user instructions and citation support. +The composer split the previous monolithic prompt string into a fragment +tree under ``prompts/`` plus a model-family dispatch step (see the +composer module docstring for credits). This module preserves the public +function surface (``build_surfsense_system_prompt`` / +``build_configurable_system_prompt`` / +``get_default_system_instructions`` / ``SURFSENSE_SYSTEM_PROMPT``) so +that existing call sites — `chat_deepagent.py`, anonymous chat routes, +and the configurable-prompt admin path — keep working without churn. -The prompt is composed of three parts: -1. System Instructions (configurable via NewLLMConfig) -2. Tools Instructions (always included, not configurable) -3. Citation Instructions (toggleable via NewLLMConfig.citations_enabled) +For new call sites prefer importing ``compose_system_prompt`` directly +from :mod:`app.agents.new_chat.prompts.composer`. """ +from __future__ import annotations + from datetime import UTC, datetime from app.db import ChatVisibility -# Default system instructions - can be overridden via NewLLMConfig.system_instructions -SURFSENSE_SYSTEM_INSTRUCTIONS = """ - -You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base. - -Today's date (UTC): {resolved_today} - -When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. - -NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. - - -CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: -- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs. -- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission. -- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: - 1. Inform the user that you could not find relevant information in their knowledge base. - 2. Ask the user: "Would you like me to answer from my general knowledge instead?" - 3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes. -- This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") - * Formatting, summarization, or analysis of content already present in the conversation - * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") - * Tool-usage actions like generating reports, podcasts, images, or scraping webpages - * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below - - - -CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. -Their data is NEVER in the knowledge base. You MUST call their tools immediately — never -say "I don't see it in the knowledge base" or ask the user if they want you to check. -Ignore any knowledge base results for these services. - -When to use which tool: -- Linear (issues) → list_issues, get_issue, save_issue (create/update) -- ClickUp (tasks) → clickup_search, clickup_get_task -- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue -- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread -- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table -- Knowledge base content (Notion, GitHub, files, notes) → automatically searched -- Real-time public web data → call web_search -- Reading a specific webpage → call scrape_webpage - - - -Some service tools require identifiers or context you do not have (account IDs, -workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw -IDs or technical identifiers — they cannot memorise them. - -Instead, follow this discovery pattern: -1. Call a listing/discovery tool to find available options. -2. ONE result → use it silently, no question to the user. -3. MULTIPLE results → present the options by their display names and let the - user choose. Never show raw UUIDs — always use friendly names. - -Discovery tools by level: -- Which account/workspace? → get_connected_accounts("") -- Which Jira site (cloudId)? → getAccessibleAtlassianResources -- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) -- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) -- Which channel? → slack_search_channels -- Which base? → list_bases -- Which table? → list_tables_for_base (after resolving baseId) -- Which task? → clickup_search -- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) - -For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to -obtain the cloudId, then pass it to other Jira tools. When creating an issue, -chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. -If there is only one option at each step, use it silently. If multiple, present -friendly names. - -Chain discovery when needed — e.g. for Airtable records: list_bases → pick -base → list_tables_for_base → pick table → list_records_for_table. - -MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for -the same service, tool names are prefixed to avoid collisions — e.g. -linear_25_list_issues and linear_30_list_issues instead of two list_issues. -Each prefixed tool's description starts with [Account: ] so you -know which account it targets. Use get_connected_accounts("") to see -the full list of accounts with their connector IDs and display names. -When only one account is connected, tools have their normal unprefixed names. - - - -IMPORTANT — After understanding each user message, ALWAYS check: does this message -reveal durable facts about the user (role, interests, preferences, projects, -background, or standing instructions)? If yes, you MUST call update_memory -alongside your normal response — do not defer this to a later turn. - - - -""" - -# Default system instructions for shared (team) threads: team context + message format for attribution -_SYSTEM_INSTRUCTIONS_SHARED = """ - -You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base. - -In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers. - -Today's date (UTC): {resolved_today} - -When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math. - -NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead. - - -CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: -- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs. -- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission. -- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST: - 1. Inform the team that you could not find relevant information in the shared knowledge base. - 2. Ask: "Would you like me to answer from my general knowledge instead?" - 3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes. -- This policy does NOT apply to: - * Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?") - * Formatting, summarization, or analysis of content already present in the conversation - * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") - * Tool-usage actions like generating reports, podcasts, images, or scraping webpages - * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below - - - -CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. -Their data is NEVER in the knowledge base. You MUST call their tools immediately — never -say "I don't see it in the knowledge base" or ask if they want you to check. -Ignore any knowledge base results for these services. - -When to use which tool: -- Linear (issues) → list_issues, get_issue, save_issue (create/update) -- ClickUp (tasks) → clickup_search, clickup_get_task -- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue -- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread -- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table -- Knowledge base content (Notion, GitHub, files, notes) → automatically searched -- Real-time public web data → call web_search -- Reading a specific webpage → call scrape_webpage - - - -Some service tools require identifiers or context you do not have (account IDs, -workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw -IDs or technical identifiers — they cannot memorise them. - -Instead, follow this discovery pattern: -1. Call a listing/discovery tool to find available options. -2. ONE result → use it silently, no question to the user. -3. MULTIPLE results → present the options by their display names and let the - user choose. Never show raw UUIDs — always use friendly names. - -Discovery tools by level: -- Which account/workspace? → get_connected_accounts("") -- Which Jira site (cloudId)? → getAccessibleAtlassianResources -- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) -- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) -- Which channel? → slack_search_channels -- Which base? → list_bases -- Which table? → list_tables_for_base (after resolving baseId) -- Which task? → clickup_search -- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) - -For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to -obtain the cloudId, then pass it to other Jira tools. When creating an issue, -chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. -If there is only one option at each step, use it silently. If multiple, present -friendly names. - -Chain discovery when needed — e.g. for Airtable records: list_bases → pick -base → list_tables_for_base → pick table → list_records_for_table. - -MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for -the same service, tool names are prefixed to avoid collisions — e.g. -linear_25_list_issues and linear_30_list_issues instead of two list_issues. -Each prefixed tool's description starts with [Account: ] so you -know which account it targets. Use get_connected_accounts("") to see -the full list of accounts with their connector IDs and display names. -When only one account is connected, tools have their normal unprefixed names. - - - -IMPORTANT — After understanding each user message, ALWAYS check: does this message -reveal durable facts about the team (decisions, conventions, architecture, processes, -or key facts)? If yes, you MUST call update_memory alongside your normal response — -do not defer this to a later turn. - - - -""" - - -def _get_system_instructions( - thread_visibility: ChatVisibility | None = None, today: datetime | None = None -) -> str: - """Build system instructions based on thread visibility (private vs shared).""" - - resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() - visibility = thread_visibility or ChatVisibility.PRIVATE - if visibility == ChatVisibility.SEARCH_SPACE: - return _SYSTEM_INSTRUCTIONS_SHARED.format(resolved_today=resolved_today) - else: - return SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today) - - -# ============================================================================= -# Per-tool prompt instructions keyed by registry tool name. -# Only tools present in the enabled set will be included in the system prompt. -# ============================================================================= - -_TOOLS_PREAMBLE = """ - -You have access to the following tools: - -IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it. -Do NOT claim you can do something if the corresponding tool is not listed. - -""" - -_TOOL_INSTRUCTIONS: dict[str, str] = {} - -_TOOL_INSTRUCTIONS["search_surfsense_docs"] = """ -- search_surfsense_docs: Search the official SurfSense documentation. - - Use this tool when the user asks anything about SurfSense itself (the application they are using). - - Args: - - query: The search query about SurfSense - - top_k: Number of documentation chunks to retrieve (default: 10) - - Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123]) -""" - -_TOOL_INSTRUCTIONS["generate_podcast"] = """ -- generate_podcast: Generate an audio podcast from provided content. - - Use this when the user asks to create, generate, or make a podcast. - - Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast" - - Args: - - source_content: The text content to convert into a podcast. This MUST be comprehensive and include: - * If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses) - * If based on knowledge base search: Include the key findings and insights from the search results - * You can combine both: conversation context + search results for richer podcasts - * The more detailed the source_content, the better the podcast quality - - podcast_title: Optional title for the podcast (default: "SurfSense Podcast") - - user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun") - - Returns: A task_id for tracking. The podcast will be generated in the background. - - IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating". - - After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes). -""" - -_TOOL_INSTRUCTIONS["generate_video_presentation"] = """ -- generate_video_presentation: Generate a video presentation from provided content. - - Use this when the user asks to create a video, presentation, slides, or slide deck. - - Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation" - - Args: - - source_content: The text content to turn into a presentation. The more detailed, the better. - - video_title: Optional title (default: "SurfSense Presentation") - - user_prompt: Optional style instructions (e.g., "Make it technical and detailed") - - After calling this tool, inform the user that generation has started and they will see the presentation when it's ready. -""" - -_TOOL_INSTRUCTIONS["generate_report"] = """ -- generate_report: Generate or revise a structured Markdown report artifact. - - WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable: - * Creation verbs: write, create, generate, draft, produce, summarize into, turn into, make - * Modification verbs: revise, update, expand, add (a section), rewrite, make (it shorter/longer/formal) - * Example triggers: "generate a report about...", "write a document on...", "add a section about budget", "make the report shorter", "rewrite in formal tone" - - WHEN NOT TO CALL THIS TOOL (answer in chat instead): - * Questions or discussion about the report: "What can we add?", "What's missing?", "Is the data accurate?", "How could this be improved?" - * Suggestions or brainstorming: "What other topics could be covered?", "What else could be added?", "What would make this better?" - * Asking for explanations: "Can you explain section 2?", "Why did you include that?", "What does this part mean?" - * Quick follow-ups or critiques: "Is the conclusion strong enough?", "Are there any gaps?", "What about the competitors?" - * THE TEST: Does the message contain a creation/modification VERB (from the list above) directed at producing or changing a deliverable? If NO verb → answer conversationally in chat. Do NOT assume the user wants a revision just because a report exists in the conversation. - - IMPORTANT FORMAT RULE: Reports are ALWAYS generated in Markdown. - - Args: - - topic: Short title for the report (max ~8 words). - - source_content: The text content to base the report on. - * For source_strategy="conversation" or "provided": Include a comprehensive summary of the relevant content. - * For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally. - * For source_strategy="auto": Include what you have; the tool searches KB if it's not enough. - - source_strategy: Controls how the tool collects source material. One of: - * "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. - * "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. - * "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries. - * "provided" — Use only what is in source_content (default, backward-compatible). - - search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated. - - report_style: Controls report depth. Options: "detailed" (DEFAULT), "deep_research", "brief". - Use "brief" ONLY when the user explicitly asks for a short/concise/one-page report (e.g., "one page", "keep it short", "brief report", "500 words"). Default to "detailed" for all other requests. - - user_instructions: Optional specific instructions (e.g., "focus on financial impacts", "include recommendations"). When revising (parent_report_id set), describe WHAT TO CHANGE. If the user mentions a length preference (e.g., "one page", "500 words", "2 pages"), include that VERBATIM here AND set report_style="brief". - - parent_report_id: Set this to the report_id from a previous generate_report result when the user wants to MODIFY an existing report. Do NOT set it for new reports or questions about reports. - - Returns: A dictionary with status "ready" or "failed", report_id, title, and word_count. - - The report is generated immediately in Markdown and displayed inline in the chat. - - Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report. - - SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly): - * If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content. - * If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries. - * If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries. - * When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content. - * NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally. - - AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat. -""" - -_TOOL_INSTRUCTIONS["generate_image"] = """ -- generate_image: Generate images from text descriptions using AI image models. - - Use this when the user asks you to create, generate, draw, design, or make an image. - - Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork" - - Args: - - prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood. - - n: Number of images to generate (1-4, default: 1) - - Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat. - - IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim - - expand and improve the prompt with specific details about style, lighting, composition, and mood. - - If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details. -""" - -_TOOL_INSTRUCTIONS["scrape_webpage"] = """ -- scrape_webpage: Scrape and extract the main content from a webpage. - - Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage. - - CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying): - * When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL - * When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices) - * When a URL was mentioned earlier in the conversation and the user asks for its actual content - * When preloaded `/documents/` data is insufficient and the user wants more - - Trigger scenarios: - * "Read this article and summarize it" - * "What does this page say about X?" - * "Summarize this blog post for me" - * "Tell me the key points from this article" - * "What's in this webpage?" - * "Can you analyze this article?" - * "Can you get the live table/data from [URL]?" - * "Scrape it" / "Can you scrape that?" (referring to a previously mentioned URL) - * "Fetch the content from [URL]" - * "Pull the data from that page" - - Args: - - url: The URL of the webpage to scrape (must be HTTP/HTTPS) - - max_length: Maximum content length to return (default: 50000 chars) - - Returns: The page title, description, full content (in markdown), word count, and metadata - - After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points. - - Reference the source using markdown links [descriptive text](url) — never bare URLs. - - IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`. - * When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: `![alt text](image_url)`. - * This makes your response more visual and engaging. - * Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content. - * Don't show every image - just the most relevant 1-3 images that enhance understanding. -""" - -_TOOL_INSTRUCTIONS["web_search"] = """ -- web_search: Search the web for real-time information using all configured search engines. - - Use this for current events, news, prices, weather, public facts, or any question requiring - up-to-date information from the internet. - - This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in - parallel and merges the results. - - IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data - (e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call - `web_search` instead of answering from memory. - - For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet - access before attempting a web search. - - If the search returns no relevant results, explain that web sources did not return enough - data and ask the user if they want you to retry with a refined query. - - Args: - - query: The search query - use specific, descriptive terms - - top_k: Number of results to retrieve (default: 10, max: 50) - - If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content. - - When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs. -""" - -# Memory tool instructions have private and shared variants. -# We store them keyed as "update_memory" with sub-keys. -_MEMORY_TOOL_INSTRUCTIONS: dict[str, dict[str, str]] = { - "update_memory": { - "private": """ -- update_memory: Update your personal memory document about the user. - - Your current memory is already in in your context. The `chars` and - `limit` attributes show your current usage and the maximum allowed size. - - This is your curated long-term memory — the distilled essence of what you know about - the user, not raw conversation logs. - - Call update_memory when: - * The user explicitly asks to remember or forget something - * The user shares durable facts or preferences that will matter in future conversations - - The user's first name is provided in . Use it in memory entries - instead of "the user" (e.g. "{name} works at..." not "The user works at..."). - Do not store the name itself as a separate memory entry. - - Do not store short-lived or ephemeral info: one-off questions, greetings, - session logistics, or things that only matter for the current task. - - Args: - - updated_memory: The FULL updated markdown document (not a diff). - Merge new facts with existing ones, update contradictions, remove outdated entries. - Treat every update as a curation pass — consolidate, don't just append. - - Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text - Markers: - [fact] — durable facts (role, background, projects, tools, expertise) - [pref] — preferences (response style, languages, formats, tools) - [instr] — standing instructions (always/never do, response rules) - - Keep it concise and well under the character limit shown in . - - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and - natural. Do NOT include the user's name in headings. Organize by context — e.g. - who they are, what they're focused on, how they prefer things. Create, split, or - merge headings freely as the memory grows. - - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant - details and context rather than just a few words. - - During consolidation, prioritize keeping: [instr] > [pref] > [fact]. -""", - "shared": """ -- update_memory: Update the team's shared memory document for this search space. - - Your current team memory is already in in your context. The `chars` - and `limit` attributes show current usage and the maximum allowed size. - - This is the team's curated long-term memory — decisions, conventions, key facts. - - NEVER store personal memory in team memory (e.g. personal bio, individual - preferences, or user-only standing instructions). - - Call update_memory when: - * A team member explicitly asks to remember or forget something - * The conversation surfaces durable team decisions, conventions, or facts - that will matter in future conversations - - Do not store short-lived or ephemeral info: one-off questions, greetings, - session logistics, or things that only matter for the current task. - - Args: - - updated_memory: The FULL updated markdown document (not a diff). - Merge new facts with existing ones, update contradictions, remove outdated entries. - Treat every update as a curation pass — consolidate, don't just append. - - Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text - Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory. - - Keep it concise and well under the character limit shown in . - - Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and - natural. Organize by context — e.g. what the team decided, current architecture, - active processes. Create, split, or merge headings freely as the memory grows. - - Each entry MUST be a single bullet point. Be descriptive but concise — include relevant - details and context rather than just a few words. - - During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities. -""", - }, -} - -_MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = { - "update_memory": { - "private": """ -- Alex, is empty. User: "I'm a space enthusiast, explain astrophage to me" - - The user casually shared a durable fact. Use their first name in the entry, short neutral heading: - update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n") -- User: "Remember that I prefer concise answers over detailed explanations" - - Durable preference. Merge with existing memory, add a new heading: - update_memory(updated_memory="## Interests & background\\n- (2025-03-15) [fact] Alex is a space enthusiast\\n\\n## Response style\\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\\n") -- User: "I actually moved to Tokyo last month" - - Updated fact, date prefix reflects when recorded: - update_memory(updated_memory="## Interests & background\\n...\\n\\n## Personal context\\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\\n...") -- User: "I'm a freelance photographer working on a nature documentary" - - Durable background info under a fitting heading: - update_memory(updated_memory="...\\n\\n## Current focus\\n- (2025-03-15) [fact] Alex is a freelance photographer\\n- (2025-03-15) [fact] Alex is working on a nature documentary\\n") -- User: "Always respond in bullet points" - - Standing instruction: - update_memory(updated_memory="...\\n\\n## Response style\\n- (2025-03-15) [instr] Always respond to Alex in bullet points\\n") -""", - "shared": """ -- User: "Let's remember that we decided to do weekly standup meetings on Mondays" - - Durable team decision: - update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\\n...") -- User: "Our office is in downtown Seattle, 5th floor" - - Durable team fact: - update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\\n...") -""", - }, -} - -# Per-tool examples keyed by tool name. Only examples for enabled tools are included. -_TOOL_EXAMPLES: dict[str, str] = {} - -_TOOL_EXAMPLES["search_surfsense_docs"] = """ -- User: "How do I install SurfSense?" - - Call: `search_surfsense_docs(query="installation setup")` -- User: "What connectors does SurfSense support?" - - Call: `search_surfsense_docs(query="available connectors integrations")` -- User: "How do I set up the Notion connector?" - - Call: `search_surfsense_docs(query="Notion connector setup configuration")` -- User: "How do I use Docker to run SurfSense?" - - Call: `search_surfsense_docs(query="Docker installation setup")` -""" - -_TOOL_EXAMPLES["generate_podcast"] = """ -- User: "Give me a podcast about AI trends based on what we discussed" - - First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")` -- User: "Create a podcast summary of this conversation" - - Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")` -- User: "Make a podcast about quantum computing" - - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")` -""" - -_TOOL_EXAMPLES["generate_video_presentation"] = """ -- User: "Give me a presentation about AI trends based on what we discussed" - - First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")` -- User: "Create slides summarizing this conversation" - - Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")` -- User: "Make a video presentation about quantum computing" - - First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")` -""" - -_TOOL_EXAMPLES["generate_report"] = """ -- User: "Generate a report about AI trends" - - Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")` - - WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search. -- User: "Write a research report from this conversation" - - Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\\n\\n...", report_style="deep_research")` - - WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation". -- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies" - - Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=, user_instructions="Add a new section about carbon capture technologies")` - - WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id. -- User: (after a report was generated) "What else could we add to have more depth?" - - Do NOT call generate_report. Answer in chat with suggestions. - - WHY: No creation/modification verb directed at producing a deliverable. -""" - -_TOOL_EXAMPLES["scrape_webpage"] = """ -- User: "Check out https://dev.to/some-article" - - Call: `scrape_webpage(url="https://dev.to/some-article")` - - Respond with a structured analysis — key points, takeaways. -- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends" - - Call: `scrape_webpage(url="https://example.com/blog/ai-trends")` - - Respond with a thorough summary using headings and bullet points. -- User: (after discussing https://example.com/stats) "Can you get the live data from that page?" - - Call: `scrape_webpage(url="https://example.com/stats")` - - IMPORTANT: Always attempt scraping first. Never refuse before trying the tool. -- User: "https://example.com/blog/weekend-recipes" - - Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")` - - When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content. -""" - -_TOOL_EXAMPLES["generate_image"] = """ -- User: "Generate an image of a cat" - - Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")` - - The generated image will automatically be displayed in the chat. -- User: "Draw me a logo for a coffee shop called Bean Dream" - - Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")` - - The generated image will automatically be displayed in the chat. -- User: "Show me this image: https://example.com/image.png" - - Simply include it in your response using markdown: `![Image](https://example.com/image.png)` -- User uploads an image file and asks: "What is this image about?" - - The user's uploaded image is already visible in the chat. - - Simply analyze the image content and respond directly. -""" - -_TOOL_EXAMPLES["web_search"] = """ -- User: "What's the current USD to INR exchange rate?" - - Call: `web_search(query="current USD to INR exchange rate")` - - Then answer using the returned web results with citations. -- User: "What's the latest news about AI?" - - Call: `web_search(query="latest AI news today")` -- User: "What's the weather in New York?" - - Call: `web_search(query="weather New York today")` -""" - -_TOOL_INSTRUCTIONS["generate_resume"] = """ -- generate_resume: Generate or revise a professional resume as a Typst document. - - WHEN TO CALL: The user asks to create, build, generate, write, or draft a resume or CV. - Also when they ask to modify, update, or revise an existing resume from this conversation. - - WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing - a resume without making changes. For cover letters, use generate_report instead. - - The tool produces Typst source code that is compiled to a PDF preview automatically. - - PAGE POLICY: - - Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more. - - If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value. - - Args: - - user_info: The user's resume content — work experience, education, skills, contact - info, etc. Can be structured or unstructured text. - CRITICAL: user_info must be COMPREHENSIVE. Do NOT just pass the user's raw message. - You MUST gather and consolidate ALL available information: - * Content from referenced/mentioned documents (e.g., uploaded resumes, CVs, LinkedIn profiles) - that appear in the conversation context — extract and include their FULL content. - * Information the user shared across multiple messages in the conversation. - * Any relevant details from knowledge base search results in the context. - The more complete the user_info, the better the resume. Include names, contact info, - work experience with dates, education, skills, projects, certifications — everything available. - - user_instructions: Optional style or content preferences (e.g. "emphasize leadership", - "keep it to one page"). For revisions, describe what to change. - - parent_report_id: Set this when the user wants to MODIFY an existing resume from - this conversation. Use the report_id from a previous generate_resume result. - - max_pages: Maximum resume length in pages (integer 1-5). Default is 1. - - Returns: Dict with status, report_id, title, and content_type. - - After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically. - - VERSIONING: Same rules as generate_report — set parent_report_id for modifications - of an existing resume, leave as None for new resumes. -""" - -_TOOL_EXAMPLES["generate_resume"] = """ -- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..." - - Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)` - - WHY: Has creation verb "build" + resume → call the tool. -- User: "Create my CV with this info: [experience, education, skills]" - - Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)` -- User: "Build me a resume" (and there is a resume/CV document in the conversation context) - - Extract the FULL content from the document in context, then call: - `generate_resume(user_info="Name: John Doe\\nEmail: john@example.com\\n\\nExperience:\\n- Senior Engineer at Acme Corp (2020-2024)\\n Led team of 5...\\n\\nEducation:\\n- BS Computer Science, MIT (2016-2020)\\n\\nSkills: Python, TypeScript, AWS...", max_pages=1)` - - WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents. -- User: (after resume generated) "Change my title to Senior Engineer" - - Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=, max_pages=1)` - - WHY: Modification verb "change" + refers to existing resume → set parent_report_id. -- User: (after resume generated) "Make this 2 pages and expand projects" - - Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=, max_pages=2)` - - WHY: Explicit page increase request → set max_pages to 2. -- User: "How should I structure my resume?" - - Do NOT call generate_resume. Answer in chat with advice. - - WHY: No creation/modification verb. -""" - -# All tool names that have prompt instructions (order matters for prompt readability) -_ALL_TOOL_NAMES_ORDERED = [ - "search_surfsense_docs", - "web_search", - "generate_podcast", - "generate_video_presentation", - "generate_report", - "generate_resume", - "generate_image", - "scrape_webpage", - "update_memory", -] - - -def _format_tool_name(name: str) -> str: - """Convert snake_case tool name to a human-readable label.""" - return name.replace("_", " ").title() - - -def _get_tools_instructions( - thread_visibility: ChatVisibility | None = None, - enabled_tool_names: set[str] | None = None, - disabled_tool_names: set[str] | None = None, -) -> str: - """Build tools instructions containing only the enabled tools. - - Args: - thread_visibility: Private vs shared — affects memory tool wording. - enabled_tool_names: Set of tool names that are actually bound to the agent. - When None, all tools are included (backward-compatible default). - disabled_tool_names: Set of tool names that the user explicitly disabled. - When provided, a note is appended telling the model about these tools - so it can inform the user they can re-enable them. - """ - visibility = thread_visibility or ChatVisibility.PRIVATE - memory_variant = ( - "shared" if visibility == ChatVisibility.SEARCH_SPACE else "private" - ) - - parts: list[str] = [_TOOLS_PREAMBLE] - examples: list[str] = [] - - for tool_name in _ALL_TOOL_NAMES_ORDERED: - if enabled_tool_names is not None and tool_name not in enabled_tool_names: - continue - - if tool_name in _TOOL_INSTRUCTIONS: - parts.append(_TOOL_INSTRUCTIONS[tool_name]) - elif tool_name in _MEMORY_TOOL_INSTRUCTIONS: - parts.append(_MEMORY_TOOL_INSTRUCTIONS[tool_name][memory_variant]) - - if tool_name in _TOOL_EXAMPLES: - examples.append(_TOOL_EXAMPLES[tool_name]) - elif tool_name in _MEMORY_TOOL_EXAMPLES: - examples.append(_MEMORY_TOOL_EXAMPLES[tool_name][memory_variant]) - - # Append a note about user-disabled tools so the model can inform the user - known_disabled = ( - disabled_tool_names & set(_ALL_TOOL_NAMES_ORDERED) - if disabled_tool_names - else set() - ) - if known_disabled: - disabled_list = ", ".join( - _format_tool_name(n) for n in _ALL_TOOL_NAMES_ORDERED if n in known_disabled - ) - parts.append(f""" -DISABLED TOOLS (by user): -The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}. -You do NOT have access to these tools and MUST NOT claim you can use them. -If the user asks about a capability provided by a disabled tool, let them know the relevant tool -is currently disabled and they can re-enable it. -""") - - parts.append("\n\n") - - if examples: - parts.append("") - parts.extend(examples) - parts.append("\n") - - return "".join(parts) - - -# Backward-compatible constant: all tools included (private memory variant) -SURFSENSE_TOOLS_INSTRUCTIONS = _get_tools_instructions() - - -SURFSENSE_CITATION_INSTRUCTIONS = """ - -CRITICAL CITATION REQUIREMENTS: - -1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `` tag inside ``. -2. Make sure ALL factual statements from the documents have proper citations. -3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2]. -4. You MUST use the exact chunk_id values from the `` attributes. Do not create your own citation numbers. -5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value. -6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags. -7. Do not return citations as clickable links. -8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only. -9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting. -10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `` tags. -11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up. - - -The documents you receive are structured like this: - -**Knowledge base documents (numeric chunk IDs):** - - - 42 - GITHUB_CONNECTOR - <![CDATA[Some repo / file / issue title]]> - - - - - - - - - - -**Web search results (URL chunk IDs):** - - - WEB_SEARCH - <![CDATA[Some web search result]]> - - - - - - - - -IMPORTANT: You MUST cite using the EXACT chunk ids from the `` tags. -- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45). -- For live web search results, chunk ids are URLs (e.g. https://example.com/article). -Do NOT cite document_id. Always use the chunk id. - - - -- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `` tag -- Citations should appear at the end of the sentence containing the information they support -- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] -- No need to return references section. Just citations in answer. -- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format -- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only -- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess -- Copy the EXACT chunk id from the XML - if it says ``, use [citation:doc-123] -- If the chunk id is a URL like ``, use [citation:https://example.com/page] - - - -CORRECT citation formats: -- [citation:5] (numeric chunk ID from knowledge base) -- [citation:doc-123] (for Surfsense documentation chunks) -- [citation:https://example.com/article] (URL chunk ID from web search results) -- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations) - -INCORRECT citation formats (DO NOT use): -- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense)) -- Using parentheses around brackets: ([citation:5]) -- Using hyperlinked text: [link to source 5](https://example.com) -- Using footnote style: ... library¹ -- Making up source IDs when source_id is unknown -- Using old IEEE format: [1], [2], [3] -- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5] - - - -Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5]. - -According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources. - -However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead. - - -""" - -# Anti-citation prompt - used when citations are disabled -# This explicitly tells the model NOT to include citations -SURFSENSE_NO_CITATION_INSTRUCTIONS = """ - -IMPORTANT: Citations are DISABLED for this configuration. - -DO NOT include any citations in your responses. Specifically: -1. Do NOT use the [citation:chunk_id] format anywhere in your response. -2. Do NOT reference document IDs, chunk IDs, or source IDs. -3. Simply provide the information naturally without any citation markers. -4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly. - -When answering questions based on documents from the knowledge base: -- Present the information directly and confidently -- Do not mention that information comes from specific documents or chunks -- Integrate facts naturally into your response without attribution markers - -Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation. - -""" - - -def _build_mcp_routing_block( - mcp_connector_tools: dict[str, list[str]] | None, -) -> str: - """Build an additional tool routing block for generic MCP connectors. - - When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know - those tools exist and should be called directly — not searched in the - knowledge base. - """ - if not mcp_connector_tools: - return "" - - lines = [ - "\n", - "You also have direct tools from these user-connected MCP servers.", - "Their data is NEVER in the knowledge base — call their tools directly.", - "", - ] - for server_name, tool_names in mcp_connector_tools.items(): - lines.append(f"- {server_name} → {', '.join(tool_names)}") - lines.append("\n") - return "\n".join(lines) +from .prompts.composer import ( + _read_fragment, + compose_system_prompt, + detect_provider_variant, +) + +# Public re-exports for backwards compatibility (some legacy code reads the +# raw default-instructions text directly). +SURFSENSE_SYSTEM_INSTRUCTIONS_TEMPLATE = ( + "\nDefault SurfSense agent system instructions are now\n" + "composed from prompts/base/*.md. See compose_system_prompt() for details.\n" + "" +) + +# Citation block re-exposed for legacy importers that referenced this constant +# directly. The composer is the canonical source; this is a frozen snapshot +# loaded at module-init time. +SURFSENSE_CITATION_INSTRUCTIONS = _read_fragment("base/citations_on.md") +SURFSENSE_NO_CITATION_INSTRUCTIONS = _read_fragment("base/citations_off.md") def build_surfsense_system_prompt( @@ -845,36 +47,23 @@ def build_surfsense_system_prompt( enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, mcp_connector_tools: dict[str, list[str]] | None = None, + *, + model_name: str | None = None, ) -> str: + """Build the default SurfSense system prompt (citations on, defaults). + + See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + for full parameter docs. """ - Build the SurfSense system prompt with default settings. - - This is a convenience function that builds the prompt with: - - Default system instructions - - Tools instructions (only for enabled tools) - - Citation instructions enabled - - Args: - today: Optional datetime for today's date (defaults to current UTC date) - thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. - enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. - disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. - mcp_connector_tools: Mapping of MCP server display name → list of tool names - for generic MCP connectors. Injected into the system prompt so the LLM - knows to call these tools directly. - - Returns: - Complete system prompt string - """ - - visibility = thread_visibility or ChatVisibility.PRIVATE - system_instructions = _get_system_instructions(visibility, today) - system_instructions += _build_mcp_routing_block(mcp_connector_tools) - tools_instructions = _get_tools_instructions( - visibility, enabled_tool_names, disabled_tool_names + return compose_system_prompt( + today=today, + thread_visibility=thread_visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + mcp_connector_tools=mcp_connector_tools, + citations_enabled=True, + model_name=model_name, ) - citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS - return system_instructions + tools_instructions + citation_instructions def build_configurable_system_prompt( @@ -886,75 +75,54 @@ def build_configurable_system_prompt( enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, mcp_connector_tools: dict[str, list[str]] | None = None, + *, + model_name: str | None = None, ) -> str: + """Build a configurable SurfSense system prompt (NewLLMConfig path). + + See :func:`app.agents.new_chat.prompts.composer.compose_system_prompt` + for full parameter docs. """ - Build a configurable SurfSense system prompt based on NewLLMConfig settings. - - The prompt is composed of three parts: - 1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS - 2. Tools Instructions - only for enabled tools, with a note about disabled ones - 3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS - - Args: - custom_system_instructions: Custom system instructions to use. If empty/None and - use_default_system_instructions is True, defaults to - SURFSENSE_SYSTEM_INSTRUCTIONS. - use_default_system_instructions: Whether to use default instructions when - custom_system_instructions is empty/None. - citations_enabled: Whether to include citation instructions (True) or - anti-citation instructions (False). - today: Optional datetime for today's date (defaults to current UTC date) - thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. - enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. - disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. - mcp_connector_tools: Mapping of MCP server display name → list of tool names - for generic MCP connectors. Injected into the system prompt so the LLM - knows to call these tools directly. - - Returns: - Complete system prompt string - """ - resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat() - - # Determine system instructions - if custom_system_instructions and custom_system_instructions.strip(): - system_instructions = custom_system_instructions.format( - resolved_today=resolved_today - ) - elif use_default_system_instructions: - visibility = thread_visibility or ChatVisibility.PRIVATE - system_instructions = _get_system_instructions(visibility, today) - else: - system_instructions = "" - - system_instructions += _build_mcp_routing_block(mcp_connector_tools) - - # Tools instructions: only include enabled tools, note disabled ones - tools_instructions = _get_tools_instructions( - thread_visibility, enabled_tool_names, disabled_tool_names + return compose_system_prompt( + today=today, + thread_visibility=thread_visibility, + enabled_tool_names=enabled_tool_names, + disabled_tool_names=disabled_tool_names, + mcp_connector_tools=mcp_connector_tools, + custom_system_instructions=custom_system_instructions, + use_default_system_instructions=use_default_system_instructions, + citations_enabled=citations_enabled, + model_name=model_name, ) - # Citation instructions based on toggle - citation_instructions = ( - SURFSENSE_CITATION_INSTRUCTIONS - if citations_enabled - else SURFSENSE_NO_CITATION_INSTRUCTIONS - ) - - return system_instructions + tools_instructions + citation_instructions - def get_default_system_instructions() -> str: + """Return the default ```` block (no tools / citations). + + Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. + The output reflects the current fragment tree, not a baked-in constant. """ - Get the default system instructions template. + resolved_today = datetime.now(UTC).date().isoformat() + from .prompts.composer import _build_system_instructions # local import - This is useful for populating the UI with the default value when - creating a new NewLLMConfig. - - Returns: - Default system instructions string (with {resolved_today} placeholder) - """ - return SURFSENSE_SYSTEM_INSTRUCTIONS.strip() + return _build_system_instructions( + visibility=ChatVisibility.PRIVATE, + resolved_today=resolved_today, + ).strip() +# Backwards compatibility — some modules import the constant directly. SURFSENSE_SYSTEM_PROMPT = build_surfsense_system_prompt() + + +__all__ = [ + "SURFSENSE_CITATION_INSTRUCTIONS", + "SURFSENSE_NO_CITATION_INSTRUCTIONS", + "SURFSENSE_SYSTEM_INSTRUCTIONS_TEMPLATE", + "SURFSENSE_SYSTEM_PROMPT", + "build_configurable_system_prompt", + "build_surfsense_system_prompt", + "compose_system_prompt", + "detect_provider_variant", + "get_default_system_instructions", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py index e0b1978e1..5675a42e6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -57,7 +57,11 @@ def create_get_connected_accounts_tool( async def _run(service: str) -> list[dict[str, Any]]: svc_cfg = MCP_SERVICES.get(service) if not svc_cfg: - return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}] + return [ + { + "error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}" + } + ] try: connector_type = SearchSourceConnectorType(svc_cfg.connector_type) @@ -74,7 +78,11 @@ def create_get_connected_accounts_tool( connectors = result.scalars().all() if not connectors: - return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}] + return [ + { + "error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings." + } + ] is_multi = len(connectors) > 1 diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py index 1f51e3660..c345f8a5e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py @@ -19,7 +19,8 @@ async def get_discord_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DISCORD_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py index a33b88aa0..3cc99ac17 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -23,16 +23,24 @@ def create_list_discord_channels_tool( Dictionary with status and a list of channels (id, name). """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} guild_id = get_guild_id(connector) if not guild_id: - return {"status": "error", "message": "No guild ID in Discord connector config."} + return { + "status": "error", + "message": "No guild ID in Discord connector config.", + } token = get_bot_token(connector) @@ -44,9 +52,16 @@ def create_list_discord_channels_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code != 200: - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } # Type 0 = text channel channels = [ @@ -54,7 +69,12 @@ def create_list_discord_channels_tool( for ch in resp.json() if ch.get("type") == 0 ] - return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)} + return { + "status": "success", + "guild_id": guild_id, + "channels": channels, + "total": len(channels), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py index 852a9297b..d8bf989a1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -31,12 +31,17 @@ def create_read_discord_messages_tool( id, author, content, timestamp. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } limit = min(limit, 50) try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} @@ -51,11 +56,21 @@ def create_read_discord_messages_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code == 403: - return {"status": "error", "message": "Bot lacks permission to read this channel."} + return { + "status": "error", + "message": "Bot lacks permission to read this channel.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } messages = [ { @@ -67,7 +82,12 @@ def create_read_discord_messages_tool( for m in resp.json() ] - return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)} + return { + "status": "success", + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py index be4e6fdb2..236cd017a 100644 --- a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -35,13 +35,21 @@ def create_send_discord_message_tool( - If status is "rejected", the user explicitly declined. Do NOT retry. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Discord tool not properly configured."} + return { + "status": "error", + "message": "Discord tool not properly configured.", + } if len(content) > 2000: - return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."} + return { + "status": "error", + "message": "Message exceeds Discord's 2000-character limit.", + } try: - connector = await get_discord_connector(db_session, search_space_id, user_id) + connector = await get_discord_connector( + db_session, search_space_id, user_id + ) if not connector: return {"status": "error", "message": "No Discord connector found."} @@ -53,7 +61,10 @@ def create_send_discord_message_tool( ) if result.rejected: - return {"status": "rejected", "message": "User declined. Message was not sent."} + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } final_content = result.params.get("content", content) final_channel = result.params.get("channel_id", channel_id) @@ -72,11 +83,21 @@ def create_send_discord_message_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + return { + "status": "auth_error", + "message": "Discord bot token is invalid.", + "connector_type": "discord", + } if resp.status_code == 403: - return {"status": "error", "message": "Bot lacks permission to send messages in this channel."} + return { + "status": "error", + "message": "Bot lacks permission to send messages in this channel.", + } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Discord API error: {resp.status_code}", + } msg_data = resp.json() return { diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index d94d55b1a..3803fa39c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -20,7 +20,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import config -from app.db import ImageGeneration, ImageGenerationConfig, SearchSpace +from app.db import ( + ImageGeneration, + ImageGenerationConfig, + SearchSpace, + shielded_async_session, +) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, ImageGenRouterService, @@ -70,8 +75,13 @@ def create_generate_image_tool( Args: search_space_id: The search space ID (for config resolution) - db_session: Async database session + db_session: Reserved for compatibility with the tool registry. + The streaming task's ``AsyncSession`` is shared by every tool; + because AsyncSession is not concurrency-safe, parallel tool calls + would interleave flushes (e.g. podcast + image in the same step) + and poison the transaction. This tool opens its own session. """ + del db_session # use a fresh per-call session, see below @tool async def generate_image( @@ -93,110 +103,119 @@ def create_generate_image_tool( A dictionary containing the generated image(s) for display in the chat. """ try: - # Resolve the image generation config from the search space preference - result = await db_session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return {"error": "Search space not found"} - - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) - - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. - gen_kwargs: dict[str, Any] = {} - if n is not None and n > 1: - gen_kwargs["n"] = n - - # Call litellm based on config type - if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - return { - "error": "No image generation models configured. " - "Please add an image model in Settings > Image Models." - } - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs + # Use a per-call session so concurrent tool calls don't share an + # AsyncSession (which is not concurrency-safe). The streaming + # task's session is shared across every tool; without isolation, + # autoflushes from a concurrent writer poison this tool too. + async with shielded_async_session() as session: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - return {"error": f"Image generation config {config_id} not found"} + search_space = result.scalars().first() + if not search_space: + return {"error": "Search space not found"} - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + config_id = ( + search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID ) - gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - else: - # Positive ID = user-created ImageGenerationConfig - cfg_result = await db_session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id + # Build generation kwargs + # NOTE: size, quality, and style are intentionally NOT passed. + # Different models support different values for these params + # (e.g. DALL-E 3 wants "hd"/"standard" for quality while + # gpt-image-1 wants "high"/"medium"/"low"; size options also + # differ). Letting the model use its own defaults avoids errors. + gen_kwargs: dict[str, Any] = {} + if n is not None and n > 1: + gen_kwargs["n"] = n + + # Call litellm based on config type + if is_image_gen_auto_mode(config_id): + if not ImageGenRouterService.is_initialized(): + return { + "error": "No image generation models configured. " + "Please add an image model in Settings > Image Models." + } + response = await ImageGenRouterService.aimage_generation( + prompt=prompt, model="auto", **gen_kwargs ) - ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - return {"error": f"Image generation config {config_id} not found"} + elif config_id < 0: + cfg = _get_global_image_gen_config(config_id) + if not cfg: + return { + "error": f"Image generation config {config_id} not found" + } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, - ) - gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + model_string = _build_model_string( + cfg.get("provider", ""), + cfg["model_name"], + cfg.get("custom_provider"), + ) + gen_kwargs["api_key"] = cfg.get("api_key") + if cfg.get("api_base"): + gen_kwargs["api_base"] = cfg["api_base"] + if cfg.get("api_version"): + gen_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + gen_kwargs.update(cfg["litellm_params"]) - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + else: + # Positive ID = user-created ImageGenerationConfig + cfg_result = await session.execute( + select(ImageGenerationConfig).filter( + ImageGenerationConfig.id == config_id + ) + ) + db_cfg = cfg_result.scalars().first() + if not db_cfg: + return { + "error": f"Image generation config {config_id} not found" + } + + model_string = _build_model_string( + db_cfg.provider.value, + db_cfg.model_name, + db_cfg.custom_provider, + ) + gen_kwargs["api_key"] = db_cfg.api_key + if db_cfg.api_base: + gen_kwargs["api_base"] = db_cfg.api_base + if db_cfg.api_version: + gen_kwargs["api_version"] = db_cfg.api_version + if db_cfg.litellm_params: + gen_kwargs.update(db_cfg.litellm_params) + + response = await aimage_generation( + prompt=prompt, model=model_string, **gen_kwargs + ) + + # Parse the response and store in DB + response_dict = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) ) - # Parse the response and store in DB - response_dict = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) + # Generate a random access token for this image + access_token = generate_image_token() - # Generate a random access token for this image - access_token = generate_image_token() - - # Save to image_generations table for history - db_image_gen = ImageGeneration( - prompt=prompt, - model=getattr(response, "_hidden_params", {}).get("model"), - n=n, - image_generation_config_id=config_id, - response_data=response_dict, - search_space_id=search_space_id, - access_token=access_token, - ) - db_session.add(db_image_gen) - await db_session.commit() - await db_session.refresh(db_image_gen) + # Save to image_generations table for history + db_image_gen = ImageGeneration( + prompt=prompt, + model=getattr(response, "_hidden_params", {}).get("model"), + n=n, + image_generation_config_id=config_id, + response_data=response_dict, + search_space_id=search_space_id, + access_token=access_token, + ) + session.add(db_image_gen) + await session.commit() + await session.refresh(db_image_gen) + db_image_gen_id = db_image_gen.id # Extract image URLs from response images = response_dict.get("data", []) @@ -217,7 +236,7 @@ def create_generate_image_tool( backend_url = config.BACKEND_URL or "http://localhost:8000" image_url = ( f"{backend_url}/api/v1/image-generations/" - f"{db_image_gen.id}/image?token={access_token}" + f"{db_image_gen_id}/image?token={access_token}" ) else: return {"error": "No displayable image data in the response"} diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py index 9071f129a..deec1627c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -65,12 +65,22 @@ def create_read_gmail_email_tool( detail, error = await gmail.get_message_details(message_id) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "gmail"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } return {"status": "error", "message": error} if not detail: - return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."} + return { + "status": "not_found", + "message": f"Email with ID '{message_id}' not found.", + } content = gmail.format_message_to_markdown(detail) @@ -82,6 +92,9 @@ def create_read_gmail_email_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error reading Gmail email: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to read email. Please try again."} + return { + "status": "error", + "message": "Failed to read email. Please try again.", + } return read_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py index de43f03d0..2e363609e 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -125,12 +125,24 @@ def create_search_gmail_tool( max_results=max_results, query=query ) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "gmail"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "gmail", + } return {"status": "error", "message": error} if not messages_list: - return {"status": "success", "emails": [], "total": 0, "message": "No emails found."} + return { + "status": "success", + "emails": [], + "total": 0, + "message": "No emails found.", + } emails = [] for msg in messages_list: @@ -141,16 +153,18 @@ def create_search_gmail_tool( h["name"].lower(): h["value"] for h in detail.get("payload", {}).get("headers", []) } - emails.append({ - "message_id": detail.get("id"), - "thread_id": detail.get("threadId"), - "subject": headers.get("subject", "No Subject"), - "from": headers.get("from", "Unknown"), - "to": headers.get("to", ""), - "date": headers.get("date", ""), - "snippet": detail.get("snippet", ""), - "labels": detail.get("labelIds", []), - }) + emails.append( + { + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + } + ) return {"status": "success", "emails": emails, "total": len(emails)} @@ -160,6 +174,9 @@ def create_search_gmail_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error searching Gmail: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to search Gmail. Please try again."} + return { + "status": "error", + "message": "Failed to search Gmail. Please try again.", + } return search_gmail diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py index a622b0efa..dc6adb822 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -39,7 +39,10 @@ def create_search_calendar_events_tool( event_id, summary, start, end, location, attendees. """ if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "Calendar tool not properly configured."} + return { + "status": "error", + "message": "Calendar tool not properly configured.", + } max_results = min(max_results, 50) @@ -76,10 +79,22 @@ def create_search_calendar_events_tool( ) if error: - if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): - return {"status": "auth_error", "message": error, "connector_type": "google_calendar"} + if ( + "re-authenticate" in error.lower() + or "authentication failed" in error.lower() + ): + return { + "status": "auth_error", + "message": error, + "connector_type": "google_calendar", + } if "no events found" in error.lower(): - return {"status": "success", "events": [], "total": 0, "message": error} + return { + "status": "success", + "events": [], + "total": 0, + "message": error, + } return {"status": "error", "message": error} events = [] @@ -87,19 +102,19 @@ def create_search_calendar_events_tool( start = ev.get("start", {}) end = ev.get("end", {}) attendees_raw = ev.get("attendees", []) - events.append({ - "event_id": ev.get("id"), - "summary": ev.get("summary", "No Title"), - "start": start.get("dateTime") or start.get("date", ""), - "end": end.get("dateTime") or end.get("date", ""), - "location": ev.get("location", ""), - "description": ev.get("description", ""), - "html_link": ev.get("htmlLink", ""), - "attendees": [ - a.get("email", "") for a in attendees_raw[:10] - ], - "status": ev.get("status", ""), - }) + events.append( + { + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [a.get("email", "") for a in attendees_raw[:10]], + "status": ev.get("status", ""), + } + ) return {"status": "success", "events": events, "total": len(events)} @@ -109,6 +124,9 @@ def create_search_calendar_events_tool( if isinstance(e, GraphInterrupt): raise logger.error("Error searching calendar events: %s", e, exc_info=True) - return {"status": "error", "message": "Failed to search calendar events. Please try again."} + return { + "status": "error", + "message": "Failed to search calendar events. Please try again.", + } return search_calendar_events diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 89f02abf6..8480e57b1 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -130,7 +130,9 @@ def request_approval( try: decision_type, edited_params = _parse_decision(approval) except ValueError: - logger.warning("No approval decision received for %s — rejecting for safety", tool_name) + logger.warning( + "No approval decision received for %s — rejecting for safety", tool_name + ) return HITLResult(rejected=True, decision_type="error", params=params) logger.info("User decision for %s: %s", tool_name, decision_type) diff --git a/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py new file mode 100644 index 000000000..ea4bc0bc1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/invalid_tool.py @@ -0,0 +1,53 @@ +""" +The ``invalid`` fallback tool. + +When the model emits a tool call whose name doesn't match any registered +tool, :class:`ToolCallNameRepairMiddleware` rewrites the call to ``invalid`` +with the original name and a parser/validation error string. This tool's +execution then returns that error to the model so it can self-correct. + +Ported from OpenCode's ``packages/opencode/src/tool/invalid.ts`` — +LangChain has no equivalent fallback path; the default behavior on an +unknown tool name is a hard ``ToolNotFoundError`` which kills the turn. + +Critically, the :class:`ToolDefinition` for this tool is **excluded** from +the system-prompt tool list and from ``LLMToolSelectorMiddleware`` selection +(see ``ToolDefinition.always_include`` filtering in the registry) — the +model never advertises ``invalid`` as a callable. It only ever shows up +in the tool registry so LangGraph can dispatch the rewritten call. +""" + +from __future__ import annotations + +from langchain_core.tools import tool + +INVALID_TOOL_NAME = "invalid" +INVALID_TOOL_DESCRIPTION = "Do not use" + + +def _format_invalid_message(tool: str | None, error: str | None) -> str: + """Return the user-visible error string. Mirrors ``invalid.ts``.""" + name = tool or "" + detail = error or "(no error message provided)" + return ( + f"The arguments provided to the tool `{name}` are invalid: {detail}\n" + f"Read the tool's docstring carefully and try again with valid arguments." + ) + + +@tool(name_or_callable=INVALID_TOOL_NAME, description=INVALID_TOOL_DESCRIPTION) +def invalid_tool(tool: str | None = None, error: str | None = None) -> str: + """Return a human-readable explanation of a tool-call validation failure. + + Activated only when :class:`ToolCallNameRepairMiddleware` rewrites a + failed tool call to ``invalid`` with the original tool name and the + error message produced during validation. + """ + return _format_invalid_message(tool, error) + + +__all__ = [ + "INVALID_TOOL_DESCRIPTION", + "INVALID_TOOL_NAME", + "invalid_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py index 1d88161d6..37deb1525 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py @@ -17,7 +17,8 @@ async def get_luma_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LUMA_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py index 2217d29e6..0a24a988f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -62,7 +62,10 @@ def create_create_luma_event_tool( ) if result.rejected: - return {"status": "rejected", "message": "User declined. Event was not created."} + return { + "status": "rejected", + "message": "User declined. Event was not created.", + } final_name = result.params.get("name", name) final_start = result.params.get("start_at", start_at) @@ -90,11 +93,21 @@ def create_create_luma_event_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code == 403: - return {"status": "error", "message": "Luma Plus subscription required to create events via API."} + return { + "status": "error", + "message": "Luma Plus subscription required to create events via API.", + } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}", + } data = resp.json() event_id = data.get("api_id") or data.get("event", {}).get("api_id") diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py index cd4721758..aec5ad220 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -46,7 +46,9 @@ def create_list_luma_events_tool( async with httpx.AsyncClient(timeout=20.0) as client: while len(all_entries) < max_results: - params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))} + params: dict[str, Any] = { + "limit": min(100, max_results - len(all_entries)) + } if cursor: params["cursor"] = cursor @@ -57,9 +59,16 @@ def create_list_luma_events_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code != 200: - return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } data = resp.json() entries = data.get("entries", []) @@ -76,16 +85,18 @@ def create_list_luma_events_tool( for entry in all_entries[:max_results]: ev = entry.get("event", {}) geo = ev.get("geo_info", {}) - events.append({ - "event_id": entry.get("api_id"), - "name": ev.get("name", "Untitled"), - "start_at": ev.get("start_at", ""), - "end_at": ev.get("end_at", ""), - "timezone": ev.get("timezone", ""), - "location": geo.get("name", ""), - "url": ev.get("url", ""), - "visibility": ev.get("visibility", ""), - }) + events.append( + { + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + } + ) return {"status": "success", "events": events, "total": len(events)} diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py index eb3ac55c6..b37a9d617 100644 --- a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -44,11 +44,21 @@ def create_read_luma_event_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + return { + "status": "auth_error", + "message": "Luma API key is invalid.", + "connector_type": "luma", + } if resp.status_code == 404: - return {"status": "not_found", "message": f"Event '{event_id}' not found."} + return { + "status": "not_found", + "message": f"Event '{event_id}' not found.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Luma API error: {resp.status_code}", + } data = resp.json() ev = data.get("event", data) diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index b46ddbcc5..e28ac8bda 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -220,10 +220,8 @@ class MCPClient: logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) return result_str - except asyncio.TimeoutError: - logger.error( - "MCP tool '%s' timed out after %.0fs", tool_name, timeout - ) + except TimeoutError: + logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout) return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s" except RuntimeError as e: if "Invalid structured content" in str(e): diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index dfee24516..5b96ab374 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -35,7 +35,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type logger = logging.getLogger(__name__) @@ -105,13 +105,15 @@ def _create_dynamic_input_model_from_schema( description=( "Arguments to pass to this tool as a JSON object. " "Infer sensible key names from the tool name and description " - "(e.g. {\"search\": \"my query\"} for a search tool)." + '(e.g. {"search": "my query"} for a search tool).' ), ), ) model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" - model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions) + model = create_model( + model_name, __config__=ConfigDict(extra="allow"), **field_definitions + ) return model @@ -187,16 +189,23 @@ async def _create_mcp_tool_from_definition_stdio( except Exception as e: last_error = e if attempt < _TOOL_CALL_MAX_RETRIES - 1: - delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt) + delay = _TOOL_CALL_RETRY_DELAY * (2**attempt) logger.warning( "MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...", - tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay, + tool_name, + attempt + 1, + _TOOL_CALL_MAX_RETRIES, + e, + delay, ) await asyncio.sleep(delay) else: logger.error( "MCP tool '%s' failed after %d attempts: %s", - tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True, + tool_name, + _TOOL_CALL_MAX_RETRIES, + e, + exc_info=True, ) return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}" @@ -318,17 +327,22 @@ async def _create_mcp_tool_from_definition_http( try: result_str = await _do_mcp_call(headers, call_kwargs) - logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + logger.debug( + "MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str) + ) return result_str except Exception as first_err: if not _is_auth_error(first_err) or connector_id is None: - logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err) + logger.exception( + "MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err + ) return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}" logger.warning( "MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s", - exposed_name, connector_id, + exposed_name, + connector_id, ) fresh_headers = await _force_refresh_and_get_headers(connector_id) if fresh_headers is None: @@ -348,7 +362,8 @@ async def _create_mcp_tool_from_definition_http( except Exception as retry_err: logger.exception( "MCP HTTP tool '%s' still failing after token refresh: %s", - exposed_name, retry_err, + exposed_name, + retry_err, ) if _is_auth_error(retry_err): await _mark_connector_auth_expired(connector_id) @@ -393,7 +408,8 @@ async def _load_stdio_mcp_tools( if not command or not isinstance(command, str): logger.warning( "MCP connector %d (name: '%s') missing or invalid command field, skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -401,7 +417,8 @@ async def _load_stdio_mcp_tools( if not isinstance(args, list): logger.warning( "MCP connector %d (name: '%s') has invalid args field (must be list), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -409,7 +426,8 @@ async def _load_stdio_mcp_tools( if not isinstance(env, dict): logger.warning( "MCP connector %d (name: '%s') has invalid env field (must be dict), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -420,7 +438,9 @@ async def _load_stdio_mcp_tools( logger.info( "Discovered %d tools from stdio MCP server '%s' (connector %d)", - len(tool_definitions), command, connector_id, + len(tool_definitions), + command, + connector_id, ) for tool_def in tool_definitions: @@ -436,7 +456,9 @@ async def _load_stdio_mcp_tools( except Exception as e: logger.exception( "Failed to create tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, + tool_def.get("name"), + connector_id, + e, ) return tools @@ -468,7 +490,8 @@ async def _load_http_mcp_tools( if not url or not isinstance(url, str): logger.warning( "MCP connector %d (name: '%s') missing or invalid url field, skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -476,7 +499,8 @@ async def _load_http_mcp_tools( if not isinstance(headers, dict): logger.warning( "MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping", - connector_id, connector_name, + connector_id, + connector_name, ) return tools @@ -507,7 +531,9 @@ async def _load_http_mcp_tools( if not _is_auth_error(first_err) or connector_id is None: logger.exception( "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", - url, connector_id, first_err, + url, + connector_id, + first_err, ) return tools @@ -534,7 +560,8 @@ async def _load_http_mcp_tools( except Exception as retry_err: logger.exception( "HTTP MCP discovery for connector %d still failing after refresh: %s", - connector_id, retry_err, + connector_id, + retry_err, ) if _is_auth_error(retry_err): await _mark_connector_auth_expired(connector_id) @@ -543,17 +570,20 @@ async def _load_http_mcp_tools( total_discovered = len(tool_definitions) if allowed_set: - tool_definitions = [ - td for td in tool_definitions if td["name"] in allowed_set - ] + tool_definitions = [td for td in tool_definitions if td["name"] in allowed_set] logger.info( "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", - url, connector_id, len(tool_definitions), total_discovered, + url, + connector_id, + len(tool_definitions), + total_discovered, ) else: logger.info( "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", - total_discovered, url, connector_id, + total_discovered, + url, + connector_id, ) for tool_def in tool_definitions: @@ -573,7 +603,9 @@ async def _load_http_mcp_tools( except Exception as e: logger.exception( "Failed to create HTTP tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, + tool_def.get("name"), + connector_id, + e, ) return tools @@ -628,7 +660,7 @@ def _inject_oauth_headers( async def _refresh_connector_token( session: AsyncSession, - connector: "SearchSourceConnector", + connector: SearchSourceConnector, ) -> str | None: """Refresh the OAuth token for an MCP connector and persist the result. @@ -692,12 +724,8 @@ async def _refresh_connector_token( updated_oauth = dict(mcp_oauth) updated_oauth["access_token"] = enc.encrypt_token(new_access) if token_json.get("refresh_token"): - updated_oauth["refresh_token"] = enc.encrypt_token( - token_json["refresh_token"] - ) - updated_oauth["expires_at"] = ( - new_expires_at.isoformat() if new_expires_at else None - ) + updated_oauth["refresh_token"] = enc.encrypt_token(token_json["refresh_token"]) + updated_oauth["expires_at"] = new_expires_at.isoformat() if new_expires_at else None updated_cfg = {**cfg, "mcp_oauth": updated_oauth} updated_cfg.pop("auth_expired", None) @@ -713,7 +741,7 @@ async def _refresh_connector_token( async def _maybe_refresh_mcp_oauth_token( session: AsyncSession, - connector: "SearchSourceConnector", + connector: SearchSourceConnector, cfg: dict[str, Any], server_config: dict[str, Any], ) -> dict[str, Any]: @@ -731,10 +759,11 @@ async def _maybe_refresh_mcp_oauth_token( try: expires_at = datetime.fromisoformat(expires_at_str) if expires_at.tzinfo is None: - from datetime import timezone - expires_at = expires_at.replace(tzinfo=timezone.utc) + expires_at = expires_at.replace(tzinfo=UTC) - if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS): + if datetime.now(UTC) < expires_at - timedelta( + seconds=_TOKEN_REFRESH_BUFFER_SECONDS + ): return server_config except (ValueError, TypeError): return server_config @@ -744,7 +773,9 @@ async def _maybe_refresh_mcp_oauth_token( if not new_access: return server_config - logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id) + logger.info( + "Proactively refreshed MCP OAuth token for connector %s", connector.id + ) refreshed_config = dict(server_config) refreshed_config["headers"] = { @@ -920,7 +951,7 @@ async def load_mcp_tools( result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, - cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, JSONB).has_key("server_config"), ), ) @@ -956,13 +987,17 @@ async def load_mcp_tools( if not server_config or not isinstance(server_config, dict): logger.warning( "MCP connector %d (name: '%s') has invalid or missing server_config, skipping", - connector.id, connector.name, + connector.id, + connector.name, ) continue if cfg.get("mcp_oauth"): server_config = await _maybe_refresh_mcp_oauth_token( - session, connector, cfg, server_config, + session, + connector, + cfg, + server_config, ) cfg = connector.config or {} server_config = _inject_oauth_headers(cfg, server_config) @@ -995,22 +1030,25 @@ async def load_mcp_tools( if service_key: tool_name_prefix = f"{service_key}_{connector.id}" - discovery_tasks.append({ - "connector_id": connector.id, - "connector_name": connector.name, - "server_config": server_config, - "trusted_tools": trusted_tools, - "allowed_tools": allowed_tools, - "readonly_tools": readonly_tools, - "tool_name_prefix": tool_name_prefix, - "transport": server_config.get("transport", "stdio"), - "is_generic_mcp": svc_cfg is None, - }) + discovery_tasks.append( + { + "connector_id": connector.id, + "connector_name": connector.name, + "server_config": server_config, + "trusted_tools": trusted_tools, + "allowed_tools": allowed_tools, + "readonly_tools": readonly_tools, + "tool_name_prefix": tool_name_prefix, + "transport": server_config.get("transport", "stdio"), + "is_generic_mcp": svc_cfg is None, + } + ) except Exception as e: logger.exception( "Failed to prepare MCP connector %d: %s", - connector.id, e, + connector.id, + e, ) async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: @@ -1039,23 +1077,23 @@ async def load_mcp_tools( ), timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError: + except TimeoutError: logger.error( "MCP connector %d timed out after %ds during discovery", - task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS, + task["connector_id"], + _MCP_DISCOVERY_TIMEOUT_SECONDS, ) return [] except Exception as e: logger.exception( "Failed to load tools from MCP connector %d: %s", - task["connector_id"], e, + task["connector_id"], + e, ) return [] results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) - tools: list[StructuredTool] = [ - tool for sublist in results for tool in sublist - ] + tools: list[StructuredTool] = [tool for sublist in results for tool in sublist] _mcp_tools_cache[search_space_id] = (now, tools) @@ -1063,7 +1101,9 @@ async def load_mcp_tools( oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) del _mcp_tools_cache[oldest_key] - logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id) + logger.info( + "Loaded %d MCP tools for search space %d", len(tools), search_space_id + ) return tools except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py index 248a4f450..2c9b7fa0c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ b/surfsense_backend/app/agents/new_chat/tools/podcast.py @@ -11,7 +11,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Podcast, PodcastStatus +from app.db import Podcast, PodcastStatus, shielded_async_session def create_generate_podcast_tool( @@ -27,12 +27,16 @@ def create_generate_podcast_tool( Args: search_space_id: The user's search space ID - db_session: Database session for creating the podcast record + db_session: Reserved for future read-side use; the row is written via a + fresh, tool-local session so parallel tool calls (e.g. podcast + + video presentation in the same agent step) don't share an + ``AsyncSession`` (which is not concurrency-safe). thread_id: The chat thread ID for associating the podcast Returns: A configured tool function for generating podcasts """ + del db_session # writes use a fresh tool-local session, see below @tool async def generate_podcast( @@ -64,32 +68,40 @@ def create_generate_podcast_tool( - message: Status message (or "error" field if status is failed) """ try: - podcast = Podcast( - title=podcast_title, - status=PodcastStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - db_session.add(podcast) - await db_session.commit() - await db_session.refresh(podcast) + # Open a fresh session per call. The streaming task's session is + # shared between every tool, and ``AsyncSession`` is NOT safe for + # concurrent use: when the LLM emits parallel tool calls, two + # concurrent ``add()`` / ``commit()`` paths interleave and the + # second one hits "Session.add() during flush" → the transaction + # is poisoned for both tools. + async with shielded_async_session() as session: + podcast = Podcast( + title=podcast_title, + status=PodcastStatus.PENDING, + search_space_id=search_space_id, + thread_id=thread_id, + ) + session.add(podcast) + await session.commit() + await session.refresh(podcast) + podcast_id = podcast.id from app.tasks.celery_tasks.podcast_tasks import ( generate_content_podcast_task, ) task = generate_content_podcast_task.delay( - podcast_id=podcast.id, + podcast_id=podcast_id, source_content=source_content, search_space_id=search_space_id, user_prompt=user_prompt, ) - print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}") + print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}") return { "status": PodcastStatus.PENDING.value, - "podcast_id": podcast.id, + "podcast_id": podcast_id, "title": podcast_title, "message": "Podcast generation started. This may take a few minutes.", } diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 85c89b114..e8bab36fd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -43,6 +43,9 @@ from typing import Any from langchain_core.tools import BaseTool +from app.agents.new_chat.middleware.dedup_tool_calls import ( + wrap_dedup_key_by_arg_name, +) from app.db import ChatVisibility from .confluence import ( @@ -50,6 +53,7 @@ from .confluence import ( create_delete_confluence_page_tool, create_update_confluence_page_tool, ) +from .connected_accounts import create_get_connected_accounts_tool from .discord import ( create_list_discord_channels_tool, create_read_discord_messages_tool, @@ -78,7 +82,6 @@ from .google_drive import ( create_create_google_drive_file_tool, create_delete_google_drive_file_tool, ) -from .connected_accounts import create_get_connected_accounts_tool from .luma import ( create_create_luma_event_tool, create_list_luma_events_tool, @@ -108,6 +111,8 @@ from .update_memory import create_update_memory_tool, create_update_team_memory_ from .video_presentation import create_generate_video_presentation_tool from .web_search import create_web_search_tool +logger = logging.getLogger(__name__) + # ============================================================================= # Tool Definition # ============================================================================= @@ -125,6 +130,12 @@ class ToolDefinition: enabled_by_default: Whether the tool is enabled when no explicit config is provided required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``) that must be in ``available_connectors`` for the tool to be enabled. + dedup_key: Optional callable that maps a tool's ``args`` dict to a + string signature used by :class:`DedupHITLToolCallsMiddleware` + to drop duplicate calls within a single LLM response. + reverse: Optional callable that, given the tool's ``(args, result)``, + returns a ``ReverseDescriptor`` describing the inverse tool + invocation. Consumed by the snapshot/revert pipeline. """ @@ -135,6 +146,8 @@ class ToolDefinition: enabled_by_default: bool = True hidden: bool = False required_connector: str | None = None + dedup_key: Callable[[dict[str, Any]], str] | None = None + reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None # ============================================================================= @@ -288,6 +301,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_notion_page", @@ -299,6 +313,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title"), ), ToolDefinition( name="delete_notion_page", @@ -310,6 +325,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="NOTION_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title"), ), # ========================================================================= # GOOGLE DRIVE TOOLS - create files, delete files @@ -325,6 +341,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_DRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_google_drive_file", @@ -336,6 +353,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_DRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # DROPBOX TOOLS - create and trash files @@ -351,6 +369,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="DROPBOX_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_dropbox_file", @@ -362,6 +381,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="DROPBOX_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # ONEDRIVE TOOLS - create and trash files @@ -377,6 +397,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="ONEDRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), ToolDefinition( name="delete_onedrive_file", @@ -388,6 +409,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="ONEDRIVE_FILE", + dedup_key=wrap_dedup_key_by_arg_name("file_name"), ), # ========================================================================= # GOOGLE CALENDAR TOOLS - search, create, update, delete events @@ -414,6 +436,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_calendar_event", @@ -425,6 +448,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), ), ToolDefinition( name="delete_calendar_event", @@ -436,6 +460,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_CALENDAR_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"), ), # ========================================================================= # GMAIL TOOLS - search, read, create drafts, update drafts, send, trash @@ -473,6 +498,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("subject"), ), ToolDefinition( name="send_gmail_email", @@ -484,6 +510,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("subject"), ), ToolDefinition( name="trash_gmail_email", @@ -495,6 +522,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("email_subject_or_id"), ), ToolDefinition( name="update_gmail_draft", @@ -506,6 +534,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="GOOGLE_GMAIL_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("draft_subject_or_id"), ), # ========================================================================= # CONFLUENCE TOOLS - create, update, delete pages @@ -521,6 +550,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("title"), ), ToolDefinition( name="update_confluence_page", @@ -532,6 +562,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), ), ToolDefinition( name="delete_confluence_page", @@ -543,6 +574,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["db_session", "search_space_id", "user_id"], required_connector="CONFLUENCE_CONNECTOR", + dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"), ), # ========================================================================= # DISCORD TOOLS - list channels, read messages, send messages @@ -675,10 +707,7 @@ def get_connector_gated_tools( available_connectors: list[str] | None, ) -> list[str]: """Return tool names to disable""" - if available_connectors is None: - available = set() - else: - available = set(available_connectors) + available = set() if available_connectors is None else set(available_connectors) disabled: list[str] = [] for tool_def in BUILTIN_TOOLS: @@ -758,6 +787,24 @@ def build_tools( # Create the tool tool = tool_def.factory(dependencies) + # Propagate the registry-level metadata so middleware (e.g. + # ``DedupHITLToolCallsMiddleware``) and the action-log/revert + # pipeline can pick the resolvers up via ``tool.metadata`` without + # re-importing :data:`BUILTIN_TOOLS`. + if tool_def.dedup_key is not None or tool_def.reverse is not None: + existing_meta = getattr(tool, "metadata", None) or {} + merged_meta = dict(existing_meta) + if tool_def.dedup_key is not None: + merged_meta.setdefault("dedup_key", tool_def.dedup_key) + if tool_def.reverse is not None: + merged_meta.setdefault("reverse", tool_def.reverse) + try: + tool.metadata = merged_meta + except Exception: + logger.debug( + "Tool %s rejected metadata mutation; relying on registry lookup", + tool_def.name, + ) tools.append(tool) # Add any additional custom tools @@ -829,14 +876,16 @@ async def build_tools_async( tools.extend(mcp_tools) logging.info( "Registered %d MCP tools: %s", - len(mcp_tools), [t.name for t in mcp_tools], + len(mcp_tools), + [t.name for t in mcp_tools], ) except Exception as e: logging.exception("Failed to load MCP tools: %s", e) logging.info( "Total tools for agent: %d — %s", - len(tools), [t.name for t in tools], + len(tools), + [t.name for t in tools], ) return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py index f24f5502e..4345bb476 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py @@ -17,7 +17,8 @@ async def get_teams_connector( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.TEAMS_CONNECTOR, ) ) return result.scalars().first() diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py index a676595c1..d7b000853 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -35,12 +35,21 @@ def create_list_teams_channels_tool( headers = {"Authorization": f"Bearer {token}"} async with httpx.AsyncClient(timeout=20.0) as client: - teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers) + teams_resp = await client.get( + f"{GRAPH_API}/me/joinedTeams", headers=headers + ) if teams_resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if teams_resp.status_code != 200: - return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"} + return { + "status": "error", + "message": f"Graph API error: {teams_resp.status_code}", + } teams_data = teams_resp.json().get("value", []) result_teams = [] @@ -58,13 +67,19 @@ def create_list_teams_channels_tool( {"id": ch["id"], "name": ch.get("displayName", "")} for ch in ch_resp.json().get("value", []) ] - result_teams.append({ - "team_id": team_id, - "team_name": team.get("displayName", ""), - "channels": channels, - }) + result_teams.append( + { + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + } + ) - return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)} + return { + "status": "success", + "teams": result_teams, + "total_teams": len(result_teams), + } except Exception as e: from langgraph.errors import GraphInterrupt diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py index 90896cb95..d24a7e4d3 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -52,11 +52,21 @@ def create_read_teams_messages_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if resp.status_code == 403: - return {"status": "error", "message": "Insufficient permissions to read this channel."} + return { + "status": "error", + "message": "Insufficient permissions to read this channel.", + } if resp.status_code != 200: - return {"status": "error", "message": f"Graph API error: {resp.status_code}"} + return { + "status": "error", + "message": f"Graph API error: {resp.status_code}", + } raw_msgs = resp.json().get("value", []) messages = [] @@ -64,13 +74,15 @@ def create_read_teams_messages_tool( sender = m.get("from", {}) user_info = sender.get("user", {}) if sender else {} body = m.get("body", {}) - messages.append({ - "id": m.get("id"), - "sender": user_info.get("displayName", "Unknown"), - "content": body.get("content", ""), - "content_type": body.get("contentType", "text"), - "timestamp": m.get("createdDateTime", ""), - }) + messages.append( + { + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + } + ) return { "status": "success", diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py index ba3a515d9..fd8d00870 100644 --- a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -50,12 +50,19 @@ def create_send_teams_message_tool( result = request_approval( action_type="teams_send_message", tool_name="send_teams_message", - params={"team_id": team_id, "channel_id": channel_id, "content": content}, + params={ + "team_id": team_id, + "channel_id": channel_id, + "content": content, + }, context={"connector_id": connector.id}, ) if result.rejected: - return {"status": "rejected", "message": "User declined. Message was not sent."} + return { + "status": "rejected", + "message": "User declined. Message was not sent.", + } final_content = result.params.get("content", content) final_team = result.params.get("team_id", team_id) @@ -74,20 +81,27 @@ def create_send_teams_message_tool( ) if resp.status_code == 401: - return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + return { + "status": "auth_error", + "message": "Teams token expired. Please re-authenticate.", + "connector_type": "teams", + } if resp.status_code == 403: return { "status": "insufficient_permissions", "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", } if resp.status_code not in (200, 201): - return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"} + return { + "status": "error", + "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}", + } msg_data = resp.json() return { "status": "success", "message_id": msg_data.get("id"), - "message": f"Message sent to Teams channel.", + "message": "Message sent to Teams channel.", } except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/tool_response.py b/surfsense_backend/app/agents/new_chat/tools/tool_response.py index 5fb1864b7..8644ada5c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/tool_response.py +++ b/surfsense_backend/app/agents/new_chat/tools/tool_response.py @@ -6,7 +6,6 @@ from typing import Any class ToolResponse: - @staticmethod def success(message: str, **data: Any) -> dict[str, Any]: return {"status": "success", "message": message, **data} @@ -31,9 +30,7 @@ class ToolResponse: return {"status": "rejected", "message": message} @staticmethod - def not_found( - resource: str, identifier: str, **data: Any - ) -> dict[str, Any]: + def not_found(resource: str, identifier: str, **data: Any) -> dict[str, Any]: return { "status": "not_found", "error": f"{resource} '{identifier}' was not found.", diff --git a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py index a90e08ac3..7bf9a1c3b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/video_presentation.py +++ b/surfsense_backend/app/agents/new_chat/tools/video_presentation.py @@ -11,7 +11,7 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.db import VideoPresentation, VideoPresentationStatus +from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session def create_generate_video_presentation_tool( @@ -23,8 +23,11 @@ def create_generate_video_presentation_tool( Factory function to create the generate_video_presentation tool with injected dependencies. Pre-creates video presentation record with pending status so the ID is available - immediately for frontend polling. + immediately for frontend polling. The row is written via a fresh, tool-local + session so parallel tool calls (e.g. video + podcast in the same agent step) + don't share an ``AsyncSession`` (which is not concurrency-safe). """ + del db_session # writes use a fresh tool-local session, see below @tool async def generate_video_presentation( @@ -42,34 +45,40 @@ def create_generate_video_presentation_tool( user_prompt: Optional style/tone instructions. """ try: - video_pres = VideoPresentation( - title=video_title, - status=VideoPresentationStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - db_session.add(video_pres) - await db_session.commit() - await db_session.refresh(video_pres) + # See podcast.py for the rationale: parallel tool calls share the + # streaming session, and AsyncSession is not concurrency-safe — + # interleaved flushes produce "Session.add() during flush" and + # poison the transaction for every concurrent tool. + async with shielded_async_session() as session: + video_pres = VideoPresentation( + title=video_title, + status=VideoPresentationStatus.PENDING, + search_space_id=search_space_id, + thread_id=thread_id, + ) + session.add(video_pres) + await session.commit() + await session.refresh(video_pres) + video_pres_id = video_pres.id from app.tasks.celery_tasks.video_presentation_tasks import ( generate_video_presentation_task, ) task = generate_video_presentation_task.delay( - video_presentation_id=video_pres.id, + video_presentation_id=video_pres_id, source_content=source_content, search_space_id=search_space_id, user_prompt=user_prompt, ) print( - f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}" + f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}" ) return { "status": VideoPresentationStatus.PENDING.value, - "video_presentation_id": video_pres.id, + "video_presentation_id": video_pres_id, "title": video_title, "message": "Video presentation generation started. This may take a few minutes.", } diff --git a/surfsense_backend/app/connectors/exceptions.py b/surfsense_backend/app/connectors/exceptions.py index 32a1e7bdc..027adbb87 100644 --- a/surfsense_backend/app/connectors/exceptions.py +++ b/surfsense_backend/app/connectors/exceptions.py @@ -13,7 +13,6 @@ from typing import Any class ConnectorError(Exception): - def __init__( self, message: str, diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index e16590afc..75342a8e1 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -976,7 +976,15 @@ class Document(BaseModel, TimestampMixin): document_metadata = Column(JSON, nullable=True) content = Column(Text, nullable=False) - content_hash = Column(String, nullable=False, index=True, unique=True) + # ``content_hash`` is intentionally NOT globally unique. In a real + # filesystem two files at different paths can hold identical bytes, + # and the agent's ``write_file`` flow needs that semantic to support + # copy / duplicate operations. Path uniqueness lives on + # ``unique_identifier_hash`` (per search space). The hash remains + # indexed because connector indexers consult it as a change-detection + # / cross-source dedup hint via :func:`check_duplicate_document`. + # See migration 133. + content_hash = Column(String, nullable=False, index=True) unique_identifier_hash = Column(String, nullable=True, index=True, unique=True) embedding = Column(Vector(config.embedding_model_instance.dimension)) @@ -2250,6 +2258,202 @@ else: ) +class AgentActionLog(BaseModel): + """Append-only audit trail of every tool call dispatched by the agent. + + One row per ``ToolMessage`` produced; written by ``ActionLogMiddleware`` + in its ``aafter_tool`` hook. Rows are referenced by the + ``/api/threads/{thread_id}/revert/{action_id}`` route to look up an + action's stored ``reverse_descriptor`` and replay it. + + The table is intentionally narrow: large tool outputs are NOT stored + here. Result text lives in the langgraph checkpoint; this row only + keeps a short ``result_id`` (the LangChain ``ToolMessage.id`` or a + spilled-content path) for correlation. + """ + + __tablename__ = "agent_action_log" + + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + turn_id = Column(String(64), nullable=True, index=True) + message_id = Column(String(128), nullable=True, index=True) + tool_name = Column(String(255), nullable=False, index=True) + args = Column(JSONB, nullable=True) + result_id = Column(String(255), nullable=True) + reversible = Column( + Boolean, nullable=False, default=False, server_default=text("false") + ) + reverse_descriptor = Column(JSONB, nullable=True) + error = Column(JSONB, nullable=True) + reverse_of = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + __table_args__ = ( + Index("ix_agent_action_log_thread_created", "thread_id", "created_at"), + ) + + +class DocumentRevision(BaseModel): + """Snapshot of a :class:`Document` row taken before a mutating tool call. + + Written by :class:`KnowledgeBasePersistenceMiddleware` (or its safety-net + `commit_staged_filesystem_state`) ahead of any NOTE / FILE / EXTENSION + document write. The row is referenced by ``/revert/{action_id}`` to + restore the original content in place. + """ + + __tablename__ = "document_revisions" + + document_id = Column( + Integer, + ForeignKey("documents.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + content_before = Column(Text, nullable=True) + title_before = Column(String, nullable=True) + folder_id_before = Column(Integer, nullable=True) + chunks_before = Column(JSONB, nullable=True) + metadata_before = Column("metadata_before", JSONB, nullable=True) + created_by_turn_id = Column(String(64), nullable=True, index=True) + agent_action_id = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + +class FolderRevision(BaseModel): + """Snapshot of a :class:`Folder` row taken before a mkdir / move.""" + + __tablename__ = "folder_revisions" + + folder_id = Column( + Integer, + ForeignKey("folders.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + name_before = Column(String(255), nullable=True) + parent_id_before = Column(Integer, nullable=True) + position_before = Column(String(50), nullable=True) + created_by_turn_id = Column(String(64), nullable=True, index=True) + agent_action_id = Column( + Integer, + ForeignKey("agent_action_log.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + +class AgentPermissionRule(BaseModel): + """Persistent permission rule consumed by :class:`PermissionMiddleware`. + + Scoped at one of: search-space-wide (``user_id`` and ``thread_id`` NULL), + user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread + (``thread_id`` set). Loaded at agent build time and converted to + :class:`Rule` instances inside the agent factory. + """ + + __tablename__ = "agent_permission_rules" + + search_space_id = Column( + Integer, + ForeignKey("searchspaces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + thread_id = Column( + Integer, + ForeignKey("new_chat_threads.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) + permission = Column(String(255), nullable=False) + pattern = Column(String(255), nullable=False, default="*", server_default="*") + action = Column(String(16), nullable=False) # allow / deny / ask + created_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + server_default=text("(now() AT TIME ZONE 'utc')"), + index=True, + ) + + __table_args__ = ( + UniqueConstraint( + "search_space_id", + "user_id", + "thread_id", + "permission", + "pattern", + "action", + name="uq_agent_permission_rules_scope", + ), + ) + + class RefreshToken(Base, TimestampMixin): """ Stores refresh tokens for user session management. diff --git a/surfsense_backend/app/observability/__init__.py b/surfsense_backend/app/observability/__init__.py new file mode 100644 index 000000000..dbf082561 --- /dev/null +++ b/surfsense_backend/app/observability/__init__.py @@ -0,0 +1,7 @@ +"""SurfSense observability surface. + +The single user-visible API right now is :mod:`otel`, which exposes a +small wrapper around the optional ``opentelemetry`` instrumentation. The +wrapper is a no-op when OTEL is not configured, so importing it from +performance-critical paths is safe. +""" diff --git a/surfsense_backend/app/observability/otel.py b/surfsense_backend/app/observability/otel.py new file mode 100644 index 000000000..6791ab499 --- /dev/null +++ b/surfsense_backend/app/observability/otel.py @@ -0,0 +1,314 @@ +""" +OpenTelemetry instrumentation helpers for the SurfSense agent stack. + +Goals +===== + +- Provide one tiny, ergonomic API for the spans we care about + (``tool.call``, ``model.call``, ``kb.search``, ``kb.persist``, + ``compaction.run``, ``interrupt.raised``, ``permission.asked``). +- Keep span **names** low-cardinality (``tool.call`` rather than + ``tool.call.``); tool name lives in the ``tool.name`` attribute + so dashboards aggregate cleanly. +- Default to **no-op** behavior unless ``OTEL_EXPORTER_OTLP_ENDPOINT`` is + set, OR an external SDK has installed a real ``TracerProvider`` already + (e.g. via the ``opentelemetry-instrument`` agent). +- Coexist with LangSmith: we never disable LangSmith tracing; we add OTel + alongside. +- Gracefully degrade if the ``opentelemetry-api`` package is missing. +""" + +from __future__ import annotations + +import contextlib +import logging +import os +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Lazy/optional OpenTelemetry import +# ----------------------------------------------------------------------------- + +try: + from opentelemetry import trace as _ot_trace + from opentelemetry.trace import ( + Span as _OtSpan, + Status as _OtStatus, + StatusCode as _OtStatusCode, + ) + + _OTEL_AVAILABLE = True +except ImportError: # pragma: no cover — optional dep + _ot_trace = None # type: ignore[assignment] + _OtSpan = Any # type: ignore[assignment, misc] + _OtStatus = Any # type: ignore[assignment, misc] + _OtStatusCode = Any # type: ignore[assignment, misc] + _OTEL_AVAILABLE = False + + +_INSTRUMENTATION_NAME = "surfsense.new_chat" +_INSTRUMENTATION_VERSION = "0.1.0" + + +# ----------------------------------------------------------------------------- +# Configuration +# ----------------------------------------------------------------------------- + + +def _resolve_enabled() -> bool: + """Return True if OTel spans should actually be emitted.""" + if not _OTEL_AVAILABLE: + return False + # Honor an explicit kill-switch first. + if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}: + return False + # Treat a configured endpoint as the canonical "OTel is wired up" signal. + if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): + return True + # Or honor an external SDK that already installed a non-default TracerProvider. + if _ot_trace is not None: + try: + provider = _ot_trace.get_tracer_provider() + # The default proxy provider has no real exporter wired up. + type_name = type(provider).__name__ + if type_name not in {"ProxyTracerProvider", "NoOpTracerProvider"}: + return True + except Exception: # pragma: no cover — defensive + return False + return False + + +_ENABLED: bool = _resolve_enabled() + + +def is_enabled() -> bool: + """Return True if instrumentation is actively emitting spans.""" + return _ENABLED + + +def _get_tracer(): + if not _OTEL_AVAILABLE: + return None + try: + return _ot_trace.get_tracer(_INSTRUMENTATION_NAME, _INSTRUMENTATION_VERSION) + except Exception: # pragma: no cover — defensive + return None + + +# ----------------------------------------------------------------------------- +# No-op span used when OTel is disabled (avoids a None check at every call site) +# ----------------------------------------------------------------------------- + + +class _NoopSpan: + """A lightweight stand-in that mimics the subset of ``Span`` we use.""" + + def set_attribute(self, key: str, value: Any) -> None: + return None + + def set_attributes(self, attributes: dict[str, Any]) -> None: + return None + + def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: + return None + + def record_exception(self, exception: BaseException) -> None: + return None + + def set_status(self, status: Any) -> None: + return None + + +# ----------------------------------------------------------------------------- +# Public span helpers +# ----------------------------------------------------------------------------- + + +@contextmanager +def span( + name: str, + *, + attributes: dict[str, Any] | None = None, +) -> Iterator[Any]: + """Generic span context manager. + + Yields the underlying span (or a :class:`_NoopSpan` when disabled) + so callers can attach attributes/events incrementally. + + On exception, the span records the error via :meth:`record_exception` + and sets ``StatusCode.ERROR``; the exception is then re-raised. + """ + if not _ENABLED: + yield _NoopSpan() + return + + tracer = _get_tracer() + if tracer is None: # pragma: no cover — defensive + yield _NoopSpan() + return + + with tracer.start_as_current_span(name) as sp: + if attributes: + with contextlib.suppress(Exception): # pragma: no cover — defensive + sp.set_attributes(attributes) + try: + yield sp + except BaseException as exc: + with contextlib.suppress(Exception): # pragma: no cover — defensive + sp.record_exception(exc) + sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc))) + raise + + +# ----------------------------------------------------------------------------- +# Domain-specific shortcuts (mirror the plan's enumerated span list) +# ----------------------------------------------------------------------------- + + +def tool_call_span( + tool_name: str, + *, + input_size: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span for an individual tool execution. + + Span name is the constant ``tool.call`` (low-cardinality); the tool + identifier lives in the ``tool.name`` attribute. + """ + attrs: dict[str, Any] = {"tool.name": tool_name} + if input_size is not None: + attrs["tool.input.size"] = int(input_size) + if extra: + attrs.update(extra) + return span("tool.call", attributes=attrs) + + +def model_call_span( + *, + model_id: str | None = None, + provider: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span around a single ``astream`` / ``ainvoke`` call to the LLM.""" + attrs: dict[str, Any] = {} + if model_id: + attrs["model.id"] = model_id + if provider: + attrs["model.provider"] = provider + if extra: + attrs.update(extra) + return span("model.call", attributes=attrs) + + +def kb_search_span( + *, + search_space_id: int | None = None, + query_chars: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around knowledge-base search routines.""" + attrs: dict[str, Any] = {} + if search_space_id is not None: + attrs["search_space.id"] = int(search_space_id) + if query_chars is not None: + attrs["query.chars"] = int(query_chars) + if extra: + attrs.update(extra) + return span("kb.search", attributes=attrs) + + +def kb_persist_span( + *, + document_type: str | None = None, + document_id: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around knowledge-base persistence operations (NOTE/EXTENSION/FILE).""" + attrs: dict[str, Any] = {} + if document_type: + attrs["document.type"] = document_type + if document_id is not None: + attrs["document.id"] = int(document_id) + if extra: + attrs.update(extra) + return span("kb.persist", attributes=attrs) + + +def compaction_span( + *, + reason: str | None = None, + messages_in: int | None = None, + extra: dict[str, Any] | None = None, +): + """Span around the compaction (summarization) middleware run.""" + attrs: dict[str, Any] = {} + if reason: + attrs["compaction.reason"] = reason + if messages_in is not None: + attrs["compaction.messages.in"] = int(messages_in) + if extra: + attrs.update(extra) + return span("compaction.run", attributes=attrs) + + +def interrupt_span( + *, + interrupt_type: str, + extra: dict[str, Any] | None = None, +): + """Span recording an interrupt being raised (HITL or permission_ask).""" + attrs: dict[str, Any] = {"interrupt.type": interrupt_type} + if extra: + attrs.update(extra) + return span("interrupt.raised", attributes=attrs) + + +def permission_asked_span( + *, + permission: str, + pattern: str | None = None, + extra: dict[str, Any] | None = None, +): + """Span recording a permission ask (PermissionMiddleware).""" + attrs: dict[str, Any] = {"permission.permission": permission} + if pattern: + attrs["permission.pattern"] = pattern + if extra: + attrs.update(extra) + return span("permission.asked", attributes=attrs) + + +# ----------------------------------------------------------------------------- +# Test/utility hooks +# ----------------------------------------------------------------------------- + + +def reload_for_tests() -> bool: + """Re-evaluate :data:`_ENABLED` from the current environment. + + Tests that toggle ``OTEL_EXPORTER_OTLP_ENDPOINT`` or + ``SURFSENSE_DISABLE_OTEL`` can call this to reset cached state. + Returns the new value of :func:`is_enabled`. + """ + global _ENABLED + _ENABLED = _resolve_enabled() + return _ENABLED + + +__all__ = [ + "compaction_span", + "interrupt_span", + "is_enabled", + "kb_persist_span", + "kb_search_span", + "model_call_span", + "permission_asked_span", + "reload_for_tests", + "span", + "tool_call_span", +] diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index fafd4d356..5b6a74376 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -1,5 +1,9 @@ from fastapi import APIRouter +from .agent_action_log_route import router as agent_action_log_router +from .agent_flags_route import router as agent_flags_router +from .agent_permissions_route import router as agent_permissions_router +from .agent_revert_route import router as agent_revert_router from .airtable_add_connector_route import ( router as airtable_add_connector_router, ) @@ -65,6 +69,12 @@ router.include_router(documents_router) router.include_router(folders_router) router.include_router(notes_router) router.include_router(new_chat_router) # Chat with assistant-ui persistence +router.include_router(agent_revert_router) # POST /threads/{id}/revert/{action_id} +router.include_router(agent_action_log_router) # GET /threads/{id}/actions +router.include_router( + agent_permissions_router +) # CRUD for /searchspaces/{id}/agent/permissions/rules +router.include_router(agent_flags_router) # GET /agent/flags router.include_router(sandbox_router) # Sandbox file downloads (Daytona) router.include_router(chat_comments_router) router.include_router(podcasts_router) # Podcast task status and audio @@ -97,7 +107,9 @@ router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync -router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable +router.include_router( + mcp_oauth_router +) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages diff --git a/surfsense_backend/app/routes/agent_action_log_route.py b/surfsense_backend/app/routes/agent_action_log_route.py new file mode 100644 index 000000000..458635761 --- /dev/null +++ b/surfsense_backend/app/routes/agent_action_log_route.py @@ -0,0 +1,186 @@ +"""``GET /api/threads/{thread_id}/actions``: list agent action-log entries. + +Pairs with ``POST /api/threads/{thread_id}/revert/{action_id}`` (see +``agent_revert_route.py``). The action log is the read-side surface for +the audit/undo UI: it returns a paginated list of every tool call +recorded by :class:`ActionLogMiddleware` against the thread, plus +metadata about whether the action is reversible and whether it has +already been reverted. + +The route is gated by the same ``SURFSENSE_ENABLE_ACTION_LOG`` flag that +controls the middleware. When the flag is off the endpoint returns 503 +so the UI can detect "this deployment doesn't have the action log +enabled" without 404-ing on a missing route. + +The list is ordered DESC by ``created_at`` (newest first) so the +revert UI can render a familiar reverse-chronological feed without an +additional client-side sort. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentActionLog, + NewChatThread, + Permission, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Response schemas +# --------------------------------------------------------------------------- + + +class AgentActionRead(BaseModel): + """One row of the action log surfaced to the client.""" + + id: int + thread_id: int + user_id: str | None + search_space_id: int + tool_name: str + args: dict[str, Any] | None + result_id: str | None + reversible: bool + reverse_descriptor: dict[str, Any] | None + error: dict[str, Any] | None + reverse_of: int | None + reverted_by_action_id: int | None + is_revert_action: bool + created_at: datetime + + +class AgentActionListResponse(BaseModel): + """Paginated list response for the action log.""" + + items: list[AgentActionRead] + total: int + page: int + page_size: int + has_more: bool + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- + + +def _flag_guard() -> None: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_action_log: + raise HTTPException( + status_code=503, + detail=( + "Action log is not available on this deployment. Flip " + "SURFSENSE_ENABLE_ACTION_LOG to enable it." + ), + ) + + +@router.get( + "/threads/{thread_id}/actions", + response_model=AgentActionListResponse, +) +async def list_thread_actions( + thread_id: int, + page: int = Query(0, ge=0), + page_size: int = Query(50, ge=1, le=200), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentActionListResponse: + """List agent actions for a thread, newest first. + + Authorization: + * Caller must be a member of the thread's search space with + ``CHATS_READ`` permission. + + Pagination: + * ``page`` is 0-indexed. + * ``page_size`` defaults to 50, max 200. + """ + + _flag_guard() + + thread = await session.get(NewChatThread, thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to view this thread's action log.", + ) + + total_stmt = select(func.count(AgentActionLog.id)).where( + AgentActionLog.thread_id == thread_id + ) + total = (await session.execute(total_stmt)).scalar_one() + + rows_stmt = ( + select(AgentActionLog) + .where(AgentActionLog.thread_id == thread_id) + .order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc()) + .offset(page * page_size) + .limit(page_size) + ) + rows = (await session.execute(rows_stmt)).scalars().all() + + # Build a reverse_of -> revert_action_id map so the UI can render + # "Reverted" badges on actions that have already been undone. + if rows: + original_ids = [r.id for r in rows] + reverts_stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where( + AgentActionLog.reverse_of.in_(original_ids) + ) + reverts = (await session.execute(reverts_stmt)).all() + revert_map: dict[int, int] = {orig: rev for rev, orig in reverts} + else: + revert_map = {} + + items = [ + AgentActionRead( + id=row.id, + thread_id=row.thread_id, + user_id=str(row.user_id) if row.user_id is not None else None, + search_space_id=row.search_space_id, + tool_name=row.tool_name, + args=row.args, + result_id=row.result_id, + reversible=bool(row.reversible), + reverse_descriptor=row.reverse_descriptor, + error=row.error, + reverse_of=row.reverse_of, + reverted_by_action_id=revert_map.get(row.id), + is_revert_action=row.reverse_of is not None, + created_at=row.created_at, + ) + for row in rows + ] + + return AgentActionListResponse( + items=items, + total=int(total), + page=page, + page_size=page_size, + has_more=(page + 1) * page_size < int(total), + ) diff --git a/surfsense_backend/app/routes/agent_flags_route.py b/surfsense_backend/app/routes/agent_flags_route.py new file mode 100644 index 000000000..5732a8dfb --- /dev/null +++ b/surfsense_backend/app/routes/agent_flags_route.py @@ -0,0 +1,71 @@ +"""``GET /api/agent/flags``: read-only feature-flag status. + +Surfaces :class:`AgentFeatureFlags` to the frontend so the UI can: + +* Render conditional surfaces (e.g. show the action-log button only when + ``enable_action_log`` is on). +* Display an admin diagnostics card so operators can verify which + middleware tier is active without shelling into the box. + +The endpoint is *read-only*. Flipping flags requires an env-var change +plus a process restart — by design, since the values are baked into the +agent factory at build time. The route does not require any special +permission (any authenticated user can see them) since the flag values +do not leak data, and the UI surfaces are conditionally rendered based +on them anyway. +""" + +from __future__ import annotations + +from dataclasses import asdict + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags +from app.db import User +from app.users import current_active_user + +router = APIRouter() + + +class AgentFeatureFlagsRead(BaseModel): + """Mirror of :class:`AgentFeatureFlags`. Updated together with it.""" + + disable_new_agent_stack: bool + + enable_context_editing: bool + enable_compaction_v2: bool + enable_retry_after: bool + enable_model_fallback: bool + enable_model_call_limit: bool + enable_tool_call_limit: bool + enable_tool_call_repair: bool + enable_doom_loop: bool + + enable_permission: bool + enable_busy_mutex: bool + enable_llm_tool_selector: bool + + enable_skills: bool + enable_specialized_subagents: bool + enable_kb_planner_runnable: bool + + enable_action_log: bool + enable_revert_route: bool + + enable_plugin_loader: bool + + enable_otel: bool + + @classmethod + def from_flags(cls, flags: AgentFeatureFlags) -> AgentFeatureFlagsRead: + # asdict() avoids missing-field bugs when AgentFeatureFlags grows. + return cls(**asdict(flags)) + + +@router.get("/agent/flags", response_model=AgentFeatureFlagsRead) +async def get_agent_flags( + _user: User = Depends(current_active_user), +) -> AgentFeatureFlagsRead: + return AgentFeatureFlagsRead.from_flags(get_flags()) diff --git a/surfsense_backend/app/routes/agent_permissions_route.py b/surfsense_backend/app/routes/agent_permissions_route.py new file mode 100644 index 000000000..1c76e00e6 --- /dev/null +++ b/surfsense_backend/app/routes/agent_permissions_route.py @@ -0,0 +1,280 @@ +"""CRUD for :class:`app.db.AgentPermissionRule`. + +Surfaces the permission rules consumed by +:class:`PermissionMiddleware`. Rules are scoped at one of three levels: + +* **Search-space wide** — both ``user_id`` and ``thread_id`` are NULL. +* **Per-user** — ``user_id`` set, ``thread_id`` NULL. +* **Per-thread** — ``thread_id`` set (``user_id`` typically NULL). + +The middleware reads these rows at agent build time (see +``chat_deepagent.py``). UI lets a search-space owner curate them so +the agent can ask for approval / auto-deny / auto-allow specific +tool patterns. + +The route group is gated by ``SURFSENSE_ENABLE_PERMISSION``: when off +all endpoints return 503 so the UI can render a "feature not enabled" +empty state without breaking on a missing route. +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentPermissionRule, + NewChatThread, + Permission, + SearchSpace, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# --------------------------------------------------------------------------- +# Schemas +# --------------------------------------------------------------------------- + + +_ACTION_VALUES: tuple[str, ...] = ("allow", "deny", "ask") +_PERMISSION_PATTERN = re.compile(r"^[a-zA-Z0-9_:.\-*]+$") + + +class AgentPermissionRuleRead(BaseModel): + id: int + search_space_id: int + user_id: str | None + thread_id: int | None + permission: str + pattern: str + action: Literal["allow", "deny", "ask"] + created_at: datetime + + +class AgentPermissionRuleCreate(BaseModel): + permission: str = Field( + ..., + min_length=1, + max_length=255, + description="Tool / capability the rule targets, e.g. 'tool:create_linear_issue'.", + ) + pattern: str = Field( + "*", + min_length=1, + max_length=255, + description="Wildcard pattern (e.g. '*' or 'production-*') applied to the matched tool argument.", + ) + action: Literal["allow", "deny", "ask"] + user_id: str | None = None + thread_id: int | None = None + + +class AgentPermissionRuleUpdate(BaseModel): + pattern: str | None = Field(default=None, min_length=1, max_length=255) + action: Literal["allow", "deny", "ask"] | None = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _flag_guard() -> None: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_permission: + raise HTTPException( + status_code=503, + detail=( + "Agent permission rules are not enabled on this deployment. " + "Flip SURFSENSE_ENABLE_PERMISSION to enable them." + ), + ) + + +def _validate_permission_string(value: str) -> str: + if not _PERMISSION_PATTERN.match(value): + raise HTTPException( + status_code=400, + detail=( + "permission must contain only alphanumerics, '.', '_', ':', '-', " + "or '*' wildcards." + ), + ) + return value + + +def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead: + return AgentPermissionRuleRead( + id=row.id, + search_space_id=row.search_space_id, + user_id=str(row.user_id) if row.user_id is not None else None, + thread_id=row.thread_id, + permission=row.permission, + pattern=row.pattern, + action=row.action, # type: ignore[arg-type] + created_at=row.created_at, + ) + + +async def _ensure_search_space_membership_admin( + session: AsyncSession, user: User, search_space_id: int +) -> None: + """Curating agent rules == "settings" administration on the space.""" + space = await session.get(SearchSpace, search_space_id) + if space is None: + raise HTTPException(status_code=404, detail="Search space not found.") + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to manage agent permission rules in this space.", + ) + + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- + + +@router.get( + "/searchspaces/{search_space_id}/agent/permissions/rules", + response_model=list[AgentPermissionRuleRead], +) +async def list_rules( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> list[AgentPermissionRuleRead]: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + stmt = ( + select(AgentPermissionRule) + .where(AgentPermissionRule.search_space_id == search_space_id) + .order_by(AgentPermissionRule.created_at.desc(), AgentPermissionRule.id.desc()) + ) + rows = (await session.execute(stmt)).scalars().all() + return [_to_read(r) for r in rows] + + +@router.post( + "/searchspaces/{search_space_id}/agent/permissions/rules", + response_model=AgentPermissionRuleRead, + status_code=201, +) +async def create_rule( + search_space_id: int, + payload: AgentPermissionRuleCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentPermissionRuleRead: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + permission = _validate_permission_string(payload.permission.strip()) + pattern = payload.pattern.strip() or "*" + + if payload.thread_id is not None: + thread = await session.get(NewChatThread, payload.thread_id) + if thread is None or thread.search_space_id != search_space_id: + raise HTTPException( + status_code=404, + detail="Thread not found in this search space.", + ) + + row = AgentPermissionRule( + search_space_id=search_space_id, + user_id=payload.user_id, + thread_id=payload.thread_id, + permission=permission, + pattern=pattern, + action=payload.action, + ) + session.add(row) + try: + await session.commit() + except IntegrityError as err: + await session.rollback() + raise HTTPException( + status_code=409, + detail=( + "An identical rule already exists for this scope. Update the " + "existing rule instead." + ), + ) from err + await session.refresh(row) + return _to_read(row) + + +@router.patch( + "/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}", + response_model=AgentPermissionRuleRead, +) +async def update_rule( + search_space_id: int, + rule_id: int, + payload: AgentPermissionRuleUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> AgentPermissionRuleRead: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + row = await session.get(AgentPermissionRule, rule_id) + if row is None or row.search_space_id != search_space_id: + raise HTTPException(status_code=404, detail="Rule not found.") + + if payload.pattern is not None: + row.pattern = payload.pattern.strip() or "*" + if payload.action is not None: + row.action = payload.action + + try: + await session.commit() + except IntegrityError as err: + await session.rollback() + raise HTTPException( + status_code=409, + detail="Update would create a duplicate rule for this scope.", + ) from err + await session.refresh(row) + return _to_read(row) + + +@router.delete( + "/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}", + status_code=204, +) +async def delete_rule( + search_space_id: int, + rule_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> None: + _flag_guard() + await _ensure_search_space_membership_admin(session, user, search_space_id) + + row = await session.get(AgentPermissionRule, rule_id) + if row is None or row.search_space_id != search_space_id: + raise HTTPException(status_code=404, detail="Rule not found.") + + await session.delete(row) + await session.commit() + return None diff --git a/surfsense_backend/app/routes/agent_revert_route.py b/surfsense_backend/app/routes/agent_revert_route.py new file mode 100644 index 000000000..12484ff53 --- /dev/null +++ b/surfsense_backend/app/routes/agent_revert_route.py @@ -0,0 +1,124 @@ +"""POST ``/api/threads/{thread_id}/revert/{action_id}``: undo an agent action. + +The route ships **before** the UI lights up the per-message "Undo from +here" affordance. To prevent accidental usage during the gap we return +``503 Service Unavailable`` until the ``SURFSENSE_ENABLE_REVERT_ROUTE`` +flag flips. Once enabled, the route runs: + +1. Authentication via :func:`current_active_user`. +2. Action lookup; 404 if the action does not belong to the thread. +3. Authorization via :func:`app.services.revert_service.can_revert`. +4. Revert dispatch via :func:`app.services.revert_service.revert_action`. +5. Idempotent on retries: if the same action is reverted twice the second + call returns 409 ``"already reverted"``. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.feature_flags import get_flags +from app.db import ( + AgentActionLog, + User, + get_async_session, +) +from app.services.revert_service import ( + RevertOutcome, + can_revert, + load_action, + load_thread, + revert_action, +) +from app.users import current_active_user + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/threads/{thread_id}/revert/{action_id}") +async def revert_agent_action( + thread_id: int, + action_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +) -> dict: + flags = get_flags() + if flags.disable_new_agent_stack or not flags.enable_revert_route: + raise HTTPException( + status_code=503, + detail=( + "Revert is not available on this deployment yet. The route " + "ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to " + "enable it." + ), + ) + + thread = await load_thread(session, thread_id=thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + + action = await load_action(session, action_id=action_id, thread_id=thread_id) + if action is None: + raise HTTPException( + status_code=404, + detail="Action not found or does not belong to this thread.", + ) + + # Idempotency: if a successful revert already exists, return 409. + existing_revert = await session.execute( + select(AgentActionLog).where(AgentActionLog.reverse_of == action.id) + ) + if existing_revert.scalars().first() is not None: + raise HTTPException( + status_code=409, + detail="This action has already been reverted.", + ) + + if not can_revert( + requester_user_id=str(user.id) if user is not None else None, + action=action, + is_admin=False, # role lookup is done by RBAC layer; default conservative + ): + raise HTTPException( + status_code=403, + detail="You are not allowed to revert this action.", + ) + + outcome: RevertOutcome + try: + outcome = await revert_action( + session, + action=action, + requester_user_id=str(user.id) if user is not None else None, + ) + except Exception as err: + logger.exception("Revert dispatch raised for action_id=%s", action_id) + await session.rollback() + raise HTTPException( + status_code=500, detail="Internal error during revert." + ) from err + + if outcome.status == "ok": + await session.commit() + return { + "status": "ok", + "message": outcome.message, + "new_action_id": outcome.new_action_id, + } + + await session.rollback() + + if outcome.status == "not_found" or outcome.status == "tool_unavailable": + raise HTTPException(status_code=409, detail=outcome.message) + if outcome.status == "permission_denied": + raise HTTPException(status_code=403, detail=outcome.message) + if outcome.status == "reverse_not_implemented": + raise HTTPException(status_code=501, detail=outcome.message) + # not_reversible + raise HTTPException(status_code=409, detail=outcome.message) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index e14be83d0..1abc1f1ec 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -29,7 +29,11 @@ from app.db import ( ) from app.users import current_active_user from app.utils.connector_naming import generate_unique_connector_name -from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair +from app.utils.oauth_security import ( + OAuthStateManager, + TokenEncryption, + generate_pkce_pair, +) logger = logging.getLogger(__name__) @@ -37,7 +41,9 @@ router = APIRouter() async def _fetch_account_metadata( - service_key: str, access_token: str, token_json: dict[str, Any], + service_key: str, + access_token: str, + token_json: dict[str, Any], ) -> dict[str, Any]: """Fetch display-friendly account metadata after a successful token exchange. @@ -86,7 +92,8 @@ async def _fetch_account_metadata( meta["display_name"] = whoami.get("email", "Airtable") else: logger.warning( - "Airtable whoami API returned %d (non-blocking)", resp.status_code, + "Airtable whoami API returned %d (non-blocking)", + resp.status_code, ) except Exception: @@ -98,6 +105,7 @@ async def _fetch_account_metadata( return meta + _state_manager: OAuthStateManager | None = None _token_encryption: TokenEncryption | None = None @@ -151,6 +159,7 @@ def _frontend_redirect( # /add — start MCP OAuth flow # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/add") async def connect_mcp_service( service: str, @@ -170,9 +179,12 @@ async def connect_mcp_service( ) metadata = await discover_oauth_metadata( - svc.mcp_url, origin_override=svc.oauth_discovery_origin, + svc.mcp_url, + origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get( + "authorization_endpoint" ) - auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") @@ -236,7 +248,9 @@ async def connect_mcp_service( logger.info( "Generated %s MCP OAuth URL for user %s, space %s", - svc.name, user.id, space_id, + svc.name, + user.id, + space_id, ) return {"auth_url": auth_url} @@ -245,7 +259,8 @@ async def connect_mcp_service( except Exception as e: logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True) raise HTTPException( - status_code=500, detail=f"Failed to initiate {service} MCP OAuth.", + status_code=500, + detail=f"Failed to initiate {service} MCP OAuth.", ) from e @@ -253,6 +268,7 @@ async def connect_mcp_service( # /callback — handle OAuth redirect # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/callback") async def mcp_oauth_callback( service: str, @@ -271,7 +287,9 @@ async def mcp_oauth_callback( except Exception: pass return _frontend_redirect( - space_id, error=f"{service}_mcp_oauth_denied", service=service, + space_id, + error=f"{service}_mcp_oauth_denied", + service=service, ) if not code: @@ -337,9 +355,7 @@ async def mcp_oauth_callback( expires_at = None if expires_in: - expires_at = datetime.now(UTC) + timedelta( - seconds=int(expires_in) - ) + expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) connector_config = { "server_config": { @@ -349,10 +365,14 @@ async def mcp_oauth_callback( "mcp_service": svc_key, "mcp_oauth": { "client_id": client_id, - "client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "client_secret": enc.encrypt_token(client_secret) + if client_secret + else "", "token_endpoint": token_endpoint, "access_token": enc.encrypt_token(access_token), - "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, + "refresh_token": enc.encrypt_token(refresh_token) + if refresh_token + else None, "expires_at": expires_at.isoformat() if expires_at else None, "scope": scope, }, @@ -361,15 +381,27 @@ async def mcp_oauth_callback( account_meta = await _fetch_account_metadata(svc_key, access_token, token_json) if account_meta: - _SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email", - "workspace_id", "workspace_name", "organization_name", - "organization_url_key", "cloud_id", "site_name", "base_url"} + safe_meta_keys = { + "display_name", + "team_id", + "team_name", + "user_id", + "user_email", + "workspace_id", + "workspace_name", + "organization_name", + "organization_url_key", + "cloud_id", + "site_name", + "base_url", + } for k, v in account_meta.items(): - if k in _SAFE_META_KEYS: + if k in safe_meta_keys: connector_config[k] = v logger.info( "Stored account metadata for %s: display_name=%s", - svc_key, account_meta.get("display_name", ""), + svc_key, + account_meta.get("display_name", ""), ) # ---- Re-auth path ---- @@ -400,15 +432,24 @@ async def mcp_oauth_callback( logger.info( "Re-authenticated %s MCP connector %s for user %s", - svc.name, db_connector.id, user_id, + svc.name, + db_connector.id, + user_id, ) reauth_return_url = data.get("return_url") - if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + if ( + reauth_return_url + and reauth_return_url.startswith("/") + and not reauth_return_url.startswith("//") + ): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) return _frontend_redirect( - space_id, success=True, connector_id=db_connector.id, service=service, + space_id, + success=True, + connector_id=db_connector.id, + service=service, ) # ---- New connector path ---- @@ -436,24 +477,34 @@ async def mcp_oauth_callback( except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail="A connector for this service already exists.", + status_code=409, + detail="A connector for this service already exists.", ) from e _invalidate_cache(space_id) logger.info( "Created %s MCP connector %s for user %s in space %s", - svc.name, new_connector.id, user_id, space_id, + svc.name, + new_connector.id, + user_id, + space_id, ) return _frontend_redirect( - space_id, success=True, connector_id=new_connector.id, service=service, + space_id, + success=True, + connector_id=new_connector.id, + service=service, ) except HTTPException: raise except Exception as e: logger.error( - "Failed to complete %s MCP OAuth: %s", service, e, exc_info=True, + "Failed to complete %s MCP OAuth: %s", + service, + e, + exc_info=True, ) raise HTTPException( status_code=500, @@ -465,6 +516,7 @@ async def mcp_oauth_callback( # /reauth — re-authenticate an existing MCP connector # --------------------------------------------------------------------------- + @router.get("/auth/mcp/{service}/connector/reauth") async def reauth_mcp_service( service: str, @@ -491,7 +543,8 @@ async def reauth_mcp_service( ) if not result.scalars().first(): raise HTTPException( - status_code=404, detail="Connector not found or access denied", + status_code=404, + detail="Connector not found or access denied", ) try: @@ -501,9 +554,12 @@ async def reauth_mcp_service( ) metadata = await discover_oauth_metadata( - svc.mcp_url, origin_override=svc.oauth_discovery_origin, + svc.mcp_url, + origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get( + "authorization_endpoint" ) - auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") @@ -545,7 +601,9 @@ async def reauth_mcp_service( "service": service, "code_verifier": verifier, "mcp_client_id": client_id, - "mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "mcp_client_secret": enc.encrypt_token(client_secret) + if client_secret + else "", "mcp_token_endpoint": token_endpoint, "mcp_url": svc.mcp_url, "connector_id": connector_id, @@ -554,7 +612,9 @@ async def reauth_mcp_service( extra["return_url"] = return_url state = _get_state_manager().generate_secure_state( - space_id, user.id, **extra, + space_id, + user.id, + **extra, ) auth_params: dict[str, str] = { @@ -572,7 +632,9 @@ async def reauth_mcp_service( logger.info( "Initiating %s MCP re-auth for user %s, connector %s", - svc.name, user.id, connector_id, + svc.name, + user.id, + connector_id, ) return {"auth_url": auth_url} @@ -580,7 +642,10 @@ async def reauth_mcp_service( raise except Exception as e: logger.error( - "Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True, + "Failed to initiate %s MCP re-auth: %s", + service, + e, + exc_info=True, ) raise HTTPException( status_code=500, @@ -592,6 +657,7 @@ async def reauth_mcp_service( # Helpers # --------------------------------------------------------------------------- + def _invalidate_cache(space_id: int) -> None: try: from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index cbc660222..b5560d90d 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1242,7 +1242,9 @@ async def handle_new_chat( await session.close() image_urls = ( - [p.as_data_url() for p in request.user_images] if request.user_images else None + [p.as_data_url() for p in request.user_images] + if request.user_images + else None ) return StreamingResponse( diff --git a/surfsense_backend/app/routes/oauth_connector_base.py b/surfsense_backend/app/routes/oauth_connector_base.py index 0638e8f34..5b75d8519 100644 --- a/surfsense_backend/app/routes/oauth_connector_base.py +++ b/surfsense_backend/app/routes/oauth_connector_base.py @@ -9,6 +9,7 @@ Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``, from __future__ import annotations import base64 +import contextlib import logging from datetime import UTC, datetime, timedelta from typing import Any @@ -41,7 +42,6 @@ logger = logging.getLogger(__name__) class OAuthConnectorRoute: - def __init__( self, *, @@ -244,10 +244,8 @@ class OAuthConnectorRoute: if resp.status_code != 200: detail = resp.text - try: + with contextlib.suppress(Exception): detail = resp.json().get("error_description", detail) - except Exception: - pass raise HTTPException( status_code=400, detail=f"Token exchange failed: {detail}" ) @@ -430,7 +428,11 @@ class OAuthConnectorRoute: state_mgr = oauth._get_state_manager() extra: dict[str, Any] = {"connector_id": connector_id} - if return_url and return_url.startswith("/") and not return_url.startswith("//"): + if ( + return_url + and return_url.startswith("/") + and not return_url.startswith("//") + ): extra["return_url"] = return_url auth_params: dict[str, str] = { @@ -450,9 +452,7 @@ class OAuthConnectorRoute: auth_params.update(oauth.extra_auth_params) - state_encoded = state_mgr.generate_secure_state( - space_id, user.id, **extra - ) + state_encoded = state_mgr.generate_secure_state(space_id, user.id, **extra) auth_params["state"] = state_encoded auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}" @@ -489,9 +489,7 @@ class OAuthConnectorRoute: status_code=400, detail="Missing authorization code" ) if not state: - raise HTTPException( - status_code=400, detail="Missing state parameter" - ) + raise HTTPException(status_code=400, detail="Missing state parameter") state_mgr = oauth._get_state_manager() try: @@ -552,7 +550,11 @@ class OAuthConnectorRoute: db_connector.id, user_id, ) - if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + if ( + reauth_return_url + and reauth_return_url.startswith("/") + and not reauth_return_url.startswith("//") + ): return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" ) @@ -603,7 +605,8 @@ class OAuthConnectorRoute: except IntegrityError as e: await session.rollback() raise HTTPException( - status_code=409, detail="A connector for this service already exists." + status_code=409, + detail="A connector for this service already exists.", ) from e logger.info( diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index d42a7fa1a..9037d275a 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -3092,7 +3092,7 @@ async def trust_mcp_tool( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), ) ) connector = result.scalars().first() @@ -3147,7 +3147,7 @@ async def untrust_mcp_tool( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, - cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), ) ) connector = result.scalars().first() diff --git a/surfsense_backend/app/services/mcp_oauth/discovery.py b/surfsense_backend/app/services/mcp_oauth/discovery.py index b0f3fef2a..dc21443bc 100644 --- a/surfsense_backend/app/services/mcp_oauth/discovery.py +++ b/surfsense_backend/app/services/mcp_oauth/discovery.py @@ -55,7 +55,9 @@ async def register_client( async with httpx.AsyncClient(follow_redirects=True) as client: resp = await client.post( - registration_endpoint, json=payload, timeout=timeout, + registration_endpoint, + json=payload, + timeout=timeout, ) resp.raise_for_status() return resp.json() diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 49bc74d3d..835d70184 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -70,12 +70,14 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "createJiraIssue", "editJiraIssue", ], - readonly_tools=frozenset({ - "getAccessibleAtlassianResources", - "searchJiraIssuesUsingJql", - "getVisibleJiraProjects", - "getJiraProjectIssueTypesMetadata", - }), + readonly_tools=frozenset( + { + "getAccessibleAtlassianResources", + "searchJiraIssuesUsingJql", + "getVisibleJiraProjects", + "getJiraProjectIssueTypesMetadata", + } + ), account_metadata_keys=["cloud_id", "site_name", "base_url"], ), "clickup": MCPServiceConfig( @@ -99,15 +101,23 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { auth_endpoint_override="https://slack.com/oauth/v2_user/authorize", token_endpoint_override="https://slack.com/api/oauth.v2.user.access", scopes=[ - "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", - "channels:history", "groups:history", "mpim:history", "im:history", + "search:read.public", + "search:read.private", + "search:read.mpim", + "search:read.im", + "channels:history", + "groups:history", + "mpim:history", + "im:history", ], allowed_tools=[ "slack_search_channels", "slack_read_channel", "slack_read_thread", ], - readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}), + readonly_tools=frozenset( + {"slack_search_channels", "slack_read_channel", "slack_read_thread"} + ), # TODO: oauth.v2.user.access only returns team.id, not team.name. # To populate team_name, either add "team:read" scope and call # GET /api/team.info during OAuth callback, or switch to oauth.v2.access. @@ -127,7 +137,9 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { "list_tables_for_base", "list_records_for_table", ], - readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}), + readonly_tools=frozenset( + {"list_bases", "list_tables_for_base", "list_records_for_table"} + ), account_metadata_keys=["user_id", "user_email"], ), } @@ -136,20 +148,22 @@ _CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = { svc.connector_type: svc for svc in MCP_SERVICES.values() } -LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({ - SearchSourceConnectorType.SLACK_CONNECTOR, - SearchSourceConnectorType.TEAMS_CONNECTOR, - SearchSourceConnectorType.LINEAR_CONNECTOR, - SearchSourceConnectorType.JIRA_CONNECTOR, - SearchSourceConnectorType.CLICKUP_CONNECTOR, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.AIRTABLE_CONNECTOR, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, - SearchSourceConnectorType.DISCORD_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, -}) +LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset( + { + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, + } +) def get_service(key: str) -> MCPServiceConfig | None: diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index 8fbdad269..0fc4f30f4 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -156,7 +156,9 @@ async def _extract_binary_attachment_markdown( try: raw_bytes = base64.b64decode(payload.binary_base64, validate=True) except Exception: - logger.warning("obsidian attachment payload had invalid base64: %s", payload.path) + logger.warning( + "obsidian attachment payload had invalid base64: %s", payload.path + ) return "", {"attachment_extraction_status": "invalid_binary_payload"} suffix = f".{payload.extension.lstrip('.')}" @@ -180,7 +182,10 @@ async def _extract_binary_attachment_markdown( return result.markdown_content, metadata except Exception as exc: logger.warning( - "obsidian attachment ETL failed for %s: %s", payload.path, exc, exc_info=True + "obsidian attachment ETL failed for %s: %s", + payload.path, + exc, + exc_info=True, ) return "", { "attachment_extraction_status": "etl_failed", diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py new file mode 100644 index 000000000..f3630e0b4 --- /dev/null +++ b/surfsense_backend/app/services/revert_service.py @@ -0,0 +1,277 @@ +"""Revert service for the SurfSense agent action log. + +Implements the actual revert workflow used by +``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a +thin auth + flag wrapper around the functions defined here. + +Operation outcomes mirror the plan: + +* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from + :class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows + written before the original mutation. +* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke + the inverse tool through the agent's normal permission stack (NOT + bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``. +* **Anything else** (deprecated tool / no descriptor / schema drift): + returns ``NOT_REVERSIBLE`` and the route surfaces it as 409. + +A successful revert appends a NEW row to ``agent_action_log`` with +``reverse_of=`` and the requesting user's +``user_id``, preserving an auditable chain. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Literal + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import ( + AgentActionLog, + DocumentRevision, + FolderRevision, + NewChatThread, +) + +logger = logging.getLogger(__name__) + + +RevertOutcomeStatus = Literal[ + "ok", + "not_reversible", + "not_found", + "permission_denied", + "tool_unavailable", + "reverse_not_implemented", +] + + +@dataclass +class RevertOutcome: + """Structured result of :func:`revert_action`.""" + + status: RevertOutcomeStatus + message: str + new_action_id: int | None = None + + +# --------------------------------------------------------------------------- +# Lookup helpers +# --------------------------------------------------------------------------- + + +async def load_action( + session: AsyncSession, + *, + action_id: int, + thread_id: int, +) -> AgentActionLog | None: + """Load the action_log row for ``action_id`` if it belongs to the thread.""" + stmt = select(AgentActionLog).where( + AgentActionLog.id == action_id, + AgentActionLog.thread_id == thread_id, + ) + result = await session.execute(stmt) + return result.scalars().first() + + +async def load_thread(session: AsyncSession, *, thread_id: int) -> NewChatThread | None: + stmt = select(NewChatThread).where(NewChatThread.id == thread_id) + result = await session.execute(stmt) + return result.scalars().first() + + +# --------------------------------------------------------------------------- +# Authorization +# --------------------------------------------------------------------------- + + +def can_revert( + *, + requester_user_id: str | None, + action: AgentActionLog, + is_admin: bool, +) -> bool: + """Return True iff the requester is allowed to revert this action. + + The plan's rule: "requester must be the original `user_id` on the + action, or hold the search-space admin role." Anonymous actions + (``action.user_id is None``) can only be reverted by admins. + """ + if is_admin: + return True + if action.user_id is None: + return False + return str(action.user_id) == str(requester_user_id) + + +# --------------------------------------------------------------------------- +# Revert paths +# --------------------------------------------------------------------------- + + +async def _restore_document_revision( + session: AsyncSession, *, action: AgentActionLog +) -> RevertOutcome: + """Restore the most recent :class:`DocumentRevision` for ``action``.""" + stmt = ( + select(DocumentRevision) + .where(DocumentRevision.agent_action_id == action.id) + .order_by(DocumentRevision.created_at.desc()) + .limit(1) + ) + result = await session.execute(stmt) + revision = result.scalars().first() + if revision is None: + return RevertOutcome( + status="not_reversible", + message="No document_revisions row tied to this action.", + ) + + from app.db import Document # late import to avoid cycles at module load + + doc = await session.get(Document, revision.document_id) + if doc is None: + return RevertOutcome( + status="tool_unavailable", + message="Original document has been deleted; revert cannot proceed.", + ) + + if revision.content_before is not None: + doc.content = revision.content_before + if revision.title_before is not None: + doc.title = revision.title_before + if revision.folder_id_before is not None: + doc.folder_id = revision.folder_id_before + doc.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Document restored from snapshot.") + + +async def _restore_folder_revision( + session: AsyncSession, *, action: AgentActionLog +) -> RevertOutcome: + stmt = ( + select(FolderRevision) + .where(FolderRevision.agent_action_id == action.id) + .order_by(FolderRevision.created_at.desc()) + .limit(1) + ) + result = await session.execute(stmt) + revision = result.scalars().first() + if revision is None: + return RevertOutcome( + status="not_reversible", + message="No folder_revisions row tied to this action.", + ) + + from app.db import Folder + + folder = await session.get(Folder, revision.folder_id) + if folder is None: + return RevertOutcome( + status="tool_unavailable", + message="Original folder has been deleted; revert cannot proceed.", + ) + + if revision.name_before is not None: + folder.name = revision.name_before + if revision.parent_id_before is not None: + folder.parent_id = revision.parent_id_before + if revision.position_before is not None: + folder.position = revision.position_before + folder.updated_at = datetime.now(UTC) + return RevertOutcome(status="ok", message="Folder restored from snapshot.") + + +# Tool-name prefixes that route to KB document / folder revert paths. Kept +# as data so a future PR adding new KB-owned tools doesn't have to touch +# this module's control flow. +_DOC_TOOL_PREFIXES: tuple[str, ...] = ( + "edit_file", + "write_file", + "update_memory", + "create_note", + "update_note", + "delete_note", +) +_FOLDER_TOOL_PREFIXES: tuple[str, ...] = ( + "mkdir", + "move_file", + "rename_folder", + "delete_folder", +) + + +async def revert_action( + session: AsyncSession, + *, + action: AgentActionLog, + requester_user_id: str | None, +) -> RevertOutcome: + """Execute the revert for ``action`` and return a structured outcome. + + The function does **not** commit — the caller is expected to commit on + success or roll back on failure. A new ``agent_action_log`` row is + added to the session on success with ``reverse_of=action.id``. + """ + tool_name = (action.tool_name or "").lower() + + if tool_name.startswith(_DOC_TOOL_PREFIXES): + outcome = await _restore_document_revision(session, action=action) + elif tool_name.startswith(_FOLDER_TOOL_PREFIXES): + outcome = await _restore_folder_revision(session, action=action) + elif action.reverse_descriptor: + # Connector-owned reversibles run through the normal permission + # stack; out of scope for this PR — the route returns 503 anyway + # until UI ships, so 501-style "not implemented" is fine. + return RevertOutcome( + status="reverse_not_implemented", + message=( + "Connector-action revert is not yet implemented. The " + "reverse_descriptor is stored; future work will replay it " + "through PermissionMiddleware." + ), + ) + else: + return RevertOutcome( + status="not_reversible", + message=( + f"Tool {action.tool_name!r} is not reversible: no document " + "revision and no reverse_descriptor." + ), + ) + + if outcome.status != "ok": + return outcome + + new_row = AgentActionLog( + thread_id=action.thread_id, + user_id=requester_user_id, + search_space_id=action.search_space_id, + turn_id=None, + message_id=None, + tool_name=f"_revert:{action.tool_name}", + args={"reverted_action_id": action.id}, + result_id=None, + reversible=False, + reverse_descriptor=None, + error=None, + reverse_of=action.id, + ) + session.add(new_row) + await session.flush() + outcome.new_action_id = new_row.id + return outcome + + +__all__ = [ + "RevertOutcome", + "can_revert", + "load_action", + "load_thread", + "revert_action", +] diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 396c7574e..c254e66e2 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -30,7 +30,7 @@ from sqlalchemy.orm import selectinload from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent from app.agents.new_chat.checkpointer import get_checkpointer -from app.agents.new_chat.filesystem_selection import FilesystemSelection +from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection from app.agents.new_chat.llm_config import ( AgentConfig, create_chat_litellm_from_agent_config, @@ -42,6 +42,9 @@ from app.agents.new_chat.memory_extraction import ( extract_and_save_memory, extract_and_save_team_memory, ) +from app.agents.new_chat.middleware.kb_persistence import ( + commit_staged_filesystem_state, +) from app.db import ( ChatVisibility, NewChatMessage, @@ -182,9 +185,9 @@ def _tool_output_has_error(tool_output: Any) -> bool: if tool_output.get("error"): return True result = tool_output.get("result") - if isinstance(result, str) and result.strip().lower().startswith("error:"): - return True - return False + return bool( + isinstance(result, str) and result.strip().lower().startswith("error:") + ) if isinstance(tool_output, str): return tool_output.strip().lower().startswith("error:") return False @@ -230,7 +233,9 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: "stage": stage, "request_id": result.request_id or "unknown", "turn_id": result.turn_id or "unknown", - "chat_id": result.turn_id.split(":", 1)[0] if ":" in result.turn_id else "unknown", + "chat_id": result.turn_id.split(":", 1)[0] + if ":" in result.turn_id + else "unknown", "filesystem_mode": result.filesystem_mode, "client_platform": result.client_platform, "intent_detected": result.intent_detected, @@ -242,7 +247,9 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: "commit_gate_reason": result.commit_gate_reason or None, } payload.update(extra) - _perf_log.info("[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False)) + _perf_log.info( + "[file_operation_contract] %s", json.dumps(payload, ensure_ascii=False) + ) async def _stream_agent_events( @@ -255,6 +262,10 @@ async def _stream_agent_events( initial_step_id: str | None = None, initial_step_title: str = "", initial_step_items: list[str] | None = None, + *, + fallback_commit_search_space_id: int | None = None, + fallback_commit_created_by_id: str | None = None, + fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD, ) -> AsyncGenerator[str, None]: """Shared async generator that streams and formats astream_events from the agent. @@ -1277,6 +1288,40 @@ async def _stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} + + # Safety net: if astream_events was cancelled before + # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work + # (dirty_paths / staged_dirs / pending_moves) will still be in the + # checkpointed state. Run the SAME shared commit helper here so the + # turn's writes don't get lost on client disconnect, then push the + # delta back into the graph using `as_node=...` so reducers fire as if + # the after_agent hook produced it. + if ( + fallback_commit_filesystem_mode == FilesystemMode.CLOUD + and fallback_commit_search_space_id is not None + and ( + (state_values.get("dirty_paths") or []) + or (state_values.get("staged_dirs") or []) + or (state_values.get("pending_moves") or []) + ) + ): + try: + delta = await commit_staged_filesystem_state( + state_values, + search_space_id=fallback_commit_search_space_id, + created_by_id=fallback_commit_created_by_id, + filesystem_mode=fallback_commit_filesystem_mode, + dispatch_events=False, + ) + if delta: + await agent.aupdate_state( + config, + delta, + as_node="KnowledgeBasePersistenceMiddleware.after_agent", + ) + except Exception as exc: + _perf_log.warning("[stream_new_chat] safety-net commit failed: %s", exc) + contract_state = state_values.get("file_operation_contract") or {} contract_turn_id = contract_state.get("turn_id") current_turn_id = config.get("configurable", {}).get("turn_id", "") @@ -1289,7 +1334,8 @@ async def _stream_agent_events( result.intent_detected = intent_value if ( isinstance(intent_value, str) - and intent_value in ( + and intent_value + in ( "chat_only", "file_write", "file_read", @@ -1308,18 +1354,17 @@ async def _stream_agent_events( result.commit_gate_passed, result.commit_gate_reason = ( _evaluate_file_contract_outcome(result) ) - if not result.commit_gate_passed: - if _contract_enforcement_active(result): - gate_notice = ( - "I could not complete the requested file write because no successful " - "write_file/edit_file operation was confirmed." - ) - gate_text_id = streaming_service.generate_text_id() - yield streaming_service.format_text_start(gate_text_id) - yield streaming_service.format_text_delta(gate_text_id, gate_notice) - yield streaming_service.format_text_end(gate_text_id) - yield streaming_service.format_terminal_info(gate_notice, "error") - accumulated_text = gate_notice + if not result.commit_gate_passed and _contract_enforcement_active(result): + gate_notice = ( + "I could not complete the requested file write because no successful " + "write_file/edit_file operation was confirmed." + ) + gate_text_id = streaming_service.generate_text_id() + yield streaming_service.format_text_start(gate_text_id) + yield streaming_service.format_text_delta(gate_text_id, gate_notice) + yield streaming_service.format_text_end(gate_text_id) + yield streaming_service.format_terminal_info(gate_notice, "error") + accumulated_text = gate_notice else: result.commit_gate_passed = True result.commit_gate_reason = "" @@ -1824,6 +1869,13 @@ async def stream_new_chat( initial_step_id=initial_step_id, initial_step_title=initial_title, initial_step_items=initial_items, + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), ): if not _first_event_logged: _perf_log.info( @@ -2266,6 +2318,13 @@ async def stream_resume_chat( streaming_service=streaming_service, result=stream_result, step_prefix="thinking-resume", + fallback_commit_search_space_id=search_space_id, + fallback_commit_created_by_id=user_id, + fallback_commit_filesystem_mode=( + filesystem_selection.mode + if filesystem_selection + else FilesystemMode.CLOUD + ), ): if not _first_event_logged: _perf_log.info( diff --git a/surfsense_backend/app/utils/async_retry.py b/surfsense_backend/app/utils/async_retry.py index c3bdd5386..607b7a156 100644 --- a/surfsense_backend/app/utils/async_retry.py +++ b/surfsense_backend/app/utils/async_retry.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import logging from collections.abc import Callable from typing import TypeVar @@ -32,9 +33,7 @@ F = TypeVar("F", bound=Callable) def _is_retryable(exc: BaseException) -> bool: if isinstance(exc, ConnectorError): return exc.retryable - if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)): - return True - return False + return bool(isinstance(exc, httpx.TimeoutException | httpx.ConnectError)) def build_retry( @@ -86,10 +85,8 @@ def raise_for_status( retry_after_raw = response.headers.get("Retry-After") retry_after: float | None = None if retry_after_raw: - try: + with contextlib.suppress(ValueError, TypeError): retry_after = float(retry_after_raw) - except (ValueError, TypeError): - pass raise ConnectorRateLimitError( f"{service} rate limited (429)", service=service, diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 889bf1464..99c8243a5 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -233,7 +233,10 @@ async def generate_unique_connector_name( if identifier: name = f"{base} - {identifier}" return await ensure_unique_connector_name( - session, name, search_space_id, user_id, + session, + name, + search_space_id, + user_id, ) count = await count_connectors_of_type( diff --git a/surfsense_backend/app/utils/user_message_multimodal.py b/surfsense_backend/app/utils/user_message_multimodal.py index 1d0691697..dc9a6fe76 100644 --- a/surfsense_backend/app/utils/user_message_multimodal.py +++ b/surfsense_backend/app/utils/user_message_multimodal.py @@ -7,7 +7,9 @@ import binascii from typing import Any -def build_human_message_content(final_query: str, image_data_urls: list[str]) -> str | list[dict[str, Any]]: +def build_human_message_content( + final_query: str, image_data_urls: list[str] +) -> str | list[dict[str, Any]]: if not image_data_urls: return final_query parts: list[dict[str, Any]] = [{"type": "text", "text": final_query}] diff --git a/surfsense_backend/tests/integration/harness/__init__.py b/surfsense_backend/tests/integration/harness/__init__.py new file mode 100644 index 000000000..9a7ec07dc --- /dev/null +++ b/surfsense_backend/tests/integration/harness/__init__.py @@ -0,0 +1,146 @@ +""" +Integration test harness for the SurfSense agent stack. + +The plan calls for an ``LLMToolEmulator``-backed harness for end-to-end +replay of ``stream_new_chat``. The currently-installed langchain version +does not expose ``LLMToolEmulator``, so this harness builds the equivalent +on top of :class:`langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel`. + +The harness lets a test author script a sequence of model responses +(text + optional tool calls) and replay them against the new_chat agent +graph. Tools are stubbed via ``StubToolSpec`` -> ``langchain_core.tools.tool`` +decorator and execute deterministic Python callbacks. + +Used by: +- ``tests/integration/agents/new_chat/test_feature_flag_smoke.py`` to + confirm the kill-switch path produces identical-shape output regardless + of which middleware flags are toggled. +- Future per-tier PRs to record golden transcripts. +""" + +from __future__ import annotations + +import uuid +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.fake_chat_models import ( + FakeMessagesListChatModel, +) +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool, tool + + +class _ToolBindingFakeChatModel(FakeMessagesListChatModel): + """Adapter so the harness model can pretend it understands ``bind_tools``. + + The base ``FakeMessagesListChatModel`` raises ``NotImplementedError`` from + ``bind_tools``, but ``langchain.agents.create_agent`` always calls + ``bind_tools`` to attach the tool registry. We don't actually need the + fake to honor the tool schema — it's already scripted to emit the right + tool calls — so we return self. + """ + + def bind_tools( # type: ignore[override] + self, + tools: Sequence[Any], + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + return self + + +@dataclass +class StubToolSpec: + """A test-mode tool: a name, description, and a deterministic body.""" + + name: str + description: str + handler: Callable[..., Any] + args_schema: dict[str, Any] | None = None + + def build(self) -> BaseTool: + """Realize as a `langchain_core.tools.BaseTool`.""" + + @tool(name_or_callable=self.name, description=self.description) + def _stub_tool(**kwargs: Any) -> Any: + return self.handler(**kwargs) + + return _stub_tool + + +@dataclass +class ScriptedTurn: + """One scripted assistant turn. + + `text` is the assistant text (may be empty if pure tool call). + `tool_calls` is a list of dicts ``{name, args, id}``; if non-empty, the + agent will route to those tools and append a follow-up turn. + """ + + text: str = "" + tool_calls: list[dict[str, Any]] = field(default_factory=list) + + +def build_scripted_messages(turns: list[ScriptedTurn]) -> list[BaseMessage]: + """Convert :class:`ScriptedTurn` records to AIMessage payloads.""" + out: list[BaseMessage] = [] + for turn in turns: + tool_calls: list[dict[str, Any]] = [] + for tc in turn.tool_calls: + tool_calls.append( + { + "name": tc["name"], + "args": tc.get("args", {}), + "id": tc.get("id") or f"call_{uuid.uuid4().hex[:8]}", + } + ) + out.append(AIMessage(content=turn.text, tool_calls=tool_calls or [])) + return out + + +@dataclass +class ScriptedHarness: + """Bundle of (model, tools) ready to plug into ``create_agent``.""" + + model: _ToolBindingFakeChatModel + tools: list[BaseTool] + + +def build_scripted_harness( + *, + turns: list[ScriptedTurn], + tools: list[StubToolSpec] | None = None, + sleep: float | None = None, +) -> ScriptedHarness: + """Construct a deterministic agent harness from a script. + + Example:: + + harness = build_scripted_harness( + turns=[ + ScriptedTurn(tool_calls=[{"name": "echo", "args": {"x": 1}}]), + ScriptedTurn(text="done"), + ], + tools=[ + StubToolSpec(name="echo", description="echo args", handler=lambda **kw: kw), + ], + ) + """ + messages = build_scripted_messages(turns) + model = _ToolBindingFakeChatModel(responses=messages, sleep=sleep) + realized_tools = [t.build() for t in (tools or [])] + return ScriptedHarness(model=model, tools=realized_tools) + + +__all__ = [ + "ScriptedHarness", + "ScriptedTurn", + "StubToolSpec", + "build_scripted_harness", + "build_scripted_messages", +] diff --git a/surfsense_backend/tests/integration/harness/test_scripted_harness.py b/surfsense_backend/tests/integration/harness/test_scripted_harness.py new file mode 100644 index 000000000..6e9f7ab91 --- /dev/null +++ b/surfsense_backend/tests/integration/harness/test_scripted_harness.py @@ -0,0 +1,53 @@ +"""Smoke test: scripted harness drives create_agent end-to-end and produces a tool-call-then-final-text trace.""" + +from __future__ import annotations + +import pytest +from langchain.agents import create_agent + +from tests.integration.harness import ( + ScriptedTurn, + StubToolSpec, + build_scripted_harness, +) + +pytestmark = pytest.mark.integration + + +@pytest.mark.asyncio +async def test_scripted_harness_drives_basic_agent() -> None: + harness = build_scripted_harness( + turns=[ + ScriptedTurn( + tool_calls=[ + {"name": "echo", "args": {"x": 1}, "id": "call_1"}, + ] + ), + ScriptedTurn(text="done"), + ], + tools=[ + StubToolSpec( + name="echo", + description="Echo args back.", + handler=lambda **kwargs: {"echoed": kwargs}, + ), + ], + ) + + agent = create_agent( + harness.model, + system_prompt="You are a test agent.", + tools=harness.tools, + ) + + result = await agent.ainvoke({"messages": [("user", "do the thing")]}) + messages = result["messages"] + final_ai = next( + (m for m in reversed(messages) if m.__class__.__name__ == "AIMessage"), + None, + ) + assert final_ai is not None + assert final_ai.content == "done" + tool_messages = [m for m in messages if m.__class__.__name__ == "ToolMessage"] + assert len(tool_messages) == 1 + assert "echoed" in str(tool_messages[0].content) diff --git a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py index 41779a570..22f6c6de5 100644 --- a/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py +++ b/surfsense_backend/tests/integration/test_obsidian_plugin_routes.py @@ -499,7 +499,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ) as upsert_mock, - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -548,7 +550,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ), - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -600,7 +604,9 @@ class TestWireContractSmoke: "app.routes.obsidian_plugin_routes.upsert_note", new=AsyncMock(return_value=fake_doc), ), - patch("app.routes.obsidian_plugin_routes._queue_obsidian_attachment") as queue_mock, + patch( + "app.routes.obsidian_plugin_routes._queue_obsidian_attachment" + ) as queue_mock, ): sync_resp = await obsidian_sync( SyncBatchRequest( @@ -619,7 +625,5 @@ class TestWireContractSmoke: items_by_path = {it.path: it for it in sync_resp.items} assert items_by_path["ok.md"].status == "ok" assert items_by_path["image.png"].status == "error" - assert "does not match extension" in ( - items_by_path["image.png"].error or "" - ) + assert "does not match extension" in (items_by_path["image.png"].error or "") queue_mock.assert_not_called() diff --git a/surfsense_backend/tests/unit/agents/__init__.py b/surfsense_backend/tests/unit/agents/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/agents/new_chat/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py new file mode 100644 index 000000000..a92d371bd --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/__init__.py @@ -0,0 +1 @@ +"""__init__ stub so pytest discovers the prompts test module.""" diff --git a/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py new file mode 100644 index 000000000..397b1c787 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/prompts/test_composer.py @@ -0,0 +1,270 @@ +"""Tests for the prompt fragment composer.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from app.agents.new_chat.prompts.composer import ( + ALL_TOOL_NAMES_ORDERED, + compose_system_prompt, + detect_provider_variant, +) +from app.db import ChatVisibility + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def fixed_today() -> datetime: + return datetime(2025, 6, 1, 12, 0, tzinfo=UTC) + + +class TestProviderVariantDetection: + @pytest.mark.parametrize( + "model_name,expected", + [ + # GPT-4 family routes to "classic" (autonomous-persistence style) + ("openai:gpt-4o-mini", "openai_classic"), + ("openai:gpt-4-turbo", "openai_classic"), + # GPT-5 / o-series route to "reasoning" (channel-aware pragmatic) + ("openai:gpt-5", "openai_reasoning"), + ("openai:o1-preview", "openai_reasoning"), + ("openai:o3-mini", "openai_reasoning"), + # Codex family beats reasoning (more specific). Mirrors OpenCode + # ``system.ts`` — ``gpt-*-codex`` gets the code-purist prompt. + ("openai:gpt-5-codex", "openai_codex"), + ("openai:gpt-codex", "openai_codex"), + ("openai:codex-mini", "openai_codex"), + # Anthropic + Google + ("anthropic:claude-3-5-sonnet", "anthropic"), + ("anthropic/claude-opus-4", "anthropic"), + ("google:gemini-2.0-flash", "google"), + ("vertex:gemini-1.5-pro", "google"), + # Newly-covered families + ("moonshot:kimi-k2", "kimi"), + ("openrouter:moonshot/kimi-k2.5", "kimi"), + ("xai:grok-2", "grok"), + ("openrouter:x-ai/grok-3", "grok"), + ("openai:deepseek-v3", "deepseek"), + ("deepseek:deepseek-r1", "deepseek"), + # Unknown families fall back to default (no provider block emitted) + ("groq:mixtral-8x7b", "default"), + ("together:llama-3.1-70b", "default"), + (None, "default"), + ("", "default"), + ], + ) + def test_detection(self, model_name: str | None, expected: str) -> None: + assert detect_provider_variant(model_name) == expected + + def test_codex_takes_precedence_over_reasoning(self) -> None: + """Regression guard: ``gpt-5-codex`` must NOT match the generic + ``gpt-5`` reasoning regex first. Codex is the more specialised + prompt and mirrors OpenCode's dispatch order. + """ + from app.agents.new_chat.prompts.composer import detect_provider_variant + + assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex" + assert detect_provider_variant("openai:gpt-5") == "openai_reasoning" + + +class TestCompose: + def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt(today=fixed_today) + # System instruction wrapper + assert "" in prompt + assert "" in prompt + # Date interpolated + assert "2025-06-01" in prompt + # Core policy blocks present + assert "" in prompt + assert "" in prompt + assert "" in prompt + assert "" in prompt + # Tools + assert "" in prompt + assert "" in prompt + # Citations on by default + assert "" in prompt + assert "[citation:chunk_id]" in prompt + + def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt( + today=fixed_today, + thread_visibility=ChatVisibility.SEARCH_SPACE, + ) + # Team-specific phrasing in the agent block + assert "team space" in prompt + # Memory protocol mentions team + assert "team" in prompt + # Should NOT mention the user-only memory phrasing + assert "personal knowledge base" not in prompt + + def test_private_visibility_uses_private_variants( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + thread_visibility=ChatVisibility.PRIVATE, + ) + assert "personal knowledge base" in prompt + # Should NOT mention the team-specific phrasing about prefixed authors + assert "[DisplayName of the author]" not in prompt + + def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None: + prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True) + prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False) + assert "Citations are DISABLED" in prompt_off + assert "Citations are DISABLED" not in prompt_on + assert "[citation:chunk_id]" in prompt_on + + def test_enabled_tool_filter_only_includes_listed_tools( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search", "scrape_webpage"}, + ) + assert "web_search:" in prompt or "- web_search:" in prompt + assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt + # Excluded tools should NOT appear in tool listing + assert "generate_podcast:" not in prompt + assert "generate_image:" not in prompt + + def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search"}, + disabled_tool_names={"generate_image", "generate_podcast"}, + ) + assert "DISABLED TOOLS (by user):" in prompt + assert "Generate Image" in prompt + assert "Generate Podcast" in prompt + + def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt( + today=fixed_today, + mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]}, + ) + assert "" in prompt + assert "My GitLab" in prompt + assert "gitlab_search" in prompt + + def test_mcp_routing_block_absent_when_no_servers( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={}) + assert "" not in prompt + + def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt( + today=fixed_today, model_name="anthropic:claude-3-5-sonnet" + ) + assert "" in prompt + assert "Anthropic" in prompt or "Claude" in prompt + + def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None: + prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo") + assert "" not in prompt + + @pytest.mark.parametrize( + "model_name,expected_marker", + [ + # Each marker is a unique-ish phrase from the corresponding fragment. + # If a fragment is renamed/rewritten such that the marker is gone, + # update both the fragment and this test deliberately. + ("openai:gpt-5-codex", "Codex-class"), + ("openai:gpt-5", "OpenAI reasoning model"), + ("openai:gpt-4o", "classic OpenAI chat model"), + ("anthropic:claude-3-5-sonnet", "Anthropic Claude"), + ("google:gemini-2.0-flash", "Google Gemini"), + ("moonshot:kimi-k2", "Moonshot Kimi"), + ("xai:grok-2", "xAI Grok"), + ("deepseek:deepseek-r1", "DeepSeek"), + ], + ) + def test_each_known_variant_renders_with_its_marker( + self, + fixed_today: datetime, + model_name: str, + expected_marker: str, + ) -> None: + """Every supported variant must produce a ```` block + containing its identifying marker. This pins the dispatch + the + on-disk fragments together so a missing/renamed file is caught + immediately. + """ + prompt = compose_system_prompt(today=fixed_today, model_name=model_name) + assert "" in prompt, ( + f"variant for {model_name!r} did not emit a provider_hints block; " + "the corresponding providers/.md may be missing" + ) + assert expected_marker in prompt, ( + f"variant for {model_name!r} emitted hints but lacked the " + f"expected marker {expected_marker!r} — the fragment may have " + "drifted from the dispatch table" + ) + + def test_provider_blocks_are_byte_stable_across_calls( + self, fixed_today: datetime + ) -> None: + """Cache-stability guard: same model id → byte-identical prompt.""" + a = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2") + b = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2") + assert a == b + + def test_custom_system_instructions_override_default( + self, fixed_today: datetime + ) -> None: + custom = "You are a custom assistant. Today is {resolved_today}." + prompt = compose_system_prompt( + today=fixed_today, custom_system_instructions=custom + ) + assert "You are a custom assistant. Today is 2025-06-01." in prompt + # Default block should NOT be present + assert "" not in prompt + + def test_use_default_false_with_no_custom_yields_no_system_block( + self, fixed_today: datetime + ) -> None: + prompt = compose_system_prompt( + today=fixed_today, + use_default_system_instructions=False, + ) + # No system_instruction wrapper but tools/citations still emitted + assert "" not in prompt + assert "" in prompt + + def test_all_known_tools_have_fragments(self) -> None: + # Soft assertion: verify that every tool in the canonical order + # produces non-empty content for at least one variant. + for tool in ALL_TOOL_NAMES_ORDERED: + prompt = compose_system_prompt( + today=datetime(2025, 1, 1, tzinfo=UTC), + enabled_tool_names={tool}, + ) + assert tool in prompt, f"tool {tool!r} missing from composed prompt" + + +class TestStableOrderingForCacheStability: + """Regression guard: prompt cache hit-rate depends on byte-stable prefix.""" + + def test_composition_is_deterministic_given_same_inputs( + self, fixed_today: datetime + ) -> None: + a = compose_system_prompt( + today=fixed_today, + enabled_tool_names={"web_search", "scrape_webpage"}, + mcp_connector_tools={"X": ["x_a", "x_b"]}, + ) + b = compose_system_prompt( + today=fixed_today, + enabled_tool_names={ + "scrape_webpage", + "web_search", + }, # set order shouldn't matter + mcp_connector_tools={"X": ["x_a", "x_b"]}, + ) + assert a == b diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py new file mode 100644 index 000000000..aad1524c9 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_action_log.py @@ -0,0 +1,317 @@ +"""Unit tests for ActionLogMiddleware (Tier 5.2).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool + +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.agents.new_chat.middleware.action_log import ActionLogMiddleware +from app.agents.new_chat.tools.registry import ToolDefinition + + +@dataclass +class _FakeRequest: + """Minimal stand-in for ToolCallRequest used in unit tests.""" + + tool_call: dict[str, Any] + tool: Any = None + state: Any = None + runtime: Any = None + + +@tool +def make_widget(color: str, size: int) -> str: + """Create a widget.""" + return f"made {color} {size}" + + +def _enabled_flags(**overrides: bool) -> AgentFeatureFlags: + return AgentFeatureFlags( + disable_new_agent_stack=False, + enable_action_log=True, + **overrides, + ) + + +def _disabled_flags() -> AgentFeatureFlags: + return AgentFeatureFlags(disable_new_agent_stack=False, enable_action_log=False) + + +@pytest.fixture +def patch_get_flags(): + def _patch(flags: AgentFeatureFlags): + return patch( + "app.agents.new_chat.middleware.action_log.get_flags", + return_value=flags, + ) + + return _patch + + +@pytest.fixture +def fake_session_factory(): + """Patch ``shielded_async_session`` with a recording fake.""" + captured: dict[str, list] = {"rows": []} + + class _FakeSession: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def add(self, row): + captured["rows"].append(row) + + async def commit(self): + captured["committed"] = True + + def _factory(): + return _FakeSession() + + return captured, _factory + + +class TestActionLogMiddlewareDisabled: + @pytest.mark.asyncio + async def test_no_op_when_flag_off(self, patch_get_flags) -> None: + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red", "size": 1}, + "id": "tc1", + } + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with patch_get_flags(_disabled_flags()): + result = await mw.awrap_tool_call(request, handler) + handler.assert_awaited_once() + assert isinstance(result, ToolMessage) + + @pytest.mark.asyncio + async def test_no_op_when_thread_id_none(self, patch_get_flags) -> None: + mw = ActionLogMiddleware(thread_id=None, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with patch_get_flags(_enabled_flags()): + result = await mw.awrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + + +class TestActionLogMiddlewarePersistence: + @pytest.mark.asyncio + async def test_writes_row_on_success( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "red", "size": 3}, + "id": "tc-abc", + }, + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1") + handler = AsyncMock(return_value=result_msg) + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + result = await mw.awrap_tool_call(request, handler) + + assert result is result_msg + assert len(captured["rows"]) == 1 + row = captured["rows"][0] + assert row.thread_id == 42 + assert row.search_space_id == 7 + assert row.user_id == "u1" + assert row.tool_name == "make_widget" + assert row.args == {"color": "red", "size": 3} + assert row.result_id == "msg-1" + assert row.error is None + assert row.reverse_descriptor is None + assert row.reversible is False + + @pytest.mark.asyncio + async def test_writes_row_on_failure_and_reraises( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1") + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {"color": "red"}, "id": "tc1"} + ) + handler = AsyncMock(side_effect=ValueError("boom")) + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + pytest.raises(ValueError, match="boom"), + ): + await mw.awrap_tool_call(request, handler) + + assert len(captured["rows"]) == 1 + row = captured["rows"][0] + assert row.tool_name == "make_widget" + assert row.error == {"type": "ValueError", "message": "boom"} + assert row.result_id is None + + @pytest.mark.asyncio + async def test_persistence_failure_does_not_break_tool_call( + self, patch_get_flags + ) -> None: + """Even if the DB write blows up, the tool's result must reach the model.""" + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc1") + handler = AsyncMock(return_value=result_msg) + + def _exploding_session(): + raise RuntimeError("DB is down") + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=_exploding_session), + ): + result = await mw.awrap_tool_call(request, handler) + assert result is result_msg + + +class TestReverseDescriptor: + @pytest.mark.asyncio + async def test_renders_reverse_descriptor_when_tool_declares_one( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + + def _reverse(args, result): + return {"tool": "delete_widget", "args": {"id": result["id"]}} + + tool_def = ToolDefinition( + name="make_widget", + description="Create a widget", + factory=lambda deps: make_widget, + reverse=_reverse, + ) + mw = ActionLogMiddleware( + thread_id=1, + search_space_id=1, + user_id="u", + tool_definitions={"make_widget": tool_def}, + ) + request = _FakeRequest( + tool_call={ + "name": "make_widget", + "args": {"color": "blue", "size": 1}, + "id": "tc-xyz", + }, + ) + result_msg = ToolMessage( + content='{"id": "widget-9"}', tool_call_id="tc-xyz", id="msg-9" + ) + handler = AsyncMock(return_value=result_msg) + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + + row = captured["rows"][0] + assert row.reversible is True + assert row.reverse_descriptor == { + "tool": "delete_widget", + "args": {"id": "widget-9"}, + } + + @pytest.mark.asyncio + async def test_swallows_reverse_callable_errors( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + + def _bad_reverse(args, result): + raise RuntimeError("reverse blew up") + + tool_def = ToolDefinition( + name="make_widget", + description="Create a widget", + factory=lambda deps: make_widget, + reverse=_bad_reverse, + ) + mw = ActionLogMiddleware( + thread_id=1, + search_space_id=1, + user_id=None, + tool_definitions={"make_widget": tool_def}, + ) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {}, "id": "tc1"} + ) + result_msg = ToolMessage(content="ok", tool_call_id="tc1") + handler = AsyncMock(return_value=result_msg) + + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + + row = captured["rows"][0] + assert row.reversible is False + assert row.reverse_descriptor is None + + @pytest.mark.asyncio + async def test_no_reverse_when_tool_definition_missing( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + request = _FakeRequest( + tool_call={"name": "unknown_tool", "args": {}, "id": "tc1"} + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.reversible is False + + +class TestArgsTruncation: + @pytest.mark.asyncio + async def test_huge_args_payload_is_truncated( + self, patch_get_flags, fake_session_factory + ) -> None: + captured, factory = fake_session_factory + mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None) + # Build a > 32KB string so the persisted payload triggers the truncation path. + huge = "x" * (40 * 1024) + request = _FakeRequest( + tool_call={"name": "make_widget", "args": {"blob": huge}, "id": "tc1"}, + ) + handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1")) + with ( + patch_get_flags(_enabled_flags()), + patch("app.db.shielded_async_session", side_effect=lambda: factory()), + ): + await mw.awrap_tool_call(request, handler) + row = captured["rows"][0] + assert row.args is not None + assert row.args.get("_truncated") is True + assert row.args.get("_size", 0) >= 40 * 1024 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py new file mode 100644 index 000000000..0c7bf17f6 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_busy_mutex.py @@ -0,0 +1,90 @@ +"""Tests for BusyMutexMiddleware: per-thread lock + cancel event behavior.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import ( + BusyMutexMiddleware, + get_cancel_event, + manager, + request_cancel, + reset_cancel, +) + +pytestmark = pytest.mark.unit + + +class _Runtime: + def __init__(self, thread_id: str | None) -> None: + self.config = {"configurable": {"thread_id": thread_id}} + + +@pytest.mark.asyncio +async def test_first_acquire_succeeds_and_release_unblocks() -> None: + mw = BusyMutexMiddleware() + runtime = _Runtime("t1") + await mw.abefore_agent({}, runtime) + + # Lock should now be held + lock = manager.lock_for("t1") + assert lock.locked() + + await mw.aafter_agent({}, runtime) + assert not lock.locked() + + +@pytest.mark.asyncio +async def test_second_concurrent_acquire_raises_busy() -> None: + mw_a = BusyMutexMiddleware() + mw_b = BusyMutexMiddleware() + runtime = _Runtime("t-conflict") + await mw_a.abefore_agent({}, runtime) + + with pytest.raises(BusyError) as excinfo: + await mw_b.abefore_agent({}, runtime) + assert excinfo.value.request_id == "t-conflict" + + await mw_a.aafter_agent({}, runtime) + # After release, mw_b can acquire + await mw_b.abefore_agent({}, runtime) + await mw_b.aafter_agent({}, runtime) + + +@pytest.mark.asyncio +async def test_cancel_event_lifecycle() -> None: + mw = BusyMutexMiddleware() + runtime = _Runtime("t-cancel") + + await mw.abefore_agent({}, runtime) + event = get_cancel_event("t-cancel") + assert not event.is_set() + + request_cancel("t-cancel") + assert event.is_set() + + # End of turn should reset + await mw.aafter_agent({}, runtime) + assert not event.is_set() + + +@pytest.mark.asyncio +async def test_no_thread_id_raises_when_required() -> None: + mw = BusyMutexMiddleware(require_thread_id=True) + runtime = _Runtime(None) + with pytest.raises(BusyError): + await mw.abefore_agent({}, runtime) + + +@pytest.mark.asyncio +async def test_no_thread_id_skipped_when_not_required() -> None: + mw = BusyMutexMiddleware(require_thread_id=False) + runtime = _Runtime(None) + await mw.abefore_agent({}, runtime) + await mw.aafter_agent({}, runtime) + + +def test_reset_cancel_idempotent() -> None: + # Should not raise even if event was never created + reset_cancel("never-seen") diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py new file mode 100644 index 000000000..c6d4cc452 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_compaction.py @@ -0,0 +1,119 @@ +"""Tests for SurfSenseCompactionMiddleware: protected SystemMessage handling and content sanitization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from app.agents.new_chat.middleware.compaction import ( + PROTECTED_SYSTEM_PREFIXES, + _is_protected_system_message, + _sanitize_message_content, +) + +pytestmark = pytest.mark.unit + + +class TestIsProtectedSystemMessage: + @pytest.mark.parametrize("prefix", PROTECTED_SYSTEM_PREFIXES) + def test_each_prefix_protected(self, prefix: str) -> None: + msg = SystemMessage(content=f"{prefix}\nbody\n") + assert _is_protected_system_message(msg) is True + + def test_unprotected_system_message(self) -> None: + assert ( + _is_protected_system_message(SystemMessage(content="random instructions")) + is False + ) + + def test_human_message_never_protected(self) -> None: + assert ( + _is_protected_system_message(HumanMessage(content="...")) + is False + ) + + def test_tolerates_leading_whitespace(self) -> None: + msg = SystemMessage(content=" \n\n...") + assert _is_protected_system_message(msg) is True + + +class TestSanitizeMessageContent: + def test_returns_same_message_when_content_present(self) -> None: + msg = AIMessage(content="hello") + assert _sanitize_message_content(msg) is msg + + def test_replaces_none_with_empty_string(self) -> None: + # Pydantic blocks ``content=None`` at construction; the real + # crash happens when the streaming layer mutates ``content`` + # after-the-fact. Replicate that by force-setting on a built + # message. + msg = AIMessage( + content="", + tool_calls=[{"name": "x", "args": {}, "id": "1"}], + ) + # Bypass pydantic validation to simulate the LiteLLM/Bedrock case + object.__setattr__(msg, "content", None) + sanitized = _sanitize_message_content(msg) + assert sanitized.content == "" + + +class TestPartitionMessages: + """Verify the partition override surfaces protected SystemMessages + into ``preserved_messages`` regardless of cutoff position. + """ + + def _build_partitioner(self): + # Construct a thin shim — we can't easily instantiate the full + # SurfSenseCompactionMiddleware without a real model, but the + # override path needs ``_lc_helper`` to delegate to. We mock + # that with a simple slicing partitioner equivalent to the real one. + from app.agents.new_chat.middleware.compaction import ( + SurfSenseCompactionMiddleware, + ) + + class _LcHelper: + @staticmethod + def _partition_messages(messages, cutoff): + return messages[:cutoff], messages[cutoff:] + + class _Stub(SurfSenseCompactionMiddleware): + def __init__(self): + self._lc_helper = _LcHelper() + + return _Stub() + + def test_protected_system_message_preserved_even_in_summarize_half(self) -> None: + partitioner = self._build_partitioner() + protected = SystemMessage(content="\n...") + msgs = [ + HumanMessage(content="old human"), + AIMessage(content="old ai"), + protected, + ToolMessage(content="tool 1", tool_call_id="t1"), + HumanMessage(content="new"), + ] + # Cutoff = 4 means everything before index 4 should be summarized + to_summary, preserved = partitioner._partition_messages(msgs, 4) + + assert protected not in to_summary + assert protected in preserved + # The non-protected old messages remain in to_summary + assert any( + isinstance(m, HumanMessage) and m.content == "old human" for m in to_summary + ) + + def test_unprotected_messages_unaffected(self) -> None: + partitioner = self._build_partitioner() + msgs = [ + HumanMessage(content="a"), + HumanMessage(content="b"), + HumanMessage(content="c"), + ] + to_summary, preserved = partitioner._partition_messages(msgs, 2) + assert [m.content for m in to_summary] == ["a", "b"] + assert [m.content for m in preserved] == ["c"] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py new file mode 100644 index 000000000..ba2246413 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_context_editing.py @@ -0,0 +1,108 @@ +"""Tests for SpillToBackendEdit and SpillingContextEditingMiddleware.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.agents.new_chat.middleware.context_editing import ( + SpillToBackendEdit, + _build_spill_placeholder, +) + +pytestmark = pytest.mark.unit + + +def _build_history(num_pairs: int = 6) -> list[Any]: + """Build a long history of (AIMessage with tool_call, ToolMessage) pairs.""" + msgs: list[Any] = [HumanMessage(content="please do many things")] + for i in range(num_pairs): + msgs.append( + AIMessage( + content="", + tool_calls=[ + {"name": f"tool_{i}", "args": {"i": i}, "id": f"call-{i}"}, + ], + ) + ) + msgs.append( + ToolMessage( + content="x" * 5000, + tool_call_id=f"call-{i}", + name=f"tool_{i}", + id=f"tool-msg-{i}", + ) + ) + return msgs + + +def _approx_count(messages: list[Any]) -> int: + """Trivial token counter: 1 token per 4 chars.""" + total = 0 + for msg in messages: + content = getattr(msg, "content", "") + if isinstance(content, str): + total += len(content) // 4 + return total + + +class TestSpillEdit: + def test_below_trigger_does_nothing(self) -> None: + edit = SpillToBackendEdit(trigger=1_000_000, keep=2) + msgs = _build_history(3) + original_lengths = [len(getattr(m, "content", "")) for m in msgs] + edit.apply(msgs, count_tokens=_approx_count) + new_lengths = [len(getattr(m, "content", "")) for m in msgs] + assert original_lengths == new_lengths + assert edit.pending_spills == [] + + def test_above_trigger_clears_and_records(self) -> None: + edit = SpillToBackendEdit(trigger=100, keep=1, path_prefix="/tool_outputs") + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + + # The most-recent ToolMessage (keep=1) should remain intact + tool_messages = [m for m in msgs if isinstance(m, ToolMessage)] + intact = tool_messages[-1] + assert intact.content.startswith("x") # untouched + + # Earlier ToolMessages should now contain the placeholder text + cleared = [ + m + for m in tool_messages + if isinstance(m.content, str) and m.content.startswith("[cleared") + ] + assert len(cleared) >= 1 + # And the spill list should match + assert len(edit.pending_spills) == len(cleared) + + def test_excluded_tools_not_cleared(self) -> None: + edit = SpillToBackendEdit( + trigger=100, + keep=0, + exclude_tools=("tool_0",), + ) + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + + first_tool = next( + m for m in msgs if isinstance(m, ToolMessage) and m.name == "tool_0" + ) + # Excluded — untouched + assert first_tool.content.startswith("x") + + def test_drain_clears_pending(self) -> None: + edit = SpillToBackendEdit(trigger=100, keep=1) + msgs = _build_history(4) + edit.apply(msgs, count_tokens=_approx_count) + first_drain = edit.drain_pending() + assert len(first_drain) > 0 + assert edit.drain_pending() == [] + + def test_placeholder_format(self) -> None: + path = "/tool_outputs/thread-1/tool-msg-0.txt" + text = _build_spill_placeholder(path) + assert path in text + assert "explore" in text # mentions the recovery agent diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py new file mode 100644 index 000000000..e04f50815 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_dedup_tool_calls.py @@ -0,0 +1,144 @@ +"""Tests for declarative dedup_key on ToolDefinition (Tier 2.3 migration).""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage +from langchain_core.tools import StructuredTool + +from app.agents.new_chat.middleware.dedup_tool_calls import ( + DedupHITLToolCallsMiddleware, +) + +pytestmark = pytest.mark.unit + + +def _make_tool(name: str, *, dedup_key=None, hitl_dedup_key=None): + metadata = {} + if dedup_key is not None: + metadata["dedup_key"] = dedup_key + if hitl_dedup_key is not None: + metadata["hitl"] = True + metadata["hitl_dedup_key"] = hitl_dedup_key + + def _fn(**kwargs): + return "ok" + + return StructuredTool.from_function( + func=_fn, name=name, description="x", metadata=metadata + ) + + +def _msg(*calls: dict) -> AIMessage: + return AIMessage(content="", tool_calls=list(calls)) + + +class _Runtime: + pass + + +def test_callable_dedup_key_takes_priority() -> None: + tool = _make_tool( + "create_doc", + dedup_key=lambda args: f"{args.get('parent_id')}::{args.get('title')}", + ) + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + state = { + "messages": [ + _msg( + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "y"}, + "id": "1", + }, + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "y"}, + "id": "2", + }, + { + "name": "create_doc", + "args": {"parent_id": "x", "title": "z"}, + "id": "3", + }, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is not None + new_calls = out["messages"][0].tool_calls + assert len(new_calls) == 2 # one duplicate dropped + assert {c["id"] for c in new_calls} == {"1", "3"} + + +def test_string_hitl_dedup_key_still_works() -> None: + tool = _make_tool("send_x", hitl_dedup_key="subject") + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) + state = { + "messages": [ + _msg( + {"name": "send_x", "args": {"subject": "Hello"}, "id": "1"}, + {"name": "send_x", "args": {"subject": "hello"}, "id": "2"}, # case + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is not None + assert len(out["messages"][0].tool_calls) == 1 + + +def test_no_agent_tools_means_no_dedup() -> None: + """After the cleanup tier removed the legacy ``_NATIVE_HITL_TOOL_DEDUP_KEYS`` + map, dedup is purely declarative — no resolvers means no dedup runs. + + Coverage for the previously hardcoded native HITL tools now lives on + each :class:`ToolDefinition.dedup_key` in + :mod:`app.agents.new_chat.tools.registry`, which is wired through to + ``tool.metadata`` by :func:`build_tools`. + """ + mw = DedupHITLToolCallsMiddleware(agent_tools=None) + state = { + "messages": [ + _msg( + {"name": "create_notion_page", "args": {"title": "X"}, "id": "1"}, + {"name": "create_notion_page", "args": {"title": "x"}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is None + + +def test_registry_propagates_dedup_key_to_tool_metadata() -> None: + """Smoke-check the wiring path that replaced the legacy native map. + + ``ToolDefinition.dedup_key`` set in the registry must be copied onto + the constructed tool's ``metadata`` so :class:`DedupHITLToolCallsMiddleware` + can pick it up at agent build time. + """ + from app.agents.new_chat.tools.registry import ( + BUILTIN_TOOLS, + wrap_dedup_key_by_arg_name, + ) + + notion_tool_defs = [t for t in BUILTIN_TOOLS if t.name == "create_notion_page"] + assert notion_tool_defs, "registry should still expose create_notion_page" + tool_def = notion_tool_defs[0] + assert tool_def.dedup_key is not None + # Same wrapping helper used in the registry — sanity check identity + sample = wrap_dedup_key_by_arg_name("title")({"title": "Plan"}) + assert sample == "plan" + + +def test_unknown_tool_passes_through() -> None: + mw = DedupHITLToolCallsMiddleware(agent_tools=None) + state = { + "messages": [ + _msg( + {"name": "anything_else", "args": {"x": 1}, "id": "1"}, + {"name": "anything_else", "args": {"x": 1}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _Runtime()) + assert out is None # no dedup configured -> kept diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py new file mode 100644 index 000000000..ac6b5d95c --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_default_permissions_layering.py @@ -0,0 +1,128 @@ +"""Lock in the default-allow layering used by ``chat_deepagent``. + +The agent factory wires ``PermissionMiddleware`` with three rulesets, +earliest -> latest: + +1. ``surfsense_defaults`` (single ``allow */*`` rule) +2. ``connector_synthesized`` (deny rules for tools whose required + connector is missing) +3. (future) user-defined rules from the Agent Permissions UI + +Without #1 every read-only built-in (``ls``, ``read_file``, ``grep``, +``glob``, ``web_search`` …) defaulted to ``ask`` because +``permissions.evaluate`` returns ``ask`` when no rule matches. That +caused two production-painful behaviors: + +* Resume payloads with a prior reject decision bled into innocent + read-only tool calls, raising ``RejectedError("ls")``. +* Mutating connector tools got *double* prompted — once via the + middleware ``ask`` and again via the per-tool ``interrupt()`` in + ``app.agents.new_chat.tools.hitl``. + +These tests pin the layering so a refactor that drops the default +ruleset fails loud. +""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate_many, +) + +pytestmark = pytest.mark.unit + + +def _layered_rulesets(connector_denies: list[Rule]) -> list[Ruleset]: + """Replicate ``chat_deepagent`` layering for the test.""" + return [ + Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ), + Ruleset(rules=connector_denies, origin="connector_synthesized"), + ] + + +class TestReadOnlyToolsAllowed: + """Read-only built-ins must NOT default to ask.""" + + @pytest.mark.parametrize( + "tool_name", + [ + "ls", + "read_file", + "grep", + "glob", + "web_search", + "scrape_webpage", + "search_surfsense_docs", + "get_connected_accounts", + "write_todos", + "task", + "_noop", + "invalid", + "update_memory", + ], + ) + def test_default_allow_covers_safe_builtin(self, tool_name: str) -> None: + rulesets = _layered_rulesets(connector_denies=[]) + rules = evaluate_many(tool_name, [tool_name], *rulesets) + assert aggregate_action(rules) == "allow" + + +class TestConnectorDenyOverridesDefaultAllow: + """Connector-synthesized denies must beat the default-allow rule.""" + + def test_missing_connector_tool_is_denied(self) -> None: + rulesets = _layered_rulesets( + connector_denies=[ + Rule(permission="linear_create_issue", pattern="*", action="deny") + ] + ) + rules = evaluate_many("linear_create_issue", ["linear_create_issue"], *rulesets) + assert aggregate_action(rules) == "deny" + + def test_default_allow_still_applies_to_other_tools(self) -> None: + """A deny rule for one tool must not bleed onto unrelated calls.""" + rulesets = _layered_rulesets( + connector_denies=[ + Rule(permission="linear_create_issue", pattern="*", action="deny") + ] + ) + rules = evaluate_many("ls", ["ls"], *rulesets) + assert aggregate_action(rules) == "allow" + + +class TestUserRuleOverridesDefault: + """User rules layered last must override the default-allow rule.""" + + def test_user_ask_overrides_default_allow(self) -> None: + defaults = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ) + user_ruleset = Ruleset( + rules=[Rule(permission="ls", pattern="*", action="ask")], + origin="user", + ) + rules = evaluate_many("ls", ["ls"], defaults, user_ruleset) + assert aggregate_action(rules) == "ask" + + def test_user_deny_overrides_default_allow(self) -> None: + defaults = Ruleset( + rules=[Rule(permission="*", pattern="*", action="allow")], + origin="surfsense_defaults", + ) + user_ruleset = Ruleset( + rules=[Rule(permission="send_*", pattern="*", action="deny")], + origin="user", + ) + rules = evaluate_many( + "send_gmail_email", ["send_gmail_email"], defaults, user_ruleset + ) + assert aggregate_action(rules) == "deny" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py new file mode 100644 index 000000000..802129bf6 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_doom_loop.py @@ -0,0 +1,94 @@ +"""Tests for DoomLoopMiddleware signature equality detection.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware, _signature + +pytestmark = pytest.mark.unit + + +def test_signature_is_stable_for_identical_args() -> None: + a = _signature("search", {"q": "hello", "n": 10}) + b = _signature("search", {"n": 10, "q": "hello"}) + assert a == b + + +def test_signature_changes_with_args() -> None: + a = _signature("search", {"q": "hello"}) + b = _signature("search", {"q": "world"}) + assert a != b + + +def test_signature_changes_with_name() -> None: + a = _signature("search", {"q": "x"}) + b = _signature("read", {"q": "x"}) + assert a != b + + +class _FakeRuntime: + def __init__(self, thread_id: str | None = "thread-1") -> None: + self.config = {"configurable": {"thread_id": thread_id}} + + +def _msg_calling(name: str, args: dict, call_id: str) -> AIMessage: + return AIMessage( + content="", + tool_calls=[{"name": name, "args": args, "id": call_id}], + ) + + +def test_threshold_triggers_after_n_identical_calls() -> None: + mw = DoomLoopMiddleware(threshold=3) + runtime = _FakeRuntime() + + # First two calls — under threshold + for i in range(2): + out = mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 1}, f"call-{i}")]}, + runtime, + ) + assert out is None + + # Third identical call should trigger ``langgraph.types.interrupt``. + # In a unit-test context (no runnable graph), ``interrupt`` raises + # ``RuntimeError`` because ``get_config`` has nothing to bind to — + # we accept that as proof the interrupt path was taken (the + # alternative would be no exception, which would mean the loop + # detection never fired). + with pytest.raises(Exception) as excinfo: + mw.after_model( + {"messages": [_msg_calling("repeat", {"x": 1}, "call-3")]}, + runtime, + ) + name = type(excinfo.value).__name__.lower() + assert "interrupt" in name or "runtimeerror" in name, ( + f"Expected an interrupt-style exception, got {name}" + ) + + +def test_does_not_trigger_when_args_differ() -> None: + mw = DoomLoopMiddleware(threshold=2) + runtime = _FakeRuntime() + out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 1}, "1")]}, runtime) + assert out is None + out = mw.after_model({"messages": [_msg_calling("repeat", {"x": 2}, "2")]}, runtime) + assert out is None + + +def test_separate_threads_have_independent_windows() -> None: + mw = DoomLoopMiddleware(threshold=2) + rt_a = _FakeRuntime(thread_id="A") + rt_b = _FakeRuntime(thread_id="B") + + mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_a) + # thread B should NOT count thread A's call + out = mw.after_model({"messages": [_msg_calling("foo", {}, "1")]}, rt_b) + assert out is None # not yet at threshold for B + + +def test_invalid_threshold_rejected() -> None: + with pytest.raises(ValueError): + DoomLoopMiddleware(threshold=1) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py new file mode 100644 index 000000000..38a70a443 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_feature_flags.py @@ -0,0 +1,120 @@ +"""Tests for the agent feature-flag system.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.feature_flags import ( + AgentFeatureFlags, + reload_for_tests, +) + +pytestmark = pytest.mark.unit + + +def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None: + for name in [ + "SURFSENSE_DISABLE_NEW_AGENT_STACK", + "SURFSENSE_ENABLE_CONTEXT_EDITING", + "SURFSENSE_ENABLE_COMPACTION_V2", + "SURFSENSE_ENABLE_RETRY_AFTER", + "SURFSENSE_ENABLE_MODEL_FALLBACK", + "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + "SURFSENSE_ENABLE_DOOM_LOOP", + "SURFSENSE_ENABLE_PERMISSION", + "SURFSENSE_ENABLE_BUSY_MUTEX", + "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + "SURFSENSE_ENABLE_SKILLS", + "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + "SURFSENSE_ENABLE_ACTION_LOG", + "SURFSENSE_ENABLE_REVERT_ROUTE", + "SURFSENSE_ENABLE_PLUGIN_LOADER", + "SURFSENSE_ENABLE_OTEL", + ]: + monkeypatch.delenv(name, raising=False) + + +def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_all(monkeypatch) + flags = reload_for_tests() + assert isinstance(flags, AgentFeatureFlags) + assert flags.disable_new_agent_stack is False + assert flags.any_new_middleware_enabled() is False + + +def test_master_kill_switch_overrides_individual_flags( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_DISABLE_NEW_AGENT_STACK", "true") + monkeypatch.setenv("SURFSENSE_ENABLE_CONTEXT_EDITING", "true") + monkeypatch.setenv("SURFSENSE_ENABLE_PERMISSION", "true") + + flags = reload_for_tests() + assert flags.disable_new_agent_stack is True + assert flags.enable_context_editing is False + assert flags.enable_permission is False + assert flags.any_new_middleware_enabled() is False + + +@pytest.mark.parametrize("truthy", ["1", "true", "TRUE", "yes", "on"]) +def test_individual_flags_truthy_values( + monkeypatch: pytest.MonkeyPatch, truthy: str +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", truthy) + flags = reload_for_tests() + assert flags.enable_retry_after is True + assert flags.any_new_middleware_enabled() is True + + +@pytest.mark.parametrize("falsy", ["0", "false", "no", "off", "", "garbage"]) +def test_individual_flags_falsy_values( + monkeypatch: pytest.MonkeyPatch, falsy: str +) -> None: + _clear_all(monkeypatch) + monkeypatch.setenv("SURFSENSE_ENABLE_RETRY_AFTER", falsy) + flags = reload_for_tests() + assert flags.enable_retry_after is False + + +def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) -> None: + _clear_all(monkeypatch) + flag_to_env = { + "enable_context_editing": "SURFSENSE_ENABLE_CONTEXT_EDITING", + "enable_compaction_v2": "SURFSENSE_ENABLE_COMPACTION_V2", + "enable_retry_after": "SURFSENSE_ENABLE_RETRY_AFTER", + "enable_model_fallback": "SURFSENSE_ENABLE_MODEL_FALLBACK", + "enable_model_call_limit": "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + "enable_tool_call_limit": "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + "enable_tool_call_repair": "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + "enable_doom_loop": "SURFSENSE_ENABLE_DOOM_LOOP", + "enable_permission": "SURFSENSE_ENABLE_PERMISSION", + "enable_busy_mutex": "SURFSENSE_ENABLE_BUSY_MUTEX", + "enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + "enable_skills": "SURFSENSE_ENABLE_SKILLS", + "enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + "enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + "enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG", + "enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE", + "enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER", + "enable_otel": "SURFSENSE_ENABLE_OTEL", + } + + # `enable_otel` is intentionally orthogonal — it does NOT count toward + # ``any_new_middleware_enabled`` because OTel is observability-only and + # ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement. + counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"} + + for attr, env_name in flag_to_env.items(): + _clear_all(monkeypatch) + monkeypatch.setenv(env_name, "true") + flags = reload_for_tests() + assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}" + if attr in counts_toward_middleware: + assert flags.any_new_middleware_enabled() is True + else: + assert flags.any_new_middleware_enabled() is False diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py new file mode 100644 index 000000000..346271f4b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_noop_injection.py @@ -0,0 +1,123 @@ +"""Tests for NoopInjectionMiddleware provider-compat logic.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from app.agents.new_chat.middleware.noop_injection import ( + NOOP_TOOL_NAME, + NoopInjectionMiddleware, + _last_ai_has_tool_calls, + _provider_needs_noop, +) + +pytestmark = pytest.mark.unit + + +class _LiteLLMModel: + def _get_ls_params(self): + return {"ls_provider": "litellm"} + + +class _BedrockModel: + def _get_ls_params(self): + return {"ls_provider": "bedrock"} + + +class _OpenAIModel: + def _get_ls_params(self): + return {"ls_provider": "openai"} + + +class _ChatLiteLLM: # name-only fallback + pass + + +class TestProviderDetection: + def test_litellm(self) -> None: + assert _provider_needs_noop(_LiteLLMModel()) is True + + def test_bedrock(self) -> None: + assert _provider_needs_noop(_BedrockModel()) is True + + def test_openai_does_not_need(self) -> None: + assert _provider_needs_noop(_OpenAIModel()) is False + + def test_class_name_fallback(self) -> None: + assert _provider_needs_noop(_ChatLiteLLM()) is True + + +class TestHistoryDetection: + def test_last_ai_has_tool_calls(self) -> None: + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]), + ] + assert _last_ai_has_tool_calls(msgs) is True + + def test_last_ai_no_tool_calls(self) -> None: + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="hello"), + ] + assert _last_ai_has_tool_calls(msgs) is False + + def test_no_ai_in_history(self) -> None: + assert _last_ai_has_tool_calls([HumanMessage(content="hi")]) is False + + +class _FakeRequest: + def __init__(self, *, tools, messages, model) -> None: + self.tools = tools + self.messages = messages + self.model = model + + def override(self, *, tools): + return _FakeRequest(tools=tools, messages=self.messages, model=self.model) + + +class TestShouldInject: + def test_injects_when_all_conditions_met(self) -> None: + mw = NoopInjectionMiddleware() + msgs = [ + HumanMessage(content="hi"), + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]), + ] + req = _FakeRequest(tools=[], messages=msgs, model=_LiteLLMModel()) + assert mw._should_inject(req) is True + + def test_skips_when_tools_present(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[object()], + messages=[ + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]) + ], + model=_LiteLLMModel(), + ) + assert mw._should_inject(req) is False + + def test_skips_when_no_history_tool_calls(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[], + messages=[HumanMessage(content="hi")], + model=_LiteLLMModel(), + ) + assert mw._should_inject(req) is False + + def test_skips_for_openai(self) -> None: + mw = NoopInjectionMiddleware() + req = _FakeRequest( + tools=[], + messages=[ + AIMessage(content="", tool_calls=[{"name": "x", "args": {}, "id": "1"}]) + ], + model=_OpenAIModel(), + ) + assert mw._should_inject(req) is False + + +def test_noop_tool_name_is_underscore_noop() -> None: + assert NOOP_TOOL_NAME == "_noop" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py new file mode 100644 index 000000000..55434c04d --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_otel_span.py @@ -0,0 +1,195 @@ +"""Tests for the OtelSpanMiddleware adapter.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.agents.new_chat.middleware.otel_span import ( + OtelSpanMiddleware, + _annotate_model_response, + _annotate_tool_result, + _resolve_input_size, + _resolve_model_attrs, + _resolve_tool_name, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _disable_otel(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + from app.observability import otel as ot + + ot.reload_for_tests() + yield + ot.reload_for_tests() + + +class TestResolveModelAttrs: + def test_extracts_model_name_and_provider(self) -> None: + request = MagicMock() + request.model = MagicMock(spec=["model_name", "provider"]) + request.model.model_name = "gpt-4o-mini" + request.model.provider = "openai" + assert _resolve_model_attrs(request) == ("gpt-4o-mini", "openai") + + def test_handles_missing_model(self) -> None: + request = MagicMock() + request.model = None + assert _resolve_model_attrs(request) == (None, None) + + def test_falls_back_through_attribute_chain(self) -> None: + request = MagicMock() + request.model = MagicMock(spec=["model_id", "_llm_type"]) + request.model.model_id = "claude-3-5-sonnet" + request.model._llm_type = "anthropic-chat" + model_id, provider = _resolve_model_attrs(request) + assert model_id == "claude-3-5-sonnet" + assert provider == "anthropic-chat" + + +class TestResolveToolName: + def test_prefers_request_tool_name(self) -> None: + request = MagicMock() + request.tool = MagicMock(name="ToolStub") + request.tool.name = "scrape_webpage" + assert _resolve_tool_name(request) == "scrape_webpage" + + def test_falls_back_to_tool_call_name(self) -> None: + request = MagicMock() + request.tool = None + request.tool_call = {"name": "web_search", "args": {}} + assert _resolve_tool_name(request) == "web_search" + + def test_unknown_when_nothing_resolves(self) -> None: + request = MagicMock() + request.tool = None + request.tool_call = {} + assert _resolve_tool_name(request) == "unknown" + + +class TestResolveInputSize: + def test_returns_repr_length_of_args(self) -> None: + request = MagicMock() + request.tool_call = {"args": {"query": "hello world"}} + size = _resolve_input_size(request) + assert isinstance(size, int) + assert size > 0 + + def test_handles_no_tool_call(self) -> None: + request = MagicMock() + request.tool_call = None + assert _resolve_input_size(request) is None + + +class TestAnnotateModelResponse: + def test_attaches_token_counts_when_present(self) -> None: + sp = MagicMock() + msg = AIMessage( + content="hello", + usage_metadata={ + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + }, + ) + _annotate_model_response(sp, msg) + sp.set_attribute.assert_any_call("tokens.prompt", 100) + sp.set_attribute.assert_any_call("tokens.completion", 50) + sp.set_attribute.assert_any_call("tokens.total", 150) + + def test_handles_response_with_no_metadata(self) -> None: + sp = MagicMock() + msg = AIMessage(content="hello") + # Should not raise even when usage_metadata is missing + _annotate_model_response(sp, msg) + + +class TestAnnotateToolResult: + def test_records_size_and_status(self) -> None: + sp = MagicMock() + result = ToolMessage( + content="result text", + tool_call_id="abc", + status="success", + ) + _annotate_tool_result(sp, result) + sp.set_attribute.assert_any_call("tool.output.size", len("result text")) + sp.set_attribute.assert_any_call("tool.status", "success") + + def test_marks_errors(self) -> None: + sp = MagicMock() + result = ToolMessage( + content="oops", + tool_call_id="abc", + additional_kwargs={"error": {"code": "x"}}, + ) + _annotate_tool_result(sp, result) + sp.set_attribute.assert_any_call("tool.error", True) + + +@pytest.mark.asyncio +class TestMiddlewareIntegration: + async def test_awrap_model_call_passes_through_when_disabled(self) -> None: + mw = OtelSpanMiddleware() + called: dict[str, Any] = {} + + async def handler(req): + called["req"] = req + return AIMessage(content="ok") + + request = MagicMock() + result = await mw.awrap_model_call(request, handler) + assert called["req"] is request + assert isinstance(result, AIMessage) + assert result.content == "ok" + + async def test_awrap_tool_call_passes_through_when_disabled(self) -> None: + mw = OtelSpanMiddleware() + + async def handler(req): + return ToolMessage(content="result", tool_call_id="abc") + + request = MagicMock() + result = await mw.awrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + assert result.content == "result" + + async def test_awrap_model_call_propagates_exceptions(self) -> None: + mw = OtelSpanMiddleware() + + async def handler(req): + raise ValueError("boom") + + with pytest.raises(ValueError): + await mw.awrap_model_call(MagicMock(), handler) + + async def test_with_otel_enabled_does_not_alter_result( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + from app.observability import otel as ot + + ot.reload_for_tests() + try: + mw = OtelSpanMiddleware() + + async def handler(req): + return AIMessage(content="enabled") + + request = MagicMock() + request.model = MagicMock() + request.model.model_name = "gpt-4o" + request.model.provider = "openai" + result = await mw.awrap_model_call(request, handler) + assert isinstance(result, AIMessage) + assert result.content == "enabled" + finally: + ot.reload_for_tests() diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py new file mode 100644 index 000000000..ddb20330d --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_path_resolver.py @@ -0,0 +1,198 @@ +"""Tests for canonical virtual-path resolver helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.agents.new_chat.path_resolver import ( + DOCUMENTS_ROOT, + PathIndex, + doc_to_virtual_path, + parse_doc_id_suffix, + parse_documents_path, + safe_filename, + safe_folder_segment, + virtual_path_to_doc, +) + +pytestmark = pytest.mark.unit + + +class TestSafeFilename: + def test_appends_xml_extension(self): + assert safe_filename("notes").endswith(".xml") + + def test_strips_invalid_chars(self): + assert "/" not in safe_filename("a/b\\c.xml") + + def test_falls_back_when_empty(self): + assert safe_filename("").endswith(".xml") + assert safe_filename("///") == "untitled.xml" or safe_filename("///").endswith( + ".xml" + ) + + +class TestSafeFolderSegment: + def test_strips_path_separators(self): + assert "/" not in safe_folder_segment("a/b") + + def test_falls_back(self): + assert safe_folder_segment("") == "folder" + + +class TestParseDocIdSuffix: + def test_parses_suffix(self): + stem, doc_id = parse_doc_id_suffix("My Doc (42).xml") + assert stem == "My Doc" + assert doc_id == 42 + + def test_no_suffix_returns_none(self): + stem, doc_id = parse_doc_id_suffix("My Doc.xml") + assert stem == "My Doc" + assert doc_id is None + + def test_no_xml_extension(self): + stem, doc_id = parse_doc_id_suffix("plain") + assert stem == "plain" + assert doc_id is None + + +class TestDocToVirtualPath: + def test_root_when_no_folder(self): + index = PathIndex() + path = doc_to_virtual_path(doc_id=1, title="Hello", folder_id=None, index=index) + assert path == f"{DOCUMENTS_ROOT}/Hello.xml" + assert index.occupants[path] == 1 + + def test_collision_picks_doc_id_suffix(self): + index = PathIndex(occupants={f"{DOCUMENTS_ROOT}/Hello.xml": 7}) + path = doc_to_virtual_path(doc_id=8, title="Hello", folder_id=None, index=index) + assert path == f"{DOCUMENTS_ROOT}/Hello (8).xml" + assert index.occupants[path] == 8 + + def test_uses_folder_path_when_known(self): + index = PathIndex(folder_paths={5: f"{DOCUMENTS_ROOT}/notes"}) + path = doc_to_virtual_path(doc_id=2, title="A", folder_id=5, index=index) + assert path == f"{DOCUMENTS_ROOT}/notes/A.xml" + + +class TestParseDocumentsPath: + def test_extracts_folder_parts_and_title(self): + parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/bar/baz.xml") + assert parts == ["foo", "bar"] + assert title == "baz" + + def test_strips_doc_id_suffix(self): + parts, title = parse_documents_path(f"{DOCUMENTS_ROOT}/foo/My Doc (12).xml") + assert parts == ["foo"] + assert title == "My Doc" + + def test_non_documents_returns_empty(self): + assert parse_documents_path("/other/x.xml") == ([], "") + + +def _result_from_scalars(rows: list): + """Build a fake SQLAlchemy ``Result`` whose ``.scalars().all()`` and + ``.scalars().first()`` yield ``rows``.""" + scalars = MagicMock() + scalars.all.return_value = list(rows) + scalars.first.return_value = rows[0] if rows else None + result = MagicMock() + result.scalars.return_value = scalars + result.scalar_one_or_none.return_value = None + result.first.return_value = None + return result + + +def _result_from_one(value): + result = MagicMock() + result.scalar_one_or_none.return_value = value + return result + + +class TestVirtualPathToDoc: + """Lookup must round-trip through ``safe_filename``'s lossy encoding. + + The workspace tree displays ``safe_filename(title)`` as the basename, so + when the agent passes that basename back to a tool (move/edit/read) the + resolver must find the original document even though characters like + ``:`` were replaced with ``_``. + """ + + @pytest.mark.asyncio + async def test_falls_back_to_safe_filename_match_when_title_lossy(self): + # A Google Calendar-style title that contains a colon — safe_filename + # rewrites the colon to ``_``, so the literal title-equality lookup + # would miss this row. + original_title = "Calendar: Happy birthday!" + encoded_basename = safe_filename(original_title) + assert encoded_basename == "Calendar_ Happy birthday!.xml" + + target_doc = SimpleNamespace(id=42, title=original_title, folder_id=None) + + session = MagicMock() + # Each ``await session.execute(...)`` returns a fresh canned result. + # Order matches the resolver's lookup steps: + # 1) unique_identifier_hash → no match + # 2) literal title match → no match (lossy encoding) + # 3) folder scan → returns the row whose title encodes to basename + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([]), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/{encoded_basename}", + ) + assert document is target_doc + + @pytest.mark.asyncio + async def test_returns_none_when_no_doc_matches_safe_filename(self): + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([]), + _result_from_scalars( + [SimpleNamespace(id=1, title="Something else", folder_id=None)] + ), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/Calendar_ Happy birthday!.xml", + ) + assert document is None + + @pytest.mark.asyncio + async def test_literal_title_match_short_circuits_fallback(self): + # When the literal title query hits, the folder-scan fallback must + # NOT run (saves a query and avoids picking the wrong doc when two + # rows share a lossy encoding). + target_doc = SimpleNamespace(id=7, title="Plain Note", folder_id=None) + + session = MagicMock() + session.execute = AsyncMock( + side_effect=[ + _result_from_one(None), + _result_from_scalars([target_doc]), + ] + ) + + document = await virtual_path_to_doc( + session, + search_space_id=5, + virtual_path=f"{DOCUMENTS_ROOT}/Plain Note.xml", + ) + assert document is target_doc + assert session.execute.await_count == 2 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py new file mode 100644 index 000000000..a997c8d61 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permission_middleware.py @@ -0,0 +1,114 @@ +"""Tests for PermissionMiddleware end-to-end behavior.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.agents.new_chat.errors import CorrectedError, RejectedError +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.permissions import Rule, Ruleset + +pytestmark = pytest.mark.unit + + +class _FakeRuntime: + config: dict = {"configurable": {"thread_id": "test"}} + + +def _msg(*tool_calls: dict) -> AIMessage: + return AIMessage(content="", tool_calls=list(tool_calls)) + + +class TestAllow: + def test_passthrough_when_allow(self) -> None: + rs = Ruleset(rules=[Rule("send_email", "*", "allow")]) + mw = PermissionMiddleware(rulesets=[rs]) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is None # no change + + +class TestDeny: + def test_replaces_with_deny_tool_message(self) -> None: + rs = Ruleset(rules=[Rule("send_email", "*", "deny")]) + mw = PermissionMiddleware(rulesets=[rs]) + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is not None + msgs = out["messages"] + # Find the deny ToolMessage + deny_msgs = [m for m in msgs if isinstance(m, ToolMessage)] + assert len(deny_msgs) == 1 + assert deny_msgs[0].status == "error" + assert "permission_denied" in str(deny_msgs[0].additional_kwargs) + # AIMessage's tool_calls should now be empty (denied call removed) + ai_msg = next(m for m in msgs if isinstance(m, AIMessage)) + assert ai_msg.tool_calls == [] + + def test_mixed_allow_deny(self) -> None: + rs = Ruleset( + rules=[ + Rule("send_email", "*", "deny"), + Rule("read", "*", "allow"), + ] + ) + mw = PermissionMiddleware(rulesets=[rs]) + state = { + "messages": [ + _msg( + {"name": "send_email", "args": {}, "id": "1"}, + {"name": "read", "args": {}, "id": "2"}, + ) + ] + } + out = mw.after_model(state, _FakeRuntime()) + assert out is not None + ai_msg = next(m for m in out["messages"] if isinstance(m, AIMessage)) + assert len(ai_msg.tool_calls) == 1 + assert ai_msg.tool_calls[0]["name"] == "read" + + +class TestAsk: + def test_reject_without_feedback_raises(self) -> None: + # Default: nothing matches -> ask + rs = Ruleset(rules=[]) + mw = PermissionMiddleware(rulesets=[rs]) + + # Bypass real interrupt — patch the helper + mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(RejectedError): + mw.after_model(state, _FakeRuntime()) + + def test_reject_with_feedback_raises_corrected(self) -> None: + rs = Ruleset(rules=[]) + mw = PermissionMiddleware(rulesets=[rs]) + mw._raise_interrupt = lambda **kw: { # type: ignore[assignment] + "decision_type": "reject", + "feedback": "use a different subject line", + } + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + with pytest.raises(CorrectedError) as excinfo: + mw.after_model(state, _FakeRuntime()) + assert excinfo.value.feedback == "use a different subject line" + + def test_once_proceeds_without_persisting(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + mw._raise_interrupt = lambda **kw: {"decision_type": "once"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + # No state change because all calls kept + assert out is None + # No new rule persisted + assert mw._runtime_ruleset.rules == [] + + def test_always_persists_runtime_rule(self) -> None: + mw = PermissionMiddleware(rulesets=[]) + mw._raise_interrupt = lambda **kw: {"decision_type": "always"} # type: ignore[assignment] + state = {"messages": [_msg({"name": "send_email", "args": {}, "id": "1"})]} + out = mw.after_model(state, _FakeRuntime()) + assert out is None # call kept + # Runtime ruleset got the always-allow rule + new_rules = [r for r in mw._runtime_ruleset.rules if r.action == "allow"] + assert any(r.permission == "send_email" for r in new_rules) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py new file mode 100644 index 000000000..8ec16617a --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_permissions.py @@ -0,0 +1,111 @@ +"""Tests for the wildcard matcher and rule evaluator (parity with OpenCode evaluate.ts).""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.permissions import ( + Rule, + Ruleset, + aggregate_action, + evaluate, + evaluate_many, + wildcard_match, +) + +pytestmark = pytest.mark.unit + + +class TestWildcardMatch: + @pytest.mark.parametrize( + "value,pattern,expected", + [ + ("edit", "edit", True), + ("edit", "*", True), + ("read", "edit", False), + ("/documents/secrets/x", "/documents/secrets/**", True), + # Single-segment glob: '*' does not cross '/' + ("/documents/secrets/x", "/documents/*/x", True), + ("/documents/foo/bar/x", "/documents/*/x", False), + ("/documents/foo/x", "/documents/*/x", True), + ("linear_create", "linear_*", True), + ("notion_create", "linear_*", False), + # ':' is not a separator, so '*' matches it + ("mcp:notion:create_page", "mcp:*", True), + ("mcp:notion:create_page", "mcp:**", True), + # But '/' IS a separator + ("foo/bar", "foo/*", True), + ("foo/bar/baz", "foo/*", False), + ], + ) + def test_match(self, value: str, pattern: str, expected: bool) -> None: + assert wildcard_match(value, pattern) is expected + + +class TestEvaluate: + def test_default_action_is_ask(self) -> None: + rule = evaluate("edit", "/foo/bar") + assert rule.action == "ask" + assert rule.permission == "edit" + + def test_last_match_wins(self) -> None: + rs = Ruleset( + rules=[ + Rule("edit", "*", "allow"), + Rule("edit", "/secrets/**", "deny"), + ] + ) + # Second rule (deny) is more specific AND specified later + assert evaluate("edit", "/secrets/x", rs).action == "deny" + # First rule (allow) covers the rest + assert evaluate("edit", "/public/x", rs).action == "allow" + + def test_layered_rulesets_later_overrides_earlier(self) -> None: + defaults = Ruleset(rules=[Rule("edit", "*", "ask")], origin="defaults") + space = Ruleset(rules=[Rule("edit", "*", "allow")], origin="space") + thread = Ruleset(rules=[Rule("edit", "*", "deny")], origin="thread") + # All three layered: thread wins + assert evaluate("edit", "x", defaults, space, thread).action == "deny" + # Without thread: space wins + assert evaluate("edit", "x", defaults, space).action == "allow" + + def test_permission_wildcard(self) -> None: + rs = Ruleset(rules=[Rule("linear_*", "*", "allow")]) + assert evaluate("linear_create_issue", "x", rs).action == "allow" + assert evaluate("notion_create", "x", rs).action == "ask" + + def test_pattern_wildcard(self) -> None: + rs = Ruleset(rules=[Rule("edit", "/documents/secrets/**", "deny")]) + assert evaluate("edit", "/documents/secrets/foo", rs).action == "deny" + assert evaluate("edit", "/documents/public/foo", rs).action == "ask" + + def test_evaluate_many(self) -> None: + rs = Ruleset( + rules=[ + Rule("edit", "*", "allow"), + Rule("edit", "/secrets/*", "deny"), + ] + ) + results = evaluate_many("edit", ["/public/x", "/secrets/y"], rs) + assert [r.action for r in results] == ["allow", "deny"] + + +class TestAggregateAction: + def test_any_deny_means_deny(self) -> None: + rules = [ + Rule("a", "*", "allow"), + Rule("a", "*", "deny"), + Rule("a", "*", "ask"), + ] + assert aggregate_action(rules) == "deny" + + def test_any_ask_means_ask_when_no_deny(self) -> None: + rules = [Rule("a", "*", "allow"), Rule("a", "*", "ask")] + assert aggregate_action(rules) == "ask" + + def test_all_allow_means_allow(self) -> None: + rules = [Rule("a", "*", "allow"), Rule("a", "*", "allow")] + assert aggregate_action(rules) == "allow" + + def test_empty_means_ask(self) -> None: + assert aggregate_action([]) == "ask" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py new file mode 100644 index 000000000..5dbf765a7 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_plugin_loader.py @@ -0,0 +1,185 @@ +"""Unit tests for the SurfSense plugin entry-point loader.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from langchain.agents.middleware import AgentMiddleware + +from app.agents.new_chat.plugin_loader import ( + PLUGIN_ENTRY_POINT_GROUP, + PluginContext, + load_allowed_plugin_names_from_env, + load_plugin_middlewares, +) +from app.agents.new_chat.plugins.year_substituter import ( + _YearSubstituterMiddleware, + make_middleware as year_substituter_factory, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _DummyMiddleware(AgentMiddleware): + """Trivial middleware used as the success-path return value.""" + + tools = () + + +def _ctx() -> PluginContext: + return PluginContext.build( + search_space_id=1, + user_id="u", + thread_visibility="PRIVATE", # type: ignore[arg-type] + llm=MagicMock(), + ) + + +class _FakeEntryPoint: + """Stand-in for ``importlib.metadata.EntryPoint``.""" + + def __init__(self, name: str, factory) -> None: + self.name = name + self._factory = factory + + def load(self): + return self._factory + + +# --------------------------------------------------------------------------- +# Loader behaviour +# --------------------------------------------------------------------------- + + +class TestPluginLoaderBasics: + def test_returns_empty_when_allowlist_is_empty(self) -> None: + assert load_plugin_middlewares(_ctx(), allowed_plugin_names=[]) == [] + + def test_skips_non_allowlisted_plugin(self) -> None: + called = [] + + def factory(_): # would be an obvious bug if called + called.append(True) + return _DummyMiddleware() + + ep = _FakeEntryPoint("dangerous_plugin", factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names=["allowed_only"] + ) + assert result == [] + assert not called + + def test_loads_allowlisted_plugin(self) -> None: + ep = _FakeEntryPoint("year_substituter", year_substituter_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names={"year_substituter"} + ) + assert len(result) == 1 + assert isinstance(result[0], _YearSubstituterMiddleware) + + +class TestPluginLoaderIsolation: + def test_factory_exception_is_isolated(self) -> None: + def crashing_factory(_): + raise RuntimeError("boom") + + ep = _FakeEntryPoint("buggy", crashing_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"buggy"}) + assert result == [] # construction continued without the plugin + + def test_non_middleware_return_is_rejected(self) -> None: + def bad_factory(_): + return "not a middleware" + + ep = _FakeEntryPoint("liar", bad_factory) + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[ep], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"liar"}) + assert result == [] + + def test_load_phase_exception_is_isolated(self) -> None: + class _BrokenEP: + name = "broken" + + def load(self): + raise ImportError("cannot import") + + with patch( + "app.agents.new_chat.plugin_loader.entry_points", + return_value=[_BrokenEP()], + ): + result = load_plugin_middlewares(_ctx(), allowed_plugin_names={"broken"}) + assert result == [] + + def test_one_failure_does_not_block_others(self) -> None: + """Two plugins; one crashes during factory; the other still loads.""" + + def crashing_factory(_): + raise RuntimeError("boom") + + eps = [ + _FakeEntryPoint("crashing", crashing_factory), + _FakeEntryPoint("ok", year_substituter_factory), + ] + with patch("app.agents.new_chat.plugin_loader.entry_points", return_value=eps): + result = load_plugin_middlewares( + _ctx(), allowed_plugin_names={"crashing", "ok"} + ) + assert len(result) == 1 + assert isinstance(result[0], _YearSubstituterMiddleware) + + +class TestAllowlistEnv: + def test_empty_env_returns_empty_set(self, monkeypatch) -> None: + monkeypatch.delenv("SURFSENSE_ALLOWED_PLUGINS", raising=False) + assert load_allowed_plugin_names_from_env() == set() + + def test_parses_comma_separated_value(self, monkeypatch) -> None: + monkeypatch.setenv("SURFSENSE_ALLOWED_PLUGINS", " year_substituter , noisy , ") + assert load_allowed_plugin_names_from_env() == { + "year_substituter", + "noisy", + } + + +class TestPluginContext: + def test_build_includes_required_fields(self) -> None: + llm = MagicMock() + ctx = PluginContext.build( + search_space_id=42, + user_id="user-1", + thread_visibility="PRIVATE", # type: ignore[arg-type] + llm=llm, + ) + assert ctx["search_space_id"] == 42 + assert ctx["user_id"] == "user-1" + assert ctx["llm"] is llm + + def test_does_not_carry_secrets_or_db_session(self) -> None: + ctx = _ctx() + # If a future change tries to add these keys, this test will fail loudly. + for forbidden in ("api_key", "secret", "db_session", "session"): + assert forbidden not in ctx + + +class TestEntryPointGroup: + def test_group_name_matches_pyproject_convention(self) -> None: + # Plugins register under `surfsense.plugins`; this is part of our + # public contract for plugin authors. + assert PLUGIN_ENTRY_POINT_GROUP == "surfsense.plugins" diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py new file mode 100644 index 000000000..d23fd693b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_retry_after.py @@ -0,0 +1,107 @@ +"""Tests for RetryAfterMiddleware Retry-After parsing and retry decision logic.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.middleware.retry_after import ( + RetryAfterMiddleware, + _extract_retry_after_seconds, + _is_non_retryable, +) + +pytestmark = pytest.mark.unit + + +class _FakeResponse: + def __init__(self, headers: dict[str, str]) -> None: + self.headers = headers + + +class _FakeRateLimitError(Exception): + def __init__(self, msg: str, headers: dict[str, str] | None = None) -> None: + super().__init__(msg) + if headers is not None: + self.response = _FakeResponse(headers) + + +class TestExtractRetryAfter: + def test_seconds_header(self) -> None: + exc = _FakeRateLimitError("rate", {"Retry-After": "30"}) + assert _extract_retry_after_seconds(exc) == 30.0 + + def test_milliseconds_header_overrides_seconds(self) -> None: + exc = _FakeRateLimitError("rate", {"retry-after-ms": "1500"}) + assert _extract_retry_after_seconds(exc) == 1.5 + + def test_case_insensitive(self) -> None: + exc = _FakeRateLimitError("rate", {"RETRY-AFTER": "12"}) + assert _extract_retry_after_seconds(exc) == 12.0 + + def test_falls_back_to_message_regex(self) -> None: + exc = Exception("Please retry after 7 seconds") + assert _extract_retry_after_seconds(exc) == 7.0 + + def test_returns_none_when_no_hint(self) -> None: + exc = Exception("oops") + assert _extract_retry_after_seconds(exc) is None + + def test_handles_missing_headers_attr(self) -> None: + exc = ValueError("no headers") + assert _extract_retry_after_seconds(exc) is None + + +class TestIsNonRetryable: + @pytest.mark.parametrize( + "name", + ["ContextWindowExceededError", "AuthenticationError", "InvalidRequestError"], + ) + def test_non_retryable_classes(self, name: str) -> None: + cls = type(name, (Exception,), {}) + assert _is_non_retryable(cls("x")) is True + + def test_generic_exception_is_retryable(self) -> None: + assert _is_non_retryable(RuntimeError("transient")) is False + + +class TestDelayCalculation: + def test_takes_max_of_backoff_and_header(self) -> None: + mw = RetryAfterMiddleware(max_retries=3, initial_delay=1.0, jitter=False) + exc = _FakeRateLimitError("rl", {"retry-after": "10"}) + delay = mw._delay_for_attempt(0, exc) + assert delay == pytest.approx(10.0) + + def test_uses_backoff_when_no_header(self) -> None: + mw = RetryAfterMiddleware( + max_retries=3, initial_delay=2.0, backoff_factor=2.0, jitter=False + ) + delay = mw._delay_for_attempt(2, RuntimeError("transient")) + # 2 * 2^2 = 8 + assert delay == pytest.approx(8.0) + + def test_caps_at_max_delay(self) -> None: + mw = RetryAfterMiddleware( + max_retries=3, + initial_delay=10.0, + backoff_factor=10.0, + max_delay=15.0, + jitter=False, + ) + delay = mw._delay_for_attempt(5, RuntimeError("x")) + assert delay <= 15.0 + + +class TestShouldRetry: + def test_default_retries_generic(self) -> None: + mw = RetryAfterMiddleware() + assert mw._should_retry(RuntimeError("transient")) is True + + def test_default_skips_non_retryable(self) -> None: + mw = RetryAfterMiddleware() + cls = type("ContextWindowExceededError", (Exception,), {}) + assert mw._should_retry(cls("too big")) is False + + def test_custom_retry_on(self) -> None: + mw = RetryAfterMiddleware(retry_on=lambda exc: isinstance(exc, ValueError)) + assert mw._should_retry(ValueError()) is True + assert mw._should_retry(KeyError()) is False diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py new file mode 100644 index 000000000..eb9cf396c --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_skills_backends.py @@ -0,0 +1,242 @@ +"""Tests for the skills backends used by SurfSense's SkillsMiddleware.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from app.agents.new_chat.middleware.skills_backends import ( + SKILLS_BUILTIN_PREFIX, + SKILLS_SPACE_PREFIX, + BuiltinSkillsBackend, + SearchSpaceSkillsBackend, + build_skills_backend_factory, + default_skills_sources, +) + + +@pytest.fixture +def skills_root(tmp_path: Path) -> Path: + """Build a small synthetic skill-tree used by the tests.""" + root = tmp_path / "skills" + (root / "alpha").mkdir(parents=True) + (root / "alpha" / "SKILL.md").write_text( + "---\nname: alpha\ndescription: alpha skill\n---\n# Alpha\n" + ) + (root / "beta").mkdir(parents=True) + (root / "beta" / "SKILL.md").write_text( + "---\nname: beta\ndescription: beta skill\n---\n# Beta\n" + ) + (root / "_orphan_file.md").write_text("not a skill, just a stray file") + return root + + +class TestBuiltinSkillsBackendListing: + def test_lists_skill_directories_at_root(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + infos = backend.ls_info("/") + names = {info["path"] for info in infos} + assert "/alpha" in names + assert "/beta" in names + assert "/_orphan_file.md" in names + for info in infos: + if info["path"] in {"/alpha", "/beta"}: + assert info["is_dir"] is True + + def test_lists_skill_md_under_skill_directory(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + infos = backend.ls_info("/alpha") + paths = {info["path"] for info in infos} + assert paths == {"/alpha/SKILL.md"} + assert infos[0]["is_dir"] is False + assert infos[0]["size"] > 0 + + def test_returns_empty_for_missing_path(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + assert backend.ls_info("/nonexistent") == [] + + def test_returns_empty_when_root_missing(self, tmp_path: Path) -> None: + backend = BuiltinSkillsBackend(tmp_path / "definitely-missing") + assert backend.ls_info("/") == [] + assert backend.download_files(["/x/SKILL.md"])[0].error == "file_not_found" + + def test_refuses_path_traversal(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + assert backend.ls_info("/../../../etc") == [] + responses = backend.download_files(["/../../../etc/passwd"]) + assert responses[0].error == "invalid_path" + + +class TestBuiltinSkillsBackendDownload: + def test_downloads_skill_md_content(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha/SKILL.md", "/beta/SKILL.md"]) + assert len(responses) == 2 + assert responses[0].path == "/alpha/SKILL.md" + assert responses[0].content is not None + assert b"name: alpha" in responses[0].content + assert responses[1].error is None + + def test_marks_directory_as_is_directory_error(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha"]) + assert responses[0].error == "is_directory" + + def test_marks_missing_file_as_file_not_found(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + responses = backend.download_files(["/alpha/missing.md"]) + assert responses[0].error == "file_not_found" + assert responses[0].content is None + + def test_response_path_matches_input_for_correlation( + self, skills_root: Path + ) -> None: + backend = BuiltinSkillsBackend(skills_root) + inputs = ["/alpha/SKILL.md", "/missing.md", "/beta/SKILL.md"] + responses = backend.download_files(inputs) + assert [r.path for r in responses] == inputs + + +class TestBuiltinSkillsBackendIntegration: + """Mirror the call sequence the SkillsMiddleware actually uses.""" + + def test_skills_middleware_call_pattern(self, skills_root: Path) -> None: + backend = BuiltinSkillsBackend(skills_root) + + infos = asyncio.run(backend.als_info("/")) + skill_dirs = [i["path"] for i in infos if i.get("is_dir")] + assert sorted(skill_dirs) == ["/alpha", "/beta"] + + skill_md_paths = [f"{p}/SKILL.md" for p in skill_dirs] + responses = asyncio.run(backend.adownload_files(skill_md_paths)) + assert all(r.error is None for r in responses) + assert all(r.content is not None for r in responses) + + +class TestBundledSkills: + def test_default_root_resolves_to_repo_skills_dir(self) -> None: + backend = BuiltinSkillsBackend() + assert backend.root.name == "builtin" + assert backend.root.parent.name == "skills" + + def test_bundled_starter_skills_are_present(self) -> None: + backend = BuiltinSkillsBackend() + infos = backend.ls_info("/") + names = {info["path"].lstrip("/") for info in infos if info.get("is_dir")} + # Five starter skills required by the Tier 4 plan. + for required in ( + "kb-research", + "report-writing", + "meeting-prep", + "slack-summary", + "email-drafting", + ): + assert required in names, f"missing starter skill: {required}" + + def test_each_starter_skill_has_valid_skill_md(self) -> None: + backend = BuiltinSkillsBackend() + infos = backend.ls_info("/") + skill_dirs = [info["path"] for info in infos if info.get("is_dir")] + for skill_dir in skill_dirs: + md_path = f"{skill_dir}/SKILL.md" + response = backend.download_files([md_path])[0] + assert response.error is None, f"missing SKILL.md in {skill_dir}" + content = response.content.decode("utf-8").replace("\r\n", "\n") + assert content.startswith("---\n"), f"missing frontmatter in {skill_dir}" + assert "\nname:" in content + assert "\ndescription:" in content + + +class _FakeKBBackend: + """Stand-in for :class:`KBPostgresBackend` with the two methods we need.""" + + def __init__(self, listing: list[dict], file_contents: dict[str, bytes]) -> None: + self._listing = listing + self._file_contents = file_contents + self.last_ls_path: str | None = None + self.last_download_paths: list[str] | None = None + + async def als_info(self, path: str): + self.last_ls_path = path + return self._listing + + async def adownload_files(self, paths): + from deepagents.backends.protocol import FileDownloadResponse + + self.last_download_paths = list(paths) + out: list[FileDownloadResponse] = [] + for p in paths: + content = self._file_contents.get(p) + if content is None: + out.append(FileDownloadResponse(path=p, error="file_not_found")) + else: + out.append(FileDownloadResponse(path=p, content=content)) + return out + + +class TestSearchSpaceSkillsBackend: + def test_remaps_paths_when_listing(self) -> None: + listing = [ + {"path": "/documents/_skills/policy", "is_dir": True}, + {"path": "/documents/_skills/policy/SKILL.md", "is_dir": False}, + {"path": "/documents/other-folder/x.md", "is_dir": False}, + ] + kb = _FakeKBBackend(listing=listing, file_contents={}) + backend = SearchSpaceSkillsBackend(kb) + infos = asyncio.run(backend.als_info("/")) + assert kb.last_ls_path == "/documents/_skills" + paths = [info["path"] for info in infos] + assert "/policy" in paths + assert "/policy/SKILL.md" in paths + # Unrelated KB documents must NOT leak into the skills namespace. + assert all(not p.startswith("/documents") for p in paths) + + def test_remaps_paths_when_downloading(self) -> None: + kb = _FakeKBBackend( + listing=[], + file_contents={ + "/documents/_skills/policy/SKILL.md": b"---\nname: policy\n---\n", + }, + ) + backend = SearchSpaceSkillsBackend(kb) + responses = asyncio.run(backend.adownload_files(["/policy/SKILL.md"])) + assert kb.last_download_paths == ["/documents/_skills/policy/SKILL.md"] + assert responses[0].path == "/policy/SKILL.md" + assert responses[0].error is None + assert responses[0].content is not None + + def test_sync_methods_raise_not_implemented(self) -> None: + backend = SearchSpaceSkillsBackend(_FakeKBBackend([], {})) + with pytest.raises(NotImplementedError): + backend.ls_info("/") + with pytest.raises(NotImplementedError): + backend.download_files(["/x"]) + + def test_custom_kb_root_is_honored(self) -> None: + kb = _FakeKBBackend( + listing=[ + {"path": "/skills_admin/x", "is_dir": True}, + ], + file_contents={}, + ) + backend = SearchSpaceSkillsBackend(kb, kb_root="/skills_admin") + infos = asyncio.run(backend.als_info("/")) + assert kb.last_ls_path == "/skills_admin" + assert infos[0]["path"] == "/x" + + +class TestBackendFactory: + def test_builtin_only_factory_returns_composite(self) -> None: + factory = build_skills_backend_factory() + backend = factory(runtime=None) # type: ignore[arg-type] + from deepagents.backends.composite import CompositeBackend + + assert isinstance(backend, CompositeBackend) + assert SKILLS_BUILTIN_PREFIX in backend.routes + assert SKILLS_SPACE_PREFIX not in backend.routes + + def test_default_skills_sources_lists_builtin_then_space(self) -> None: + sources = default_skills_sources() + assert sources == [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX] diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py new file mode 100644 index 000000000..0adb578ce --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_specialized_subagents.py @@ -0,0 +1,339 @@ +"""Tests for the specialized subagents (explore / report_writer / connector_negotiator).""" + +from __future__ import annotations + +from langchain_core.tools import tool + +from app.agents.new_chat.middleware.permission import PermissionMiddleware +from app.agents.new_chat.subagents import ( + build_connector_negotiator_subagent, + build_explore_subagent, + build_report_writer_subagent, + build_specialized_subagents, +) +from app.agents.new_chat.subagents.config import ( + EXPLORE_READ_TOOLS, + REPORT_WRITER_TOOLS, + WRITE_TOOL_DENY_PATTERNS, +) + +# --------------------------------------------------------------------------- +# Fake tools used to verify filtering & permission behavior +# --------------------------------------------------------------------------- + + +@tool +def search_surfsense_docs(query: str) -> str: + """Search the user's KB.""" + return "" + + +@tool +def web_search(query: str) -> str: + """Search the public web.""" + return "" + + +@tool +def scrape_webpage(url: str) -> str: + """Scrape a single webpage.""" + return "" + + +@tool +def read_file(path: str) -> str: + """Read a file.""" + return "" + + +@tool +def ls_tree(path: str) -> str: + """List a tree.""" + return "" + + +@tool +def grep(pattern: str) -> str: + """Grep.""" + return "" + + +@tool +def update_memory(content: str) -> str: + """Update the user's memory.""" + return "" + + +@tool +def edit_file(path: str, old: str, new: str) -> str: + """Edit a file.""" + return "" + + +@tool +def linear_create_issue(title: str) -> str: + """Create a Linear issue.""" + return "" + + +@tool +def slack_send_message(channel: str, text: str) -> str: + """Send a Slack message.""" + return "" + + +@tool +def get_connected_accounts() -> str: + """List connected accounts.""" + return "" + + +@tool +def generate_report(topic: str) -> str: + """Generate a report artifact.""" + return "" + + +ALL_TOOLS = [ + search_surfsense_docs, + web_search, + scrape_webpage, + read_file, + ls_tree, + grep, + update_memory, + edit_file, + linear_create_issue, + slack_send_message, + get_connected_accounts, + generate_report, +] + + +class TestExploreSubagent: + def test_only_read_tools_are_exposed(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert names == EXPLORE_READ_TOOLS & {t.name for t in ALL_TOOLS} + assert "update_memory" not in names + assert "linear_create_issue" not in names + assert "edit_file" not in names + + def test_includes_permission_middleware_with_deny_rules(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + permission_mws = [ + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + assert len(permission_mws) == 1 + ruleset = permission_mws[0]._static_rulesets[0] + assert ruleset.origin == "subagent_explore" + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + assert "update_memory" in deny_patterns + assert "edit_file" in deny_patterns + assert "*create*" in deny_patterns + assert "*send*" in deny_patterns + + def test_skills_inherits_default_sources(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + assert spec["skills"] == ["/skills/builtin/", "/skills/space/"] # type: ignore[index] + + def test_name_and_description_match_contract(self) -> None: + spec = build_explore_subagent(tools=ALL_TOOLS) + assert spec["name"] == "explore" + assert "read-only" in spec["description"].lower() + + def test_includes_dedup_and_patch_middleware(self) -> None: + from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware + + from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware + + spec = build_explore_subagent(tools=ALL_TOOLS) + types = {type(m) for m in spec["middleware"]} # type: ignore[index] + assert PatchToolCallsMiddleware in types + assert DedupHITLToolCallsMiddleware in types + + +class TestReportWriterSubagent: + def test_exposes_only_report_writing_tools(self) -> None: + spec = build_report_writer_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert names == REPORT_WRITER_TOOLS & {t.name for t in ALL_TOOLS} + assert "generate_report" in names + assert "search_surfsense_docs" in names + + def test_deny_rules_block_writes_but_allow_generate_report(self) -> None: + spec = build_report_writer_subagent(tools=ALL_TOOLS) + permission_mws = [ + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + ruleset = permission_mws[0]._static_rulesets[0] + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + assert "update_memory" in deny_patterns + # generate_report MUST not be denied — it's the whole point of the subagent. + assert "generate_report" not in deny_patterns + # No deny pattern should match `generate_report` either. + assert all( + not _wildcard_matches(pattern, "generate_report") + for pattern in deny_patterns + ) + + +class TestConnectorNegotiatorSubagent: + def test_inherits_all_parent_tools(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + # Every parent tool is inherited; the deny ruleset enforces behavior + # at execution time instead of trimming the tool list. + assert names == {t.name for t in ALL_TOOLS} + + def test_get_connected_accounts_is_present(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + names = {t.name for t in spec["tools"]} # type: ignore[index] + assert "get_connected_accounts" in names + + def test_deny_ruleset_blocks_mutating_connector_tools(self) -> None: + spec = build_connector_negotiator_subagent(tools=ALL_TOOLS) + permission_mws = [ + m + for m in spec["middleware"] + if isinstance(m, PermissionMiddleware) # type: ignore[index] + ] + ruleset = permission_mws[0]._static_rulesets[0] + deny_patterns = {r.permission for r in ruleset.rules if r.action == "deny"} + # `linear_create_issue` matches the `*_create` deny pattern. + assert any(_wildcard_matches(p, "linear_create_issue") for p in deny_patterns) + assert any(_wildcard_matches(p, "slack_send_message") for p in deny_patterns) + + +class TestBuildSpecializedSubagents: + def test_returns_three_specs(self) -> None: + specs = build_specialized_subagents(tools=ALL_TOOLS) + names = [s["name"] for s in specs] # type: ignore[index] + assert names == ["explore", "report_writer", "connector_negotiator"] + + def test_all_specs_have_unique_names(self) -> None: + specs = build_specialized_subagents(tools=ALL_TOOLS) + names = [s["name"] for s in specs] # type: ignore[index] + assert len(set(names)) == len(names) + + def test_extra_middleware_is_prepended_to_each_spec(self) -> None: + """Sentinel middleware passed via ``extra_middleware`` must appear + in each subagent's ``middleware`` list, before the local rules. + + This guards against the regression where specialized subagents + promised filesystem tools (``read_file``, ``ls``, ``grep``) in + their system prompts but had no filesystem middleware mounted. + """ + + class _Sentinel: + pass + + sentinel = _Sentinel() + specs = build_specialized_subagents( + tools=ALL_TOOLS, extra_middleware=[sentinel] + ) + for spec in specs: + mws = spec["middleware"] # type: ignore[index] + assert sentinel in mws + # The sentinel must appear *before* the permission middleware + # (subagent-local rules), preserving the documented composition + # order: extra → custom → patch → dedup. + sentinel_idx = mws.index(sentinel) + perm_idx = next( + (i for i, m in enumerate(mws) if isinstance(m, PermissionMiddleware)), + None, + ) + assert perm_idx is not None + assert sentinel_idx < perm_idx + + +class TestFilterToolsWarningSuppression: + """Names provided by middleware (read_file, ls, grep, …) must not + trigger the spurious "missing" warning in :func:`_filter_tools`.""" + + def test_middleware_provided_names_are_silent(self, caplog) -> None: + import logging + + from app.agents.new_chat.subagents.config import _filter_tools + + with caplog.at_level( + logging.INFO, logger="app.agents.new_chat.subagents.config" + ): + # Allowed set asks for two registry tools (one present, one + # not) plus a bunch of middleware-provided names. + _filter_tools( + [search_surfsense_docs], + allowed_names={ + "search_surfsense_docs", + "scrape_webpage", # legitimately missing → should warn + "read_file", # mw-provided → suppressed + "ls", + "grep", + "glob", + "write_todos", + }, + ) + + warnings = [r.message for r in caplog.records if r.levelno >= logging.INFO] + # Exactly one warning, and it should mention scrape_webpage but not + # any middleware-provided name. Inspect the rendered "missing" + # list (between the brackets) so we don't false-match substrings + # like ``ls`` inside ``available``. + assert len(warnings) == 1, warnings + msg = warnings[0] + assert "scrape_webpage" in msg + bracket_section = msg.split("missing: ", 1)[1] + for noisy in ("read_file", "ls", "grep", "glob", "write_todos"): + assert f"'{noisy}'" not in bracket_section, msg + + +class TestDenyPatternsCoverage: + def test_deny_patterns_cover_canonical_write_tools(self) -> None: + canonical_writes = [ + "update_memory", + "edit_file", + "write_file", + "move_file", + "mkdir", + "linear_create_issue", + "linear_update_issue", + "linear_delete_issue", + "slack_send_message", + "create_index", + "update_account", + "delete_record", + "send_email", + ] + for tool_name in canonical_writes: + assert any( + _wildcard_matches(pattern, tool_name) + for pattern in WRITE_TOOL_DENY_PATTERNS + ), f"no deny pattern matches {tool_name!r}" + + def test_deny_patterns_do_not_match_safe_read_tools(self) -> None: + canonical_reads = [ + "search_surfsense_docs", + "read_file", + "ls_tree", + "grep", + "web_search", + "scrape_webpage", + "get_connected_accounts", + "generate_report", + ] + for tool_name in canonical_reads: + assert not any( + _wildcard_matches(pattern, tool_name) + for pattern in WRITE_TOOL_DENY_PATTERNS + ), f"deny pattern incorrectly matches read tool {tool_name!r}" + + +def _wildcard_matches(pattern: str, value: str) -> bool: + """Helper using the same matcher the rule evaluator does.""" + from app.agents.new_chat.permissions import wildcard_match + + return wildcard_match(value, pattern) diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py new file mode 100644 index 000000000..3caeb9a34 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_state_reducers.py @@ -0,0 +1,107 @@ +"""Tests for SurfSense filesystem state reducers.""" + +from __future__ import annotations + +import pytest + +from app.agents.new_chat.state_reducers import ( + _CLEAR, + _add_unique_reducer, + _dict_merge_with_tombstones_reducer, + _initial_filesystem_state, + _list_append_reducer, + _replace_reducer, +) + +pytestmark = pytest.mark.unit + + +class TestReplaceReducer: + def test_right_wins_outright(self): + assert _replace_reducer("a", "b") == "b" + + def test_none_right_returns_none(self): + assert _replace_reducer("a", None) is None + + def test_none_left_returns_right(self): + assert _replace_reducer(None, "b") == "b" + + +class TestAddUniqueReducer: + def test_appends_unique_items(self): + assert _add_unique_reducer(["a"], ["b", "c"]) == ["a", "b", "c"] + + def test_dedupes_against_left(self): + assert _add_unique_reducer(["a", "b"], ["b", "c"]) == ["a", "b", "c"] + + def test_dedupes_within_right(self): + assert _add_unique_reducer([], ["a", "a", "b"]) == ["a", "b"] + + def test_clear_anywhere_resets_and_reseeds_with_after_items(self): + # _CLEAR semantics: only items AFTER the LAST _CLEAR are kept. + result = _add_unique_reducer(["x", "y"], ["a", _CLEAR, "b", "c"]) + assert result == ["b", "c"] + + def test_multiple_clears_use_last(self): + result = _add_unique_reducer(["x"], [_CLEAR, "a", _CLEAR, "b"]) + assert result == ["b"] + + def test_clear_only_resets_to_empty(self): + assert _add_unique_reducer(["x", "y"], [_CLEAR]) == [] + + def test_empty_right_keeps_left(self): + assert _add_unique_reducer(["a"], []) == ["a"] + assert _add_unique_reducer(["a"], None) == ["a"] + + +class TestListAppendReducer: + def test_preserves_order_and_duplicates(self): + result = _list_append_reducer([{"a": 1}], [{"b": 2}, {"a": 1}]) + assert result == [{"a": 1}, {"b": 2}, {"a": 1}] + + def test_clear_resets_keeping_after_items(self): + result = _list_append_reducer([{"a": 1}], [{"old": 1}, _CLEAR, {"new": 2}]) + assert result == [{"new": 2}] + + +class TestDictMergeWithTombstones: + def test_merges_keys(self): + assert _dict_merge_with_tombstones_reducer({"a": 1}, {"b": 2}) == { + "a": 1, + "b": 2, + } + + def test_none_value_deletes_key(self): + result = _dict_merge_with_tombstones_reducer({"a": 1, "b": 2}, {"a": None}) + assert result == {"b": 2} + + def test_clear_resets_then_merges(self): + result = _dict_merge_with_tombstones_reducer( + {"a": 1, "b": 2}, {_CLEAR: True, "c": 3} + ) + assert result == {"c": 3} + + def test_clear_keeps_only_post_clear_non_none(self): + result = _dict_merge_with_tombstones_reducer( + {"a": 1}, {_CLEAR: True, "b": 2, "c": None} + ) + assert result == {"b": 2} + + def test_none_left_handled(self): + assert _dict_merge_with_tombstones_reducer(None, {"a": 1, "b": None}) == { + "a": 1 + } + + +class TestInitialFilesystemState: + def test_default_shape(self): + state = _initial_filesystem_state() + assert state["cwd"] == "/documents" + assert state["staged_dirs"] == [] + assert state["pending_moves"] == [] + assert state["doc_id_by_path"] == {} + assert state["dirty_paths"] == [] + assert state["kb_priority"] == [] + assert state["kb_matched_chunk_ids"] == {} + assert state["kb_anon_doc"] is None + assert state["tree_version"] == 0 diff --git a/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py new file mode 100644 index 000000000..e02a04774 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/new_chat/test_tool_call_repair.py @@ -0,0 +1,121 @@ +"""Tests for ToolCallNameRepairMiddleware.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.new_chat.middleware.tool_call_repair import ( + ToolCallNameRepairMiddleware, +) +from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME + +pytestmark = pytest.mark.unit + + +def _make_state(message: AIMessage) -> dict: + return {"messages": [message]} + + +class _FakeRuntime: + def __init__(self, context: object | None = None) -> None: + self.context = context + + +class TestRepair: + def test_passthrough_when_name_matches(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "echo", "args": {}, "id": "1"}, + ], + ) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is None # no change + + def test_lowercase_repair(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "Echo", "args": {"x": 1}, "id": "1"}, + ], + ) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + repaired = out["messages"][0] + assert repaired.tool_calls[0]["name"] == "echo" + + def test_invalid_fallback_when_no_match(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo", INVALID_TOOL_NAME}, + fuzzy_match_threshold=None, + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "totally_different_name", "args": {"k": "v"}, "id": "1"}, + ], + ) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + repaired_call = out["messages"][0].tool_calls[0] + assert repaired_call["name"] == INVALID_TOOL_NAME + assert repaired_call["args"]["tool"] == "totally_different_name" + assert "totally_different_name" in repaired_call["args"]["error"] + + def test_no_invalid_means_skip_when_unknown(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "unknown", "args": {}, "id": "1"}, + ], + ) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + # No repair available; original returned unchanged (no update) + assert out is None + + def test_fuzzy_match_works_when_enabled(self) -> None: + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"search_documents"}, + fuzzy_match_threshold=0.7, + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "search_docments", "args": {}, "id": "1"}, + ], + ) + out = mw.after_model(_make_state(msg), _FakeRuntime()) + assert out is not None + assert out["messages"][0].tool_calls[0]["name"] == "search_documents" + + def test_skips_when_no_messages(self) -> None: + mw = ToolCallNameRepairMiddleware(registered_tool_names={"echo"}) + out = mw.after_model({"messages": []}, _FakeRuntime()) + assert out is None + + def test_runtime_context_extends_registered(self) -> None: + from types import SimpleNamespace + + mw = ToolCallNameRepairMiddleware( + registered_tool_names={"echo"}, fuzzy_match_threshold=None + ) + msg = AIMessage( + content="", + tool_calls=[ + {"name": "DynamicTool", "args": {}, "id": "1"}, + ], + ) + runtime = _FakeRuntime(SimpleNamespace(registered_tool_names=["dynamictool"])) + out = mw.after_model(_make_state(msg), runtime) + assert out is not None + assert out["messages"][0].tool_calls[0]["name"] == "dynamictool" diff --git a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py index add0105e4..467ba6d5f 100644 --- a/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py +++ b/surfsense_backend/tests/unit/middleware/test_dedup_hitl_tool_calls.py @@ -1,8 +1,10 @@ import pytest from langchain_core.messages import AIMessage +from langchain_core.tools import StructuredTool from app.agents.new_chat.middleware.dedup_tool_calls import ( DedupHITLToolCallsMiddleware, + wrap_dedup_key_by_arg_name, ) pytestmark = pytest.mark.unit @@ -14,9 +16,34 @@ def _make_state(tool_calls: list[dict]) -> dict: return {"messages": [msg]} +def _hitl_tool(name: str, *, dedup_arg: str) -> StructuredTool: + """Build a tool with declarative ``dedup_key`` metadata. + + Mirrors the ``ToolDefinition.dedup_key`` -> ``tool.metadata["dedup_key"]`` + propagation done by :func:`build_tools` after the cleanup tier. + """ + + def _fn(**kwargs): + return "ok" + + return StructuredTool.from_function( + func=_fn, + name=name, + description="x", + metadata={"dedup_key": wrap_dedup_key_by_arg_name(dedup_arg)}, + ) + + def test_duplicate_hitl_calls_reduced_to_first(): - """When the LLM emits the same HITL tool call twice, only the first is kept.""" - mw = DedupHITLToolCallsMiddleware() + """When the LLM emits the same HITL tool call twice, only the first is kept. + + After the cleanup tier removed ``_NATIVE_HITL_TOOL_DEDUP_KEYS``, the + resolver is sourced from ``ToolDefinition.dedup_key`` propagated onto + ``tool.metadata`` — which the registry does at agent build time. The + test mirrors that wiring with an in-memory tool. + """ + tool = _hitl_tool("delete_calendar_event", dedup_arg="event_title_or_id") + mw = DedupHITLToolCallsMiddleware(agent_tools=[tool]) state = _make_state( [ diff --git a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py index 673331b0a..7fd3fe4a7 100644 --- a/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py +++ b/surfsense_backend/tests/unit/middleware/test_file_intent_middleware.py @@ -45,9 +45,7 @@ async def test_file_write_intent_injects_contract_message(): @pytest.mark.asyncio async def test_non_write_intent_does_not_inject_contract_message(): - llm = _FakeLLM( - '{"intent":"file_read","confidence":0.88,"suggested_filename":null}' - ) + llm = _FakeLLM('{"intent":"file_read","confidence":0.88,"suggested_filename":null}') middleware = FileIntentMiddleware(llm=llm) original_messages = [HumanMessage(content="Read /notes.md")] state = {"messages": original_messages, "turn_id": "abc:def"} @@ -55,7 +53,10 @@ async def test_non_write_intent_does_not_inject_contract_message(): result = await middleware.abefore_agent(state, runtime=None) # type: ignore[arg-type] assert result is not None - assert result["file_operation_contract"]["intent"] == FileOperationIntent.FILE_READ.value + assert ( + result["file_operation_contract"]["intent"] + == FileOperationIntent.FILE_READ.value + ) assert "messages" not in result @@ -211,4 +212,3 @@ def test_fallback_path_keeps_posix_style_absolute_path_for_linux_and_macos() -> ) assert resolved == "/var/log/surfsense/notes.md" - diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py index 98996d6bc..c71b5efde 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_backends.py @@ -36,11 +36,18 @@ def test_backend_resolver_returns_multi_root_backend_for_single_root(tmp_path: P def test_backend_resolver_uses_cloud_mode_by_default(): resolver = build_backend_resolver(FilesystemSelection()) backend = resolver(_RuntimeStub()) - # StateBackend class name check keeps this test decoupled - # from internal deepagents runtime class identity. + # When no search_space_id is provided we fall back to StateBackend so + # sub-agents / tests without DB access still work. assert backend.__class__.__name__ == "StateBackend" +def test_backend_resolver_uses_kb_postgres_in_cloud_with_search_space(): + resolver = build_backend_resolver(FilesystemSelection(), search_space_id=42) + backend = resolver(_RuntimeStub()) + assert backend.__class__.__name__ == "KBPostgresBackend" + assert backend.search_space_id == 42 + + def test_backend_resolver_returns_multi_root_backend_for_multiple_roots(tmp_path: Path): root_one = tmp_path / "resume" root_two = tmp_path / "notes" diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py new file mode 100644 index 000000000..c2e304399 --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_middleware.py @@ -0,0 +1,204 @@ +"""Unit tests for the SurfSense filesystem middleware new behaviors. + +Covers: +* cloud cwd defaults to ``/documents`` and relative paths resolve under it +* cloud writes outside ``/documents/`` are rejected unless basename starts + with ``temp_`` +* cloud writes/edits to the anonymous document are rejected (read-only) +* helper methods on the middleware (``_resolve_relative``, + ``_check_cloud_write_namespace``, ``_default_cwd``) + +These tests use ``__new__`` to bypass the heavy ``__init__`` and exercise +the helper methods directly so the test surface stays narrow and fast. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import ( + SurfSenseFilesystemMiddleware, + _build_filesystem_system_prompt, + _build_tool_descriptions, +) + +pytestmark = pytest.mark.unit + + +def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD): + middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) + middleware._filesystem_mode = mode + return middleware + + +def _runtime(state: dict | None = None) -> SimpleNamespace: + return SimpleNamespace(state=state or {}) + + +class TestCloudCwdDefaults: + def test_default_cwd_in_cloud_is_documents_root(self): + m = _make_middleware() + assert m._default_cwd() == "/documents" + + def test_default_cwd_in_desktop_is_root(self): + m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) + assert m._default_cwd() == "/" + + def test_current_cwd_uses_state_when_set(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/notes"}) + assert m._current_cwd(runtime) == "/documents/notes" + + def test_current_cwd_falls_back_to_default(self): + m = _make_middleware() + runtime = _runtime({}) + assert m._current_cwd(runtime) == "/documents" + + def test_current_cwd_ignores_invalid(self): + m = _make_middleware() + runtime = _runtime({"cwd": "not-absolute"}) + assert m._current_cwd(runtime) == "/documents" + + +class TestRelativePathResolution: + def test_relative_path_resolves_against_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/projects"}) + assert ( + m._resolve_relative("notes.md", runtime) == "/documents/projects/notes.md" + ) + + def test_relative_path_with_dotdot(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/a/b"}) + assert m._resolve_relative("../c.md", runtime) == "/documents/a/c.md" + + def test_absolute_path_is_kept(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents"}) + assert m._resolve_relative("/other/x.md", runtime) == "/other/x.md" + + def test_empty_path_returns_cwd(self): + m = _make_middleware() + runtime = _runtime({"cwd": "/documents/projects"}) + assert m._resolve_relative("", runtime) == "/documents/projects" + + +class TestCloudWriteNamespacePolicy: + def test_documents_path_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/documents/foo.md", runtime) is None + + def test_documents_root_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/documents", runtime) is None + + def test_temp_basename_anywhere_allowed(self): + m = _make_middleware() + runtime = _runtime() + assert m._check_cloud_write_namespace("/temp_scratch.md", runtime) is None + assert m._check_cloud_write_namespace("/foo/temp_x.md", runtime) is None + assert m._check_cloud_write_namespace("/documents/temp_x.md", runtime) is None + + def test_other_paths_rejected(self): + m = _make_middleware() + runtime = _runtime() + err = m._check_cloud_write_namespace("/foo/bar.md", runtime) + assert err is not None + assert "must target /documents" in err + + def test_anon_doc_path_is_read_only(self): + m = _make_middleware() + runtime = _runtime( + { + "kb_anon_doc": { + "path": "/documents/uploaded.xml", + "title": "uploaded", + "content": "", + "chunks": [], + } + } + ) + err = m._check_cloud_write_namespace("/documents/uploaded.xml", runtime) + assert err is not None + assert "read-only" in err + + def test_desktop_mode_skips_namespace_policy(self): + m = _make_middleware(FilesystemMode.DESKTOP_LOCAL_FOLDER) + runtime = _runtime() + assert m._check_cloud_write_namespace("/random/path.md", runtime) is None + + +class TestModeSpecificPrompts: + """The prompt and tool descriptions must only describe the active mode. + + Cross-mode noise wastes tokens and confuses the model with rules it + cannot use this session. + """ + + def test_cloud_prompt_omits_desktop_section(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=False + ) + assert "Local Folder Mode" not in prompt + assert "mount-prefixed" not in prompt + assert "Persistence Rules" in prompt + assert "/documents" in prompt + assert "temp_" in prompt + + def test_desktop_prompt_omits_cloud_persistence_rules(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.DESKTOP_LOCAL_FOLDER, sandbox_available=False + ) + assert "Persistence Rules" not in prompt + assert "Workspace Tree" not in prompt + assert "" not in prompt + assert "Local Folder Mode" in prompt + assert "mount-prefixed" in prompt + + def test_cloud_tool_descs_omit_desktop_phrases(self): + descs = _build_tool_descriptions(FilesystemMode.CLOUD) + for name in ( + "write_file", + "edit_file", + "move_file", + "mkdir", + "list_tree", + "grep", + ): + text = descs[name] + assert "Desktop" not in text, f"{name} leaks desktop hints" + assert "Cloud mode:" not in text, f"{name} qualifies a cloud-only desc" + + def test_desktop_tool_descs_omit_cloud_phrases(self): + descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER) + for name in ( + "write_file", + "edit_file", + "move_file", + "mkdir", + "list_tree", + "grep", + ): + text = descs[name] + assert "Cloud" not in text, f"{name} leaks cloud hints" + assert "/documents/" not in text, f"{name} mentions cloud namespace" + assert "temp_" not in text, f"{name} mentions cloud temp_ semantics" + + def test_sandbox_addendum_appended_when_available(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=True + ) + assert "execute_code" in prompt + assert "Code Execution" in prompt + + def test_sandbox_addendum_absent_when_unavailable(self): + prompt = _build_filesystem_system_prompt( + FilesystemMode.CLOUD, sandbox_available=False + ) + assert "execute_code" not in prompt diff --git a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py index d00365032..81cf590d3 100644 --- a/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py +++ b/surfsense_backend/tests/unit/middleware/test_filesystem_verification.py @@ -2,34 +2,15 @@ from pathlib import Path import pytest +from app.agents.new_chat.filesystem_selection import FilesystemMode +from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware from app.agents.new_chat.middleware.multi_root_local_folder_backend import ( MultiRootLocalFolderBackend, ) -from app.agents.new_chat.filesystem_selection import FilesystemMode -from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware pytestmark = pytest.mark.unit -class _BackendWithRawRead: - def __init__(self, content: str) -> None: - self._content = content - - def read(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: - del file_path, offset, limit - return " 1\tline1\n 2\tline2" - - async def aread(self, file_path: str, offset: int = 0, limit: int = 200000) -> str: - return self.read(file_path, offset, limit) - - def read_raw(self, file_path: str) -> str: - del file_path - return self._content - - async def aread_raw(self, file_path: str) -> str: - return self.read_raw(file_path) - - class _RuntimeNoSuggestedPath: state = {"file_operation_contract": {}} @@ -39,40 +20,19 @@ class _RuntimeWithSuggestedPath: self.state = {"file_operation_contract": {"suggested_path": suggested_path}} -def test_verify_written_content_prefers_raw_sync() -> None: - middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - expected = "line1\nline2" - backend = _BackendWithRawRead(expected) - - verify_error = middleware._verify_written_content_sync( - backend=backend, - path="/note.md", - expected_content=expected, - ) - - assert verify_error is None - - -def test_contract_suggested_path_falls_back_to_notes_md() -> None: +def test_contract_suggested_path_falls_back_to_documents_notes_md() -> None: middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) middleware._filesystem_mode = FilesystemMode.CLOUD suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] - assert suggested == "/notes.md" + # Cloud default cwd is /documents so the fallback lands in the KB. + assert suggested == "/documents/notes.md" -@pytest.mark.asyncio -async def test_verify_written_content_prefers_raw_async() -> None: +def test_contract_suggested_path_falls_back_to_root_notes_md_in_desktop() -> None: middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware) - expected = "line1\nline2" - backend = _BackendWithRawRead(expected) - - verify_error = await middleware._verify_written_content_async( - backend=backend, - path="/note.md", - expected_content=expected, - ) - - assert verify_error is None + middleware._filesystem_mode = FilesystemMode.DESKTOP_LOCAL_FOLDER + suggested = middleware._get_contract_suggested_path(_RuntimeNoSuggestedPath()) # type: ignore[arg-type] + assert suggested == "/notes.md" def test_normalize_local_mount_path_prefixes_default_mount(tmp_path: Path) -> None: diff --git a/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py new file mode 100644 index 000000000..ef95434bf --- /dev/null +++ b/surfsense_backend/tests/unit/middleware/test_kb_persistence_filesystem_parity.py @@ -0,0 +1,168 @@ +"""Unit tests for kb_persistence filesystem-parity invariants. + +Specifically, these tests pin down that the agent-driven write_file flow +treats path uniqueness — not content uniqueness — as the only hard +invariant. This mirrors a real filesystem: ``cp a b`` produces two files +with identical bytes living at different paths, and that should round-trip +through :class:`KnowledgeBasePersistenceMiddleware` without losing the copy. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock + +import numpy as np +import pytest + +from app.agents.new_chat.middleware import kb_persistence +from app.db import Document + + +class _FakeResult: + """Minimal stand-in for ``sqlalchemy.engine.Result``.""" + + def __init__(self, value: Any = None) -> None: + self._value = value + + def scalar_one_or_none(self) -> Any: + return self._value + + def scalar(self) -> Any: + return self._value + + +class _FakeSession: + """Minimal AsyncSession stand-in scoped to ``_create_document`` needs. + + Records every ``add`` so we can assert against the resulting Documents + and Chunks. ``execute`` always returns "no row" by default — i.e. no + folder hierarchy preexists and no path collision exists. Tests that + want a path collision can override that on a per-call basis. + """ + + def __init__(self) -> None: + self.added: list[Any] = [] + self.execute = AsyncMock(return_value=_FakeResult(None)) + self.flush = AsyncMock() + + # Simulate ``await session.flush()`` assigning an id to the doc; + # we increment a counter so each Document gets a unique id. + self._next_id = 1 + + async def _flush_assigning_ids() -> None: + for obj in self.added: + if getattr(obj, "id", None) is None: + obj.id = self._next_id + self._next_id += 1 + + self.flush.side_effect = _flush_assigning_ids + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: list[Any]) -> None: + self.added.extend(objs) + + +@pytest.fixture(autouse=True) +def _stub_embeddings_and_chunks(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid loading the embedding model in unit tests.""" + monkeypatch.setattr( + kb_persistence, + "embed_texts", + lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts], + ) + monkeypatch.setattr(kb_persistence, "chunk_text", lambda content: [content]) + + +@pytest.mark.asyncio +async def test_create_document_allows_identical_content_at_different_paths() -> None: + """The core regression: ``cp /a/notes.md /b/notes-copy.md``. + + Both create calls must succeed even though the bytes are byte-for-byte + identical, because path is the only filesystem-style unique key. + """ + session = _FakeSession() + content = "# Same body\n\nIdentical content used by two different paths.\n" + + first = await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/a/notes.md", + content=content, + search_space_id=42, + created_by_id="user-1", + ) + assert isinstance(first, Document) + assert first.title == "notes.md" + + # Second create with byte-identical content at a different path should + # not raise — that's the whole point of the filesystem-parity fix. + second = await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/b/notes-copy.md", + content=content, + search_space_id=42, + created_by_id="user-1", + ) + assert isinstance(second, Document) + assert second.title == "notes-copy.md" + + # Both rows share the same content_hash but live at distinct paths + # (distinct ``unique_identifier_hash``). That's the desired contract. + assert first.content_hash == second.content_hash + assert first.unique_identifier_hash != second.unique_identifier_hash + + +@pytest.mark.asyncio +async def test_create_document_still_rejects_path_collision() -> None: + """Path uniqueness remains the hard invariant. + + If ``unique_identifier_hash`` already points at an existing row in + the same search space, the create call must raise ``ValueError`` + with a clear message — matching the behavior the commit loop relies + on to upsert via the existing-row code path. + """ + session = _FakeSession() + + # Path with no folder parts so ``_ensure_folder_hierarchy`` is a + # no-op and the only SELECT executed is the path-collision check. + # That SELECT returns an existing doc id, triggering the guard. + session.execute = AsyncMock(return_value=_FakeResult(value=99)) + + with pytest.raises(ValueError, match="already exists at path"): + await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/notes.md", + content="anything", + search_space_id=42, + created_by_id="user-1", + ) + + +@pytest.mark.asyncio +async def test_create_document_does_not_query_for_content_hash_collision( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression guard: the legacy second SELECT (content_hash collision + pre-check) must be gone. Counting ``execute`` calls is a brittle but + effective way to lock that in. + + The current flow runs exactly one ``execute`` for the path-collision + SELECT (no folder parts in this path → ``_ensure_folder_hierarchy`` + short-circuits). If a future refactor reintroduces a content-hash + SELECT, this test will fail loud. + """ + session = _FakeSession() + await kb_persistence._create_document( + session, # type: ignore[arg-type] + virtual_path="/documents/notes.md", + content="hello", + search_space_id=42, + created_by_id="user-1", + ) + # Path-collision SELECT only. No content_hash SELECT. + assert session.execute.await_count == 1, ( + f"Unexpected execute count {session.execute.await_count}; " + "did the legacy content_hash collision pre-check get re-added?" + ) diff --git a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py index 1aaf5d127..2ca470680 100644 --- a/surfsense_backend/tests/unit/middleware/test_knowledge_search.py +++ b/surfsense_backend/tests/unit/middleware/test_knowledge_search.py @@ -5,10 +5,10 @@ import json import pytest from langchain_core.messages import AIMessage, HumanMessage +from app.agents.new_chat.document_xml import build_document_xml as _build_document_xml from app.agents.new_chat.middleware.knowledge_search import ( KBSearchPlan, KnowledgeBaseSearchMiddleware, - _build_document_xml, _normalize_optional_date_range, _parse_kb_search_plan_response, _render_recent_conversation, @@ -248,17 +248,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps( @@ -298,17 +291,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) middleware = KnowledgeBaseSearchMiddleware( llm=FakeLLM("not json"), @@ -334,17 +320,10 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) middleware = KnowledgeBaseSearchMiddleware( llm=FakeLLM( @@ -386,9 +365,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: search_called = True return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", fake_browse_recent_documents, @@ -397,10 +373,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps( @@ -440,9 +412,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: search_captured.update(kwargs) return [] - async def fake_build_scoped_filesystem(**kwargs): - return {}, {} - monkeypatch.setattr( "app.agents.new_chat.middleware.knowledge_search.browse_recent_documents", fake_browse_recent_documents, @@ -451,10 +420,6 @@ class TestKnowledgeBaseSearchMiddlewarePlanner: "app.agents.new_chat.middleware.knowledge_search.search_knowledge_base", fake_search_knowledge_base, ) - monkeypatch.setattr( - "app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem", - fake_build_scoped_filesystem, - ) llm = FakeLLM( json.dumps( diff --git a/surfsense_backend/tests/unit/observability/__init__.py b/surfsense_backend/tests/unit/observability/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/__init__.py @@ -0,0 +1 @@ + diff --git a/surfsense_backend/tests/unit/observability/test_otel.py b/surfsense_backend/tests/unit/observability/test_otel.py new file mode 100644 index 000000000..fc5813973 --- /dev/null +++ b/surfsense_backend/tests/unit/observability/test_otel.py @@ -0,0 +1,84 @@ +"""Tests for the SurfSense OpenTelemetry shim.""" + +from __future__ import annotations + +import pytest + +from app.observability import otel + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_otel_state(monkeypatch: pytest.MonkeyPatch): + """Force a clean OTel disabled state per test, then restore after.""" + for env in ("OTEL_EXPORTER_OTLP_ENDPOINT", "SURFSENSE_DISABLE_OTEL"): + monkeypatch.delenv(env, raising=False) + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + otel.reload_for_tests() + yield + otel.reload_for_tests() + + +def test_disabled_by_default_when_no_endpoint() -> None: + assert otel.is_enabled() is False + + +def test_enabled_when_endpoint_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + assert otel.reload_for_tests() is True + + +def test_kill_switch_overrides_endpoint(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + monkeypatch.setenv("SURFSENSE_DISABLE_OTEL", "true") + assert otel.reload_for_tests() is False + + +class TestNoopSpansWhenDisabled: + def test_generic_span_yields_noop(self) -> None: + with otel.span("any.thing", attributes={"x": 1}) as sp: + sp.set_attribute("y", 2) + sp.set_attributes({"a": "b"}) + sp.add_event("evt") + sp.record_exception(RuntimeError("ignored")) + sp.set_status("ignored") + # Reaching here without raising means the no-op is well-formed + + def test_exception_propagates_through_span(self) -> None: + with pytest.raises(ValueError), otel.span("err"): + raise ValueError("boom") + + def test_each_helper_is_a_no_op_when_disabled(self) -> None: + helpers = [ + otel.tool_call_span("write_file", input_size=42), + otel.model_call_span(model_id="openai:gpt-4o", provider="openai"), + otel.kb_search_span(search_space_id=1, query_chars=99), + otel.kb_persist_span(document_type="NOTE", document_id=7), + otel.compaction_span(reason="overflow", messages_in=120), + otel.interrupt_span(interrupt_type="permission_ask"), + otel.permission_asked_span(permission="edit", pattern="/x/**"), + ] + for cm in helpers: + with cm as sp: + assert sp is not None + sp.set_attribute("ok", True) + + +class TestEnabledIntegration: + """When OTel is wired but no SDK exporter is bound, the API still works.""" + + def test_span_attaches_attributes(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Use the API tracer (no-op-ish but real Span objects). + monkeypatch.delenv("SURFSENSE_DISABLE_OTEL", raising=False) + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317") + assert otel.reload_for_tests() is True + + # Should not raise even when set_attributes/record_exception fall through + # to an SDK that isn't actually installed. + with otel.tool_call_span("scrape_webpage", input_size=10) as sp: + sp.set_attribute("tool.output.size", 200) + sp.set_attribute("tool.truncated", False) + with otel.model_call_span(model_id="m", provider="p") as sp: + sp.set_attribute("retry.count", 3) diff --git a/surfsense_backend/tests/unit/services/test_revert_service.py b/surfsense_backend/tests/unit/services/test_revert_service.py new file mode 100644 index 000000000..a81e52041 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_revert_service.py @@ -0,0 +1,46 @@ +"""Unit tests for the agent revert service.""" + +from __future__ import annotations + +from typing import Any + +from app.services.revert_service import can_revert + + +class _FakeAction: + def __init__(self, *, user_id: Any, tool_name: str = "edit_file") -> None: + self.user_id = user_id + self.tool_name = tool_name + + +class TestCanRevert: + def test_owner_can_revert_their_own_action(self) -> None: + action = _FakeAction(user_id="user-123") + assert can_revert(requester_user_id="user-123", action=action, is_admin=False) + + def test_other_user_cannot_revert(self) -> None: + action = _FakeAction(user_id="user-123") + assert not can_revert( + requester_user_id="someone-else", action=action, is_admin=False + ) + + def test_admin_always_allowed(self) -> None: + action = _FakeAction(user_id="user-123") + assert can_revert(requester_user_id="anybody", action=action, is_admin=True) + + def test_admin_can_revert_anonymous_action(self) -> None: + action = _FakeAction(user_id=None) + assert can_revert(requester_user_id="admin", action=action, is_admin=True) + + def test_anonymous_action_blocks_non_admin(self) -> None: + action = _FakeAction(user_id=None) + assert not can_revert(requester_user_id="user-1", action=action, is_admin=False) + + def test_uuid_string_normalization(self) -> None: + """``user_id`` may be a UUID object; comparison should still work.""" + import uuid + + u = uuid.uuid4() + action = _FakeAction(user_id=u) + # Same UUID, passed as string from the requesting side. + assert can_revert(requester_user_id=str(u), action=action, is_admin=False) diff --git a/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py b/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py index 7ab3c52e0..20795c739 100644 --- a/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py +++ b/surfsense_backend/tests/unit/test_obsidian_plugin_indexer.py @@ -15,7 +15,6 @@ from app.services.obsidian_plugin_indexer import ( _require_extracted_attachment_content, ) - _FAKE_PNG_B64 = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode("ascii") @@ -102,9 +101,7 @@ async def test_extract_binary_attachment_markdown_uses_etl(monkeypatch) -> None: mime_type="application/pdf", ) - async def _fake_run_etl_extract( # noqa: ANN001 - *, file_path, filename, vision_llm - ): + async def _fake_run_etl_extract(*, file_path, filename, vision_llm): assert filename == "spec.pdf" assert file_path assert vision_llm is None @@ -216,7 +213,7 @@ def test_note_payload_rejects_markdown_with_binary_fields() -> None: def test_require_extracted_attachment_content_rejects_empty_content() -> None: with pytest.raises( - RuntimeError, match="Attachment extraction failed for assets/img.png" + RuntimeError, match=r"Attachment extraction failed for assets/img\.png" ): _require_extracted_attachment_content( content=" ", diff --git a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py index f4adc3d73..034aa484c 100644 --- a/surfsense_backend/tests/unit/test_stream_new_chat_contract.py +++ b/surfsense_backend/tests/unit/test_stream_new_chat_contract.py @@ -45,4 +45,3 @@ def test_contract_enforcement_local_only(): result.filesystem_mode = "cloud" assert not _contract_enforcement_active(result) - diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx new file mode 100644 index 000000000..b01f556ad --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentPermissionsContent.tsx @@ -0,0 +1,451 @@ +"use client"; + +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; +import { AlertTriangle, Check, Plus, ShieldCheck, Trash2, X } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { toast } from "sonner"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Spinner } from "@/components/ui/spinner"; +import { + type AgentPermissionAction, + type AgentPermissionRule, + type AgentPermissionRuleCreate, + agentPermissionsApiService, +} from "@/lib/apis/agent-permissions-api.service"; +import { AppError } from "@/lib/error"; +import { formatRelativeDate } from "@/lib/format-date"; +import { cn } from "@/lib/utils"; + +const ACTION_DESCRIPTIONS: Record = { + allow: "Always run without prompting", + deny: "Block silently", + ask: "Pause and ask for approval", +}; + +const ACTION_BADGE: Record = { + allow: { label: "Allow", className: "bg-emerald-500/10 text-emerald-600 border-emerald-500/30" }, + deny: { label: "Deny", className: "bg-destructive/10 text-destructive border-destructive/30" }, + ask: { label: "Ask", className: "bg-amber-500/10 text-amber-600 border-amber-500/30" }, +}; + +const EMPTY_FORM: AgentPermissionRuleCreate = { + permission: "", + pattern: "*", + action: "ask", + user_id: null, + thread_id: null, +}; + +function permissionRulesQueryKey(searchSpaceId: number) { + return ["agent-permission-rules", searchSpaceId] as const; +} + +function ScopeBadge({ rule }: { rule: AgentPermissionRule }) { + if (rule.thread_id !== null) { + return ( + + Thread #{rule.thread_id} + + ); + } + if (rule.user_id !== null) { + return ( + + User-specific + + ); + } + return ( + + Search space + + ); +} + +export function AgentPermissionsContent() { + const searchSpaceIdRaw = useAtomValue(activeSearchSpaceIdAtom); + const searchSpaceId = searchSpaceIdRaw ? Number(searchSpaceIdRaw) : null; + + const { data: flags } = useAtomValue(agentFlagsAtom); + const featureEnabled = !!flags?.enable_permission && !flags?.disable_new_agent_stack; + + const queryClient = useQueryClient(); + + const { + data: rules, + isLoading, + isError, + error, + } = useQuery({ + queryKey: searchSpaceId + ? permissionRulesQueryKey(searchSpaceId) + : ["agent-permission-rules", "none"], + queryFn: () => agentPermissionsApiService.list(searchSpaceId as number), + enabled: !!searchSpaceId && featureEnabled, + staleTime: 60 * 1000, + }); + + const createMutation = useMutation({ + mutationFn: (payload: AgentPermissionRuleCreate) => + agentPermissionsApiService.create(searchSpaceId as number, payload), + onSuccess: () => { + toast.success("Rule created."); + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to create rule."); + }, + }); + + const updateMutation = useMutation({ + mutationFn: (params: { ruleId: number; action: AgentPermissionAction; pattern?: string }) => + agentPermissionsApiService.update(searchSpaceId as number, params.ruleId, { + action: params.action, + pattern: params.pattern, + }), + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to update rule."); + }, + }); + + const deleteMutation = useMutation({ + mutationFn: (ruleId: number) => + agentPermissionsApiService.remove(searchSpaceId as number, ruleId), + onSuccess: () => { + toast.success("Rule deleted."); + queryClient.invalidateQueries({ + queryKey: permissionRulesQueryKey(searchSpaceId as number), + }); + }, + onError: (err: unknown) => { + toast.error(err instanceof Error ? err.message : "Failed to delete rule."); + }, + }); + + const [showForm, setShowForm] = useState(false); + const [formData, setFormData] = useState(EMPTY_FORM); + const [deleteTarget, setDeleteTarget] = useState(null); + + const sortedRules = useMemo(() => rules ?? [], [rules]); + + const handleCreate = useCallback(async () => { + if (!formData.permission.trim()) { + toast.error("Permission is required."); + return; + } + try { + await createMutation.mutateAsync({ + ...formData, + permission: formData.permission.trim(), + pattern: formData.pattern.trim() || "*", + }); + setShowForm(false); + setFormData(EMPTY_FORM); + } catch (err) { + if (err instanceof AppError && err.message) { + // already toasted by onError + } + } + }, [createMutation, formData]); + + const handleConfirmDelete = useCallback(async () => { + if (deleteTarget === null) return; + try { + await deleteMutation.mutateAsync(deleteTarget); + } finally { + setDeleteTarget(null); + } + }, [deleteMutation, deleteTarget]); + + if (!featureEnabled) { + return ( + + + Permission middleware is disabled + + Flip{" "} + SURFSENSE_ENABLE_PERMISSION on + the backend to manage allow/deny/ask rules from this panel. + + + ); + } + + if (!searchSpaceId) { + return ( +

Open a search space to manage agent rules.

+ ); + } + + if (isLoading) { + return ( +
+ +
+ ); + } + + if (isError) { + return ( +
+ +

Failed to load rules

+

+ {error instanceof Error ? error.message : "Unknown error."} +

+
+ ); + } + + return ( +
+
+
+

+ Tell the agent which tools to allow, deny, or ask before running. Rules use wildcard + patterns and are evaluated at the most specific scope first. +

+
+ {!showForm && ( + + )} +
+ + {showForm && ( +
+
+

New permission rule

+ +
+
+ + setFormData((p) => ({ ...p, permission: e.target.value }))} + /> +

+ Match a tool capability. Use * for wildcards. +

+
+ +
+ + setFormData((p) => ({ ...p, pattern: e.target.value }))} + /> +

+ Wildcard against the canonical argument (e.g. prod-*). +

+
+
+ +
+ + +

+ {ACTION_DESCRIPTIONS[formData.action]} +

+
+ +
+ + +
+
+
+ )} + + {sortedRules.length === 0 && !showForm && ( +
+ +

No rules yet

+

+ Without rules the agent uses the deployment default for every tool. +

+
+ )} + + {sortedRules.length > 0 && ( +
+ {sortedRules.map((rule) => { + const badge = ACTION_BADGE[rule.action]; + const isUpdating = + updateMutation.isPending && updateMutation.variables?.ruleId === rule.id; + const isDeleting = deleteMutation.isPending && deleteMutation.variables === rule.id; + + return ( +
+
+
+
+ + {rule.permission} + + {rule.pattern !== "*" && ( + + → {rule.pattern} + + )} + +
+

+ Created {formatRelativeDate(rule.created_at)} +

+
+ +
+ + + +
+
+
+ ); + })} +
+ )} + + !open && setDeleteTarget(null)} + > + + + Delete this rule? + + The agent will fall back to deployment defaults for matching tool calls. + + + + Cancel + { + e.preventDefault(); + handleConfirmDelete(); + }} + disabled={deleteMutation.isPending} + > + {deleteMutation.isPending ? "Deleting…" : "Delete"} + + + + +
+ ); +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx new file mode 100644 index 000000000..bd8f03a70 --- /dev/null +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/AgentStatusContent.tsx @@ -0,0 +1,309 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { CircleCheck, CircleSlash, Cog, RotateCcw } from "lucide-react"; +import { useMemo } from "react"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import type { AgentFeatureFlags } from "@/lib/apis/agent-flags-api.service"; +import { cn } from "@/lib/utils"; + +type FlagKey = keyof AgentFeatureFlags; + +interface FlagDef { + key: FlagKey; + label: string; + description: string; + envVar: string; +} + +interface FlagGroup { + id: string; + title: string; + subtitle: string; + flags: FlagDef[]; +} + +const FLAG_GROUPS: FlagGroup[] = [ + { + id: "tier1", + title: "Tier 1 — Agent quality", + subtitle: "Context editing, retries, fallbacks, doom-loop, tool-call repair.", + flags: [ + { + key: "enable_context_editing", + label: "Context editing", + description: "Trim tool outputs and spill old text into backend storage.", + envVar: "SURFSENSE_ENABLE_CONTEXT_EDITING", + }, + { + key: "enable_compaction_v2", + label: "Compaction v2", + description: "SurfSense-aware compaction replacing safe summarization.", + envVar: "SURFSENSE_ENABLE_COMPACTION_V2", + }, + { + key: "enable_retry_after", + label: "Retry-After", + description: "Honour rate-limit retry-after headers automatically.", + envVar: "SURFSENSE_ENABLE_RETRY_AFTER", + }, + { + key: "enable_model_fallback", + label: "Model fallback", + description: "Fail over to a backup model on persistent errors.", + envVar: "SURFSENSE_ENABLE_MODEL_FALLBACK", + }, + { + key: "enable_model_call_limit", + label: "Model call limit", + description: "Cap total model calls per turn to prevent budget run-aways.", + envVar: "SURFSENSE_ENABLE_MODEL_CALL_LIMIT", + }, + { + key: "enable_tool_call_limit", + label: "Tool call limit", + description: "Cap total tool calls per turn.", + envVar: "SURFSENSE_ENABLE_TOOL_CALL_LIMIT", + }, + { + key: "enable_tool_call_repair", + label: "Tool-call name repair", + description: "Recover from lower-cased / fuzzy tool names emitted by smaller models.", + envVar: "SURFSENSE_ENABLE_TOOL_CALL_REPAIR", + }, + { + key: "enable_doom_loop", + label: "Doom-loop detection", + description: "Detect repeated identical tool calls and ask the user to confirm.", + envVar: "SURFSENSE_ENABLE_DOOM_LOOP", + }, + ], + }, + { + id: "tier2", + title: "Tier 2 — Safety", + subtitle: "Permission rules, busy-mutex, smarter tool selection.", + flags: [ + { + key: "enable_permission", + label: "Permission middleware", + description: "Apply allow/deny/ask rules from the Agent Permissions tab.", + envVar: "SURFSENSE_ENABLE_PERMISSION", + }, + { + key: "enable_busy_mutex", + label: "Busy mutex", + description: "Prevent two concurrent runs from corrupting the same thread.", + envVar: "SURFSENSE_ENABLE_BUSY_MUTEX", + }, + { + key: "enable_llm_tool_selector", + label: "LLM tool selector", + description: "Use a smaller model to pre-filter the tool list per turn.", + envVar: "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", + }, + ], + }, + { + id: "tier4", + title: "Tier 4 — Skills + subagents", + subtitle: "Built-in skills, specialized subagents, KB planner runnable.", + flags: [ + { + key: "enable_skills", + label: "Skills", + description: "Load on-demand skill packs (kb-research, report-writing, …).", + envVar: "SURFSENSE_ENABLE_SKILLS", + }, + { + key: "enable_specialized_subagents", + label: "Specialized subagents", + description: "Spin up explore / report_writer / connector_negotiator subagents.", + envVar: "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", + }, + { + key: "enable_kb_planner_runnable", + label: "KB planner runnable", + description: "Compile a private planner sub-agent for KB search.", + envVar: "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", + }, + ], + }, + { + id: "tier5", + title: "Tier 5 — Audit + revert", + subtitle: "Action log + revert route used by the Agent Actions sheet.", + flags: [ + { + key: "enable_action_log", + label: "Action log", + description: "Persist every tool call to agent_action_log.", + envVar: "SURFSENSE_ENABLE_ACTION_LOG", + }, + { + key: "enable_revert_route", + label: "Revert route", + description: "Allow reverting reversible actions from the action log.", + envVar: "SURFSENSE_ENABLE_REVERT_ROUTE", + }, + ], + }, + { + id: "tier6", + title: "Tier 6 — Plugins", + subtitle: "Optional middleware loaded from entry points.", + flags: [ + { + key: "enable_plugin_loader", + label: "Plugin loader", + description: "Load surfsense.plugins entry-point middleware.", + envVar: "SURFSENSE_ENABLE_PLUGIN_LOADER", + }, + ], + }, + { + id: "obs", + title: "Observability", + subtitle: "Telemetry pipelines (orthogonal to feature gating).", + flags: [ + { + key: "enable_otel", + label: "OpenTelemetry", + description: "Emit OTel spans (also requires OTEL_EXPORTER_OTLP_ENDPOINT).", + envVar: "SURFSENSE_ENABLE_OTEL", + }, + ], + }, +]; + +function FlagRow({ def, value }: { def: FlagDef; value: boolean }) { + return ( +
+
+
+ {def.label} + + {def.envVar} + +
+

{def.description}

+
+ + {value ? : } + {value ? "On" : "Off"} + +
+ ); +} + +export function AgentStatusContent() { + const { data: flags, isLoading, isError, error, refetch } = useAtomValue(agentFlagsAtom); + + const enabledCount = useMemo(() => { + if (!flags) return 0; + return Object.entries(flags).filter(([k, v]) => k !== "disable_new_agent_stack" && v === true) + .length; + }, [flags]); + + if (isLoading) { + return ( +
+ + + +
+ ); + } + + if (isError || !flags) { + return ( + + Failed to load agent status + + {error instanceof Error ? error.message : "Unknown error."} + + + + ); + } + + const masterOff = flags.disable_new_agent_stack; + + return ( +
+ {masterOff ? ( + + + Master kill-switch is on + + + SURFSENSE_DISABLE_NEW_AGENT_STACK=true + + forces every new middleware off, regardless of the individual flags below. Restart the + backend after changing it. + + + ) : ( + + + + Agent stack + + {enabledCount} on + + + + Read-only mirror of the backend's AgentFeatureFlags. Flip an env var and + restart the backend to change a value. + + + )} + + {FLAG_GROUPS.map((group, groupIdx) => { + const allOff = group.flags.every((f) => !flags[f.key]); + return ( +
+ {groupIdx > 0 && } +
+
+
+

{group.title}

+

{group.subtitle}

+
+ {allOff && ( + + all off + + )} +
+
+ {group.flags.map((def) => ( + + ))} +
+
+
+ ); + })} +
+ ); +} diff --git a/surfsense_web/atoms/agent/action-log-sheet.atom.ts b/surfsense_web/atoms/agent/action-log-sheet.atom.ts new file mode 100644 index 000000000..f88d3ed1e --- /dev/null +++ b/surfsense_web/atoms/agent/action-log-sheet.atom.ts @@ -0,0 +1,19 @@ +import { atom } from "jotai"; + +interface ActionLogSheetState { + open: boolean; + threadId: number | null; +} + +export const actionLogSheetAtom = atom({ + open: false, + threadId: null, +}); + +export const openActionLogSheetAtom = atom(null, (_get, set, threadId: number) => { + set(actionLogSheetAtom, { open: true, threadId }); +}); + +export const closeActionLogSheetAtom = atom(null, (_get, set) => { + set(actionLogSheetAtom, { open: false, threadId: null }); +}); diff --git a/surfsense_web/atoms/agent/agent-flags-query.atom.ts b/surfsense_web/atoms/agent/agent-flags-query.atom.ts new file mode 100644 index 000000000..30158deaa --- /dev/null +++ b/surfsense_web/atoms/agent/agent-flags-query.atom.ts @@ -0,0 +1,17 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { agentFlagsApiService } from "@/lib/apis/agent-flags-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; + +export const AGENT_FLAGS_QUERY_KEY = ["agent", "flags"] as const; + +/** + * Reads the backend agent feature flags. Cached for the lifetime of the + * page (flags only change on backend restart) so we can drive UI gating + * without re-hitting the API. + */ +export const agentFlagsAtom = atomWithQuery(() => ({ + queryKey: AGENT_FLAGS_QUERY_KEY, + staleTime: 10 * 60 * 1000, + enabled: !!getBearerToken(), + queryFn: () => agentFlagsApiService.get(), +})); diff --git a/surfsense_web/atoms/citation/citation-panel.atom.ts b/surfsense_web/atoms/citation/citation-panel.atom.ts new file mode 100644 index 000000000..ca7312857 --- /dev/null +++ b/surfsense_web/atoms/citation/citation-panel.atom.ts @@ -0,0 +1,40 @@ +import { atom } from "jotai"; +import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right-panel.atom"; + +interface CitationPanelState { + isOpen: boolean; + chunkId: number | null; +} + +const initialState: CitationPanelState = { + isOpen: false, + chunkId: null, +}; + +export const citationPanelAtom = atom(initialState); + +export const citationPanelOpenAtom = atom((get) => get(citationPanelAtom).isOpen); + +const preCitationCollapsedAtom = atom(null); + +export const openCitationPanelAtom = atom(null, (get, set, payload: { chunkId: number }) => { + if (!get(citationPanelAtom).isOpen) { + set(preCitationCollapsedAtom, get(rightPanelCollapsedAtom)); + } + set(citationPanelAtom, { + isOpen: true, + chunkId: payload.chunkId, + }); + set(rightPanelTabAtom, "citation"); + set(rightPanelCollapsedAtom, false); +}); + +export const closeCitationPanelAtom = atom(null, (get, set) => { + set(citationPanelAtom, initialState); + set(rightPanelTabAtom, "sources"); + const prev = get(preCitationCollapsedAtom); + if (prev !== null) { + set(rightPanelCollapsedAtom, prev); + set(preCitationCollapsedAtom, null); + } +}); diff --git a/surfsense_web/atoms/layout/right-panel.atom.ts b/surfsense_web/atoms/layout/right-panel.atom.ts index e06500113..d296587ed 100644 --- a/surfsense_web/atoms/layout/right-panel.atom.ts +++ b/surfsense_web/atoms/layout/right-panel.atom.ts @@ -1,6 +1,6 @@ import { atom } from "jotai"; -export type RightPanelTab = "sources" | "report" | "editor" | "hitl-edit"; +export type RightPanelTab = "sources" | "report" | "editor" | "hitl-edit" | "citation"; export const rightPanelTabAtom = atom("sources"); diff --git a/surfsense_web/components/agent-action-log/action-log-button.tsx b/surfsense_web/components/agent-action-log/action-log-button.tsx new file mode 100644 index 000000000..1c0383136 --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-button.tsx @@ -0,0 +1,50 @@ +"use client"; + +import { useAtomValue, useSetAtom } from "jotai"; +import { Activity } from "lucide-react"; +import { useCallback } from "react"; +import { openActionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; + +interface ActionLogButtonProps { + threadId: number | null; +} + +/** + * Header button that opens the agent action log sheet for the current + * thread. Renders nothing when: + * - the action log feature flag is off (graceful no-op for older + * deployments), OR + * - there is no active thread (lazy-created chats haven't started). + */ +export function ActionLogButton({ threadId }: ActionLogButtonProps) { + const { data: flags } = useAtomValue(agentFlagsAtom); + const open = useSetAtom(openActionLogSheetAtom); + + const enabled = !!flags?.enable_action_log && !flags?.disable_new_agent_stack; + + const handleClick = useCallback(() => { + if (threadId !== null) open(threadId); + }, [open, threadId]); + + if (!enabled || threadId === null) return null; + + return ( + + + + + Agent actions + + ); +} diff --git a/surfsense_web/components/agent-action-log/action-log-item.tsx b/surfsense_web/components/agent-action-log/action-log-item.tsx new file mode 100644 index 000000000..425714c1f --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-item.tsx @@ -0,0 +1,215 @@ +"use client"; + +import { ChevronRight, RotateCcw, ShieldOff, Undo2 } from "lucide-react"; +import { useState } from "react"; +import { toast } from "sonner"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Separator } from "@/components/ui/separator"; +import { getToolIcon } from "@/contracts/enums/toolIcons"; +import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { AppError } from "@/lib/error"; +import { formatRelativeDate } from "@/lib/format-date"; +import { cn } from "@/lib/utils"; + +function formatToolName(name: string): string { + return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase()); +} + +interface ActionLogItemProps { + action: AgentAction; + threadId: number; + onRevertSuccess: () => void; +} + +export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogItemProps) { + const [isExpanded, setIsExpanded] = useState(false); + const [isReverting, setIsReverting] = useState(false); + const [confirmOpen, setConfirmOpen] = useState(false); + + const isAlreadyReverted = action.reverted_by_action_id !== null; + const isRevertAction = action.is_revert_action; + const hasError = action.error !== null && action.error !== undefined; + + const Icon = getToolIcon(action.tool_name); + const displayName = formatToolName(action.tool_name); + + const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null; + const truncatedArgs = + argsPreview && argsPreview.length > 600 ? `${argsPreview.slice(0, 600)}…` : argsPreview; + + const canRevert = action.reversible && !isAlreadyReverted && !isRevertAction && !hasError; + + const handleRevert = async () => { + setIsReverting(true); + try { + const response = await agentActionsApiService.revert(threadId, action.id); + toast.success(response.message || "Action reverted successfully."); + onRevertSuccess(); + } catch (err) { + const message = + err instanceof AppError + ? err.message + : err instanceof Error + ? err.message + : "Failed to revert action."; + toast.error(message); + } finally { + setIsReverting(false); + setConfirmOpen(false); + } + }; + + return ( +
+ + + {isExpanded && ( +
+ {truncatedArgs && ( +
+

+ Arguments +

+
+								{truncatedArgs}
+							
+
+ )} + {action.error && ( +
+

+ Error +

+
+								{JSON.stringify(action.error, null, 2)}
+							
+
+ )} + {action.reverse_descriptor && ( +
+

+ Reverse plan +

+
+								{JSON.stringify(action.reverse_descriptor, null, 2)}
+							
+
+ )} + + + +
+

+ Action ID: {action.id} +

+ {canRevert ? ( + + + + + + + Revert this action? + + This will undo {displayName} and append a + new audit entry. The agent's chat history is preserved — only the tool's + effects on your knowledge base or connectors will be reversed where possible. + + + + Cancel + { + e.preventDefault(); + handleRevert(); + }} + disabled={isReverting} + > + {isReverting ? "Reverting…" : "Revert"} + + + + + ) : ( +
+ + {isAlreadyReverted + ? "Already reverted" + : isRevertAction + ? "Revert entry" + : hasError + ? "Cannot revert errored action" + : "Not reversible"} +
+ )} +
+
+ )} +
+ ); +} diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx new file mode 100644 index 000000000..68d2ffef3 --- /dev/null +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -0,0 +1,185 @@ +"use client"; + +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useAtom, useAtomValue } from "jotai"; +import { Activity, RefreshCcw } from "lucide-react"; +import { useCallback, useMemo } from "react"; +import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; +import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Separator } from "@/components/ui/separator"; +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetTitle, +} from "@/components/ui/sheet"; +import { Skeleton } from "@/components/ui/skeleton"; +import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { ActionLogItem } from "./action-log-item"; + +const ACTION_LOG_PAGE_SIZE = 50; + +function actionLogQueryKey(threadId: number) { + return ["agent-actions", threadId] as const; +} + +function EmptyState() { + return ( +
+
+ +
+
+

No actions logged yet

+

+ Once the agent calls a tool in this thread, it will show up here. From the log you can + inspect arguments and revert reversible actions. +

+
+
+ ); +} + +function DisabledState() { + return ( +
+
+ +
+
+

Action log is disabled

+

+ This deployment hasn't enabled the agent action log. An admin can flip + + SURFSENSE_ENABLE_ACTION_LOG + + . +

+
+
+ ); +} + +const SKELETON_KEYS = ["s1", "s2", "s3", "s4"] as const; + +function LoadingState() { + return ( +
+ {SKELETON_KEYS.map((key) => ( + + ))} +
+ ); +} + +export function ActionLogSheet() { + const [state, setState] = useAtom(actionLogSheetAtom); + const queryClient = useQueryClient(); + + const { data: flags } = useAtomValue(agentFlagsAtom); + const actionLogEnabled = !!flags?.enable_action_log && !flags?.disable_new_agent_stack; + const revertEnabled = !!flags?.enable_revert_route && !flags?.disable_new_agent_stack; + + const threadId = state.threadId; + + const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ + queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], + queryFn: () => + agentActionsApiService.listForThread(threadId as number, { + page: 0, + pageSize: ACTION_LOG_PAGE_SIZE, + }), + enabled: state.open && threadId !== null && actionLogEnabled, + staleTime: 15 * 1000, + }); + + const handleRevertSuccess = useCallback(() => { + if (threadId !== null) { + queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); + } + }, [queryClient, threadId]); + + const items = useMemo(() => data?.items ?? [], [data]); + + return ( + setState((s) => ({ ...s, open }))}> + + +
+
+ + Agent actions + {data?.total !== undefined && data.total > 0 && ( + + {data.total} + + )} +
+ +
+ + Audit trail of every tool call the agent made in this thread. + {revertEnabled + ? " Reversible actions can be undone in place." + : " Reverts are read-only on this deployment."} + +
+ + + +
+ {!actionLogEnabled ? ( + + ) : threadId === null ? ( + + ) : isLoading ? ( + + ) : isError ? ( +
+

Failed to load actions

+

+ {error instanceof Error ? error.message : "Unknown error"} +

+ +
+ ) : items.length === 0 ? ( + + ) : ( +
+ {items.map((action) => ( + + ))} + {data?.has_more && ( +

+ Showing {items.length} of {data.total}. Older actions are paginated. +

+ )} +
+ )} +
+
+
+ ); +} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index e7895c2e9..c104f140a 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -314,8 +314,7 @@ export const ConnectorEditView: FC = ({ {connector.is_indexable && (() => { - const isGoogleDrive = - connector.connector_type === "GOOGLE_DRIVE_CONNECTOR"; + const isGoogleDrive = connector.connector_type === "GOOGLE_DRIVE_CONNECTOR"; const isComposioGoogleDrive = connector.connector_type === "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"; const requiresFolderSelection = isGoogleDrive || isComposioGoogleDrive; @@ -327,8 +326,7 @@ export const ConnectorEditView: FC = ({ (connector.config?.selected_files as | Array<{ id: string; name: string }> | undefined) || []; - const hasItemsSelected = - selectedFolders.length > 0 || selectedFiles.length > 0; + const hasItemsSelected = selectedFolders.length > 0 || selectedFiles.length > 0; const isDisabled = requiresFolderSelection && !hasItemsSelected; return ( diff --git a/surfsense_web/components/assistant-ui/inline-citation.tsx b/surfsense_web/components/assistant-ui/inline-citation.tsx index eb4bd9af8..2aeba89ca 100644 --- a/surfsense_web/components/assistant-ui/inline-citation.tsx +++ b/surfsense_web/components/assistant-ui/inline-citation.tsx @@ -1,26 +1,43 @@ "use client"; -import { FileText } from "lucide-react"; +import { useQuery } from "@tanstack/react-query"; +import { useSetAtom } from "jotai"; +import { ExternalLink, FileText } from "lucide-react"; import type { FC } from "react"; -import { useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context"; -import { SourceDetailPanel } from "@/components/new-chat/source-detail-panel"; +import { MarkdownViewer } from "@/components/markdown-viewer"; import { Citation } from "@/components/tool-ui/citation"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Spinner } from "@/components/ui/spinner"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; interface InlineCitationProps { chunkId: number; isDocsChunk?: boolean; } +const POPOVER_HOVER_CLOSE_DELAY_MS = 150; + /** - * Inline citation for knowledge-base chunks (numeric chunk IDs). - * Renders a clickable badge showing the actual chunk ID that opens the SourceDetailPanel. - * Negative chunk IDs indicate anonymous/synthetic uploads and render as a static badge. + * Inline citation badge for knowledge-base chunks (numeric chunk IDs) and + * Surfsense documentation chunks (`isDocsChunk`). Negative chunk IDs render as + * a static "doc" pill (anonymous/synthetic uploads). + * + * Numeric KB chunks: clicking opens the citation panel in the right + * sidebar (alongside the chat — does not replace it). The panel shows + * the cited chunk surrounded by adjacent chunks (via the API's + * `chunk_window`), with the cited one highlighted and an option to + * expand the window or jump into the full document via the editor panel. + * + * Surfsense docs chunks: rendered as a hover-controlled shadcn Popover that + * lazily fetches and previews the cited chunk inline, since those docs aren't + * indexed into the user's search space and have no tab to open. */ export const InlineCitation: FC = ({ chunkId, isDocsChunk = false }) => { - const [isOpen, setIsOpen] = useState(false); - if (chunkId < 0) { return ( @@ -38,26 +55,131 @@ export const InlineCitation: FC = ({ chunkId, isDocsChunk = ); } + if (isDocsChunk) { + return ; + } + + return ; +}; + +const NumericChunkCitation: FC<{ chunkId: number }> = ({ chunkId }) => { + const openCitationPanel = useSetAtom(openCitationPanelAtom); + return ( - openCitationPanel({ chunkId })} + className="ml-0.5 inline-flex h-5 min-w-5 cursor-pointer items-center justify-center rounded-md bg-muted/60 px-1.5 text-[11px] font-medium text-muted-foreground align-baseline shadow-sm transition-colors hover:bg-muted hover:text-foreground focus-visible:ring-ring focus-visible:ring-2 focus-visible:outline-none" + title={`View source chunk #${chunkId}`} + aria-label={`View cited chunk ${chunkId}`} > - + ); +}; + +const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => { + const [open, setOpen] = useState(false); + const closeTimerRef = useRef | null>(null); + + const cancelClose = useCallback(() => { + if (closeTimerRef.current) { + clearTimeout(closeTimerRef.current); + closeTimerRef.current = null; + } + }, []); + + const scheduleClose = useCallback(() => { + cancelClose(); + closeTimerRef.current = setTimeout(() => { + setOpen(false); + closeTimerRef.current = null; + }, POPOVER_HOVER_CLOSE_DELAY_MS); + }, [cancelClose]); + + useEffect(() => () => cancelClose(), [cancelClose]); + + const { data, isLoading, error } = useQuery({ + queryKey: cacheKeys.documents.byChunk(`doc-${chunkId}`), + queryFn: () => documentsApiService.getSurfsenseDocByChunk(chunkId), + enabled: open, + staleTime: 5 * 60 * 1000, + }); + + const citedChunk = data?.chunks.find((c) => c.id === chunkId) ?? data?.chunks[0]; + + return ( + + + + + e.preventDefault()} > - {chunkId} - - +
+
+

+ {data?.title ?? "Surfsense documentation"} +

+

Chunk #{chunkId}

+
+ {data?.source && ( + + + Open + + )} +
+
+ {isLoading && ( +
+ + Loading… +
+ )} + {error && ( +

+ {error instanceof Error ? error.message : "Failed to load chunk"} +

+ )} + {!isLoading && !error && citedChunk?.content && ( + + )} + {!isLoading && !error && !citedChunk?.content && ( +

No content available.

+ )} +
+ + ); }; diff --git a/surfsense_web/components/assistant-ui/markdown-text.tsx b/surfsense_web/components/assistant-ui/markdown-text.tsx index f8abed486..7655e10cc 100644 --- a/surfsense_web/components/assistant-ui/markdown-text.tsx +++ b/surfsense_web/components/assistant-ui/markdown-text.tsx @@ -85,10 +85,13 @@ function preprocessMarkdown(content: string): string { } ); + // All math forms are normalised to $$...$$ so we can disable single-dollar + // inline math in remark-math (otherwise currency like "$3,120.00 and $0.00" + // gets parsed as a LaTeX expression). // 1. Block math: \[...\] → $$...$$ content = content.replace(/\\\[([\s\S]*?)\\\]/g, (_, inner) => `$$${inner}$$`); - // 2. Inline math: \(...\) → $...$ - content = content.replace(/\\\(([\s\S]*?)\\\)/g, (_, inner) => `$${inner}$`); + // 2. Inline math: \(...\) → $$...$$ + content = content.replace(/\\\(([\s\S]*?)\\\)/g, (_, inner) => `$$${inner}$$`); // 3. Block: \begin{equation}...\end{equation} → $$...$$ content = content.replace( /\\begin\{equation\}([\s\S]*?)\\end\{equation\}/g, @@ -99,8 +102,11 @@ function preprocessMarkdown(content: string): string { /\\begin\{displaymath\}([\s\S]*?)\\end\{displaymath\}/g, (_, inner) => `$$${inner}$$` ); - // 5. Inline: \begin{math}...\end{math} → $...$ - content = content.replace(/\\begin\{math\}([\s\S]*?)\\end\{math\}/g, (_, inner) => `$${inner}$`); + // 5. Inline: \begin{math}...\end{math} → $$...$$ + content = content.replace( + /\\begin\{math\}([\s\S]*?)\\end\{math\}/g, + (_, inner) => `$$${inner}$$` + ); // 6. Strip backtick wrapping around math: `$$...$$` → $$...$$ and `$...$` → $...$ content = content.replace(/`(\${1,2})((?:(?!\1).)+)\1`/g, "$1$2$1"); @@ -180,7 +186,7 @@ const MarkdownTextImpl = () => { return ( { if (isInterruptResult(props.result)) { + if (isDoomLoopInterrupt(props.result)) { + return ; + } return ; } return ; diff --git a/surfsense_web/components/citation-panel/citation-panel.tsx b/surfsense_web/components/citation-panel/citation-panel.tsx new file mode 100644 index 000000000..cec07b9cf --- /dev/null +++ b/surfsense_web/components/citation-panel/citation-panel.tsx @@ -0,0 +1,230 @@ +"use client"; + +import { useQuery } from "@tanstack/react-query"; +import { useSetAtom } from "jotai"; +import { ChevronDown, ChevronUp, ExternalLink, XIcon } from "lucide-react"; +import type { FC } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; +import { MarkdownViewer } from "@/components/markdown-viewer"; +import { Button } from "@/components/ui/button"; +import { Spinner } from "@/components/ui/spinner"; +import { documentsApiService } from "@/lib/apis/documents-api.service"; + +const DEFAULT_CHUNK_WINDOW = 5; +const EXPANDED_CHUNK_WINDOW = 50; + +interface CitationPanelContentProps { + chunkId: number; + onClose?: () => void; +} + +/** + * Right-panel citation viewer. Shows the cited chunk surrounded by + * adjacent chunks (±N chunks via the API's `chunk_window` parameter), + * with the cited one visually highlighted and auto-scrolled into view. + * The window can be expanded to a wider range, or the user can jump to + * the full document via the editor panel. + */ +export const CitationPanelContent: FC = ({ chunkId, onClose }) => { + const openEditorPanel = useSetAtom(openEditorPanelAtom); + const [expanded, setExpanded] = useState(false); + + useEffect(() => { + setExpanded(false); + }, []); + + const chunkWindow = expanded ? EXPANDED_CHUNK_WINDOW : DEFAULT_CHUNK_WINDOW; + + const { data, isLoading, error } = useQuery({ + queryKey: ["citation-panel", chunkId, chunkWindow] as const, + queryFn: () => + documentsApiService.getDocumentByChunk({ + chunk_id: chunkId, + chunk_window: chunkWindow, + }), + staleTime: 5 * 60 * 1000, + }); + + const cited = useMemo(() => data?.chunks.find((c) => c.id === chunkId) ?? null, [data, chunkId]); + + const totalChunks = data?.total_chunks ?? data?.chunks.length ?? 0; + const startIndex = data?.chunk_start_index ?? 0; + const citedIndexInWindow = data + ? Math.max( + 0, + data.chunks.findIndex((c) => c.id === chunkId) + ) + : 0; + const shownAbove = citedIndexInWindow; + const shownBelow = data ? Math.max(0, data.chunks.length - 1 - citedIndexInWindow) : 0; + const hasMoreAbove = startIndex > 0; + const hasMoreBelow = data ? startIndex + data.chunks.length < totalChunks : false; + + // Scroll the cited chunk into view inside the panel's scroll container + // (not the page). We anchor the scroll to the panel's scroll element + // so opening the citation doesn't yank the chat scroll on the left. + const scrollContainerRef = useRef(null); + const citedRef = useRef(null); + useEffect(() => { + if (!cited) return; + const id = requestAnimationFrame(() => { + const container = scrollContainerRef.current; + const target = citedRef.current; + if (!container || !target) return; + const containerRect = container.getBoundingClientRect(); + const targetRect = target.getBoundingClientRect(); + const offset = targetRect.top - containerRect.top + container.scrollTop; + container.scrollTo({ + top: Math.max(0, offset - 16), + behavior: "smooth", + }); + }); + return () => cancelAnimationFrame(id); + }, [cited]); + + const handleOpenFullDocument = () => { + if (!data) return; + openEditorPanel({ + documentId: data.id, + searchSpaceId: data.search_space_id, + title: data.title, + }); + }; + + return ( + <> +
+
+

Citation

+
+ {onClose && ( + + )} +
+
+
+
+

+ {data?.title ?? (isLoading ? "Loading…" : `Chunk #${chunkId}`)} +

+
+
+ Chunk #{chunkId} + {totalChunks > 0 && · {totalChunks} chunks} +
+
+
+ +
+ {isLoading && ( +
+ + Loading citation… +
+ )} + + {error && ( +

+ {error instanceof Error ? error.message : "Failed to load citation"} +

+ )} + + {!isLoading && !error && data && ( + <> + {hasMoreAbove && ( +

+ … {startIndex} earlier chunk{startIndex === 1 ? "" : "s"} not shown +

+ )} +
+ {data.chunks.map((chunk) => { + const isCited = chunk.id === chunkId; + return ( +
+
+ + {isCited ? "Cited chunk" : `Chunk #${chunk.id}`} + + {isCited && ( + #{chunk.id} + )} +
+
+ +
+
+ ); + })} +
+ {hasMoreBelow && ( +

+ … {totalChunks - (startIndex + data.chunks.length)} later chunk + {totalChunks - (startIndex + data.chunks.length) === 1 ? "" : "s"} not shown +

+ )} + + )} +
+ + {!isLoading && !error && data && ( +
+
+ Showing {shownAbove} above · cited · {shownBelow} below +
+
+ {(hasMoreAbove || hasMoreBelow) && !expanded && ( + + )} + {expanded && ( + + )} + +
+
+ )} + + ); +}; diff --git a/surfsense_web/components/editor-panel/editor-panel.tsx b/surfsense_web/components/editor-panel/editor-panel.tsx index 3b69ae6e0..df138e97e 100644 --- a/surfsense_web/components/editor-panel/editor-panel.tsx +++ b/surfsense_web/components/editor-panel/editor-panel.tsx @@ -136,6 +136,7 @@ export function EditorPanelContent({ const [displayTitle, setDisplayTitle] = useState(title || "Untitled"); const isLocalFileMode = kind === "local_file"; const editorRenderMode: EditorRenderMode = isLocalFileMode ? "source_code" : "rich_markdown"; + const resolveLocalVirtualPath = useCallback( async (candidatePath: string): Promise => { if (!electronAPI?.getAgentFilesystemMounts) { @@ -291,7 +292,7 @@ export function EditorPanelContent({ }, [editorDoc?.source_markdown]); const handleSave = useCallback( - async (_options?: { silent?: boolean }) => { + async (options?: { silent?: boolean }) => { setSaving(true); try { if (isLocalFileMode) { @@ -342,11 +343,15 @@ export function EditorPanelContent({ setEditorDoc((prev) => (prev ? { ...prev, source_markdown: markdownRef.current } : prev)); setEditedMarkdown(null); - toast.success("Document saved! Reindexing in background..."); + if (!options?.silent) { + toast.success("Document saved! Reindexing in background..."); + } return true; } catch (err) { console.error("Error saving document:", err); - toast.error(err instanceof Error ? err.message : "Failed to save document"); + if (!options?.silent) { + toast.error(err instanceof Error ? err.message : "Failed to save document"); + } return false; } finally { setSaving(false); @@ -367,6 +372,11 @@ export function EditorPanelContent({ EDITABLE_DOCUMENT_TYPES.has(editorDoc.document_type ?? "")) && !isLargeDocument : false; + // Render through PlateEditor for editable doc types (FILE/NOTE). + // Everything else (large docs, non-editable types) falls back to the + // lightweight `MarkdownViewer` — Plate is heavy on multi-MB docs and + // non-editable types don't benefit from its editing UX. + const renderInPlateEditor = isEditableType; const hasUnsavedChanges = editedMarkdown !== null; const showDesktopHeader = !!onClose; const showEditingActions = isEditableType && isEditing; @@ -381,6 +391,60 @@ export function EditorPanelContent({ setIsEditing(false); }, [editorDoc?.source_markdown]); + const handleDownloadMarkdown = useCallback(async () => { + if (!searchSpaceId || !documentId) return; + setDownloading(true); + try { + const response = await authenticatedFetch( + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/documents/${documentId}/download-markdown`, + { method: "GET" } + ); + if (!response.ok) throw new Error("Download failed"); + const blob = await response.blob(); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + const disposition = response.headers.get("content-disposition"); + const match = disposition?.match(/filename="(.+)"/); + a.download = match?.[1] ?? `${editorDoc?.title || "document"}.md`; + document.body.appendChild(a); + a.click(); + a.remove(); + URL.revokeObjectURL(url); + toast.success("Download started"); + } catch { + toast.error("Failed to download document"); + } finally { + setDownloading(false); + } + }, [documentId, editorDoc?.title, searchSpaceId]); + + const largeDocAlert = isLargeDocument && !isLocalFileMode && editorDoc && ( + + + + + This document is too large for the editor ( + {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} + {editorDoc.chunk_count ?? 0} chunks). Showing a preview below. + + + + + ); + return ( <> {showDesktopHeader ? ( @@ -565,61 +629,6 @@ export function EditorPanelContent({

- ) : isLargeDocument && !isLocalFileMode ? ( -
- - - - - This document is too large for the editor ( - {Math.round((editorDoc.content_size_bytes ?? 0) / 1024 / 1024)}MB,{" "} - {editorDoc.chunk_count ?? 0} chunks). Showing a preview below. - - - - - -
) : editorRenderMode === "source_code" ? (
- ) : isEditableType ? ( - + ) : isLargeDocument && !isLocalFileMode ? ( + // Large doc — fast Streamdown preview + download CTA. + // Plate is heavy on multi-MB docs. +
+ {largeDocAlert} + +
+ ) : renderInPlateEditor ? ( + // Editable doc (FILE/NOTE) — Plate editing UX. +
+
+ +
+
) : (
diff --git a/surfsense_web/components/editor/plate-editor.tsx b/surfsense_web/components/editor/plate-editor.tsx index 481a420fb..7f12d3cae 100644 --- a/surfsense_web/components/editor/plate-editor.tsx +++ b/surfsense_web/components/editor/plate-editor.tsx @@ -12,6 +12,9 @@ import { type EditorPreset, presetMap } from "@/components/editor/presets"; import { escapeMdxExpressions } from "@/components/editor/utils/escape-mdx"; import { Editor, EditorContainer } from "@/components/ui/editor"; +/** Live editor instance returned by `usePlateEditor`. */ +export type PlateEditorInstance = ReturnType; + export interface PlateEditorProps { /** Markdown string to load as initial content */ markdown?: string; diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index aecf55a27..3efdab03b 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -28,6 +28,7 @@ import { import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { SearchSpaceSettingsDialog } from "@/components/settings/search-space-settings-dialog"; import { TeamDialog } from "@/components/settings/team-dialog"; +import { ActionLogSheet } from "@/components/agent-action-log/action-log-sheet"; import { UserSettingsDialog } from "@/components/settings/user-settings-dialog"; import { AlertDialog, @@ -909,6 +910,9 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid + + {/* Agent action log + revert sheet */} + ); } diff --git a/surfsense_web/components/layout/ui/header/Header.tsx b/surfsense_web/components/layout/ui/header/Header.tsx index ec54cb901..f49d7fb88 100644 --- a/surfsense_web/components/layout/ui/header/Header.tsx +++ b/surfsense_web/components/layout/ui/header/Header.tsx @@ -5,6 +5,7 @@ import { usePathname } from "next/navigation"; import { currentThreadAtom } from "@/atoms/chat/current-thread.atom"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeTabAtom, tabsAtom } from "@/atoms/tabs/tabs.atom"; +import { ActionLogButton } from "@/components/agent-action-log/action-log-button"; import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatShareButton } from "@/components/new-chat/chat-share-button"; import { useIsMobile } from "@/hooks/use-mobile"; @@ -69,6 +70,7 @@ export function Header({ mobileMenuTrigger }: HeaderProps) { {/* Right side - Actions */}
+ {hasThread && } {hasThread && ( )} diff --git a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx index 04bae010c..3481eec28 100644 --- a/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx +++ b/surfsense_web/components/layout/ui/right-panel/RightPanel.tsx @@ -6,6 +6,7 @@ import dynamic from "next/dynamic"; import { startTransition, useEffect } from "react"; import { closeHitlEditPanelAtom, hitlEditPanelAtom } from "@/atoms/chat/hitl-edit-panel.atom"; import { closeReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom"; +import { citationPanelAtom, closeCitationPanelAtom } from "@/atoms/citation/citation-panel.atom"; import { documentsSidebarOpenAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom, editorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { rightPanelCollapsedAtom, rightPanelTabAtom } from "@/atoms/layout/right-panel.atom"; @@ -21,6 +22,14 @@ const EditorPanelContent = dynamic( { ssr: false, loading: () => null } ); +const CitationPanelContent = dynamic( + () => + import("@/components/citation-panel/citation-panel").then((m) => ({ + default: m.CitationPanelContent, + })), + { ssr: false, loading: () => null } +); + const HitlEditPanelContent = dynamic( () => import("@/components/hitl-edit-panel/hitl-edit-panel").then((m) => ({ @@ -69,12 +78,14 @@ export function RightPanelExpandButton() { const reportState = useAtomValue(reportPanelAtom); const editorState = useAtomValue(editorPanelAtom); const hitlEditState = useAtomValue(hitlEditPanelAtom); + const citationState = useAtomValue(citationPanelAtom); const reportOpen = reportState.isOpen && !!reportState.reportId; const editorOpen = editorState.isOpen && (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; - const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen; + const citationOpen = citationState.isOpen && citationState.chunkId != null; + const hasContent = documentsOpen || reportOpen || editorOpen || hitlEditOpen || citationOpen; if (!collapsed || !hasContent) return null; @@ -98,7 +109,13 @@ export function RightPanelExpandButton() { ); } -const PANEL_WIDTHS = { sources: 420, report: 640, editor: 640, "hitl-edit": 640 } as const; +const PANEL_WIDTHS = { + sources: 420, + report: 640, + editor: 640, + "hitl-edit": 640, + citation: 560, +} as const; export function RightPanel({ documentsPanel }: RightPanelProps) { const [activeTab] = useAtom(rightPanelTabAtom); @@ -108,6 +125,8 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { const closeEditor = useSetAtom(closeEditorPanelAtom); const hitlEditState = useAtomValue(hitlEditPanelAtom); const closeHitlEdit = useSetAtom(closeHitlEditPanelAtom); + const citationState = useAtomValue(citationPanelAtom); + const closeCitation = useSetAtom(closeCitationPanelAtom); const [collapsed, setCollapsed] = useAtom(rightPanelCollapsedAtom); const documentsOpen = documentsPanel?.open ?? false; @@ -116,37 +135,59 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { editorState.isOpen && (editorState.kind === "document" ? !!editorState.documentId : !!editorState.localFilePath); const hitlEditOpen = hitlEditState.isOpen && !!hitlEditState.onSave; + const citationOpen = citationState.isOpen && citationState.chunkId != null; useEffect(() => { - if (!reportOpen && !editorOpen && !hitlEditOpen) return; + if (!reportOpen && !editorOpen && !hitlEditOpen && !citationOpen) return; const handleKeyDown = (e: KeyboardEvent) => { if (e.key === "Escape") { if (hitlEditOpen) closeHitlEdit(); + else if (citationOpen) closeCitation(); else if (editorOpen) closeEditor(); else if (reportOpen) closeReport(); } }; document.addEventListener("keydown", handleKeyDown); return () => document.removeEventListener("keydown", handleKeyDown); - }, [reportOpen, editorOpen, hitlEditOpen, closeReport, closeEditor, closeHitlEdit]); + }, [ + reportOpen, + editorOpen, + hitlEditOpen, + citationOpen, + closeReport, + closeEditor, + closeHitlEdit, + closeCitation, + ]); - const isVisible = (documentsOpen || reportOpen || editorOpen || hitlEditOpen) && !collapsed; + const isVisible = + (documentsOpen || reportOpen || editorOpen || hitlEditOpen || citationOpen) && !collapsed; let effectiveTab = activeTab; if (effectiveTab === "hitl-edit" && !hitlEditOpen) { - effectiveTab = editorOpen ? "editor" : reportOpen ? "report" : "sources"; - } else if (effectiveTab === "editor" && !editorOpen) { - effectiveTab = reportOpen ? "report" : "sources"; - } else if (effectiveTab === "report" && !reportOpen) { - effectiveTab = editorOpen ? "editor" : "sources"; - } else if (effectiveTab === "sources" && !documentsOpen) { - effectiveTab = hitlEditOpen - ? "hitl-edit" + effectiveTab = citationOpen + ? "citation" : editorOpen ? "editor" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "citation" && !citationOpen) { + effectiveTab = editorOpen ? "editor" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "editor" && !editorOpen) { + effectiveTab = citationOpen ? "citation" : reportOpen ? "report" : "sources"; + } else if (effectiveTab === "report" && !reportOpen) { + effectiveTab = citationOpen ? "citation" : editorOpen ? "editor" : "sources"; + } else if (effectiveTab === "sources" && !documentsOpen) { + effectiveTab = hitlEditOpen + ? "hitl-edit" + : citationOpen + ? "citation" + : editorOpen + ? "editor" + : reportOpen + ? "report" + : "sources"; } const targetWidth = PANEL_WIDTHS[effectiveTab]; @@ -205,6 +246,11 @@ export function RightPanel({ documentsPanel }: RightPanelProps) { />
)} + {effectiveTab === "citation" && citationOpen && citationState.chunkId != null && ( +
+ +
+ )}
); diff --git a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx index dd7520d24..cd8fca331 100644 --- a/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx +++ b/surfsense_web/components/layout/ui/sidebar/DesktopLocalTabContent.tsx @@ -1,11 +1,9 @@ "use client"; -import { Folder, FolderPlus, Search, X } from "lucide-react"; import { useAtom } from "jotai"; +import { Folder, FolderPlus, Search, X } from "lucide-react"; import { useCallback, useMemo, useRef, useState } from "react"; import { localExpandedFolderKeysAtom } from "@/atoms/documents/folder.atoms"; -import { Input } from "@/components/ui/input"; -import { Separator } from "@/components/ui/separator"; import { DropdownMenu, DropdownMenuContent, @@ -14,6 +12,8 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { useDebouncedValue } from "@/hooks/use-debounced-value"; import { LocalFilesystemBrowser } from "./LocalFilesystemBrowser"; diff --git a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx index 77668a93d..ac5463873 100644 --- a/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx +++ b/surfsense_web/components/layout/ui/tabs/DocumentTabContent.tsx @@ -308,9 +308,7 @@ export function DocumentTabContent({ documentId, searchSpaceId, title }: Documen } }} > - + Download .md diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index 5775fe083..c4d73e30b 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -10,7 +10,11 @@ const code = createCodePlugin({ }); const math = createMathPlugin({ - singleDollarTextMath: true, + // Disabled so currency like "$3,120.00 and ... $0.00" isn't parsed as + // inline LaTeX. convertLatexDelimiters() below normalises any genuine + // inline math (\(...\), $...$ starting with a LaTeX command, etc.) to + // $$...$$, so this flip doesn't lose any math rendering. + singleDollarTextMath: false, }); interface MarkdownViewerProps { diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 3f5a5fa8c..9fe9dd8da 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -8,9 +8,9 @@ import { ChevronLeft, ChevronRight, ChevronUp, - Pencil, ImageIcon, Layers, + Pencil, Plus, ScanEye, Search, @@ -741,9 +741,7 @@ export function ModelSelector({
{!isMobile && ( @@ -769,9 +767,7 @@ export function ModelSelector({
diff --git a/surfsense_web/components/new-chat/source-detail-panel.tsx b/surfsense_web/components/new-chat/source-detail-panel.tsx deleted file mode 100644 index aded206c7..000000000 --- a/surfsense_web/components/new-chat/source-detail-panel.tsx +++ /dev/null @@ -1,719 +0,0 @@ -"use client"; - -import { useQuery } from "@tanstack/react-query"; -import { - BookOpen, - ChevronDown, - ChevronUp, - ExternalLink, - FileQuestionMark, - FileText, - Hash, - Loader2, - Sparkles, - X, -} from "lucide-react"; -import { AnimatePresence, motion, useReducedMotion } from "motion/react"; -import { useTranslations } from "next-intl"; -import type React from "react"; -import { forwardRef, memo, type ReactNode, useCallback, useEffect, useRef, useState } from "react"; -import { createPortal } from "react-dom"; -import { MarkdownViewer } from "@/components/markdown-viewer"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { ScrollArea } from "@/components/ui/scroll-area"; -import { Spinner } from "@/components/ui/spinner"; -import type { - GetDocumentByChunkResponse, - GetSurfsenseDocsByChunkResponse, -} from "@/contracts/types/document.types"; -import { documentsApiService } from "@/lib/apis/documents-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { cn } from "@/lib/utils"; - -type DocumentData = GetDocumentByChunkResponse | GetSurfsenseDocsByChunkResponse; - -interface SourceDetailPanelProps { - open: boolean; - onOpenChange: (open: boolean) => void; - chunkId: number; - sourceType: string; - title: string; - description?: string; - url?: string; - children?: ReactNode; - isDocsChunk?: boolean; -} - -const formatDocumentType = (type: string) => { - if (!type) return ""; - return type - .split("_") - .map((word) => word.charAt(0) + word.slice(1).toLowerCase()) - .join(" "); -}; - -// Chunk card component -// For large documents (>30 chunks), we disable animation to prevent layout shifts -// which break auto-scroll functionality -interface ChunkCardProps { - chunk: { id: number; content: string }; - localIndex: number; - chunkNumber: number; - totalChunks: number; - isCited: boolean; - isActive: boolean; - disableLayoutAnimation?: boolean; -} - -const ChunkCard = memo( - forwardRef( - ({ chunk, localIndex, chunkNumber, totalChunks, isCited }, ref) => { - return ( -
- {isCited &&
} - -
-
-
- {chunkNumber} -
- - Chunk {chunkNumber} of {totalChunks} - -
- {isCited && ( - - - Cited Source - - )} -
- -
- -
-
- ); - } - ) -); -ChunkCard.displayName = "ChunkCard"; - -export function SourceDetailPanel({ - open, - onOpenChange, - chunkId, - sourceType, - title, - description, - url, - children, - isDocsChunk = false, -}: SourceDetailPanelProps) { - const t = useTranslations("dashboard"); - const scrollAreaRef = useRef(null); - const hasScrolledRef = useRef(false); // Use ref to avoid stale closures - const scrollTimersRef = useRef[]>([]); - const [activeChunkIndex, setActiveChunkIndex] = useState(null); - const [mounted, setMounted] = useState(false); - const shouldReduceMotion = useReducedMotion(); - - useEffect(() => { - setMounted(true); - }, []); - - const { - data: documentData, - isLoading: isDocumentByChunkFetching, - error: documentByChunkFetchingError, - } = useQuery({ - queryKey: isDocsChunk - ? cacheKeys.documents.byChunk(`doc-${chunkId}`) - : cacheKeys.documents.byChunk(chunkId.toString()), - queryFn: async () => { - if (isDocsChunk) { - return documentsApiService.getSurfsenseDocByChunk(chunkId); - } - return documentsApiService.getDocumentByChunk({ chunk_id: chunkId, chunk_window: 5 }); - }, - enabled: !!chunkId && open, - staleTime: 5 * 60 * 1000, - }); - - const totalChunks = - documentData && "total_chunks" in documentData - ? (documentData.total_chunks ?? documentData.chunks.length) - : (documentData?.chunks?.length ?? 0); - const [beforeChunks, setBeforeChunks] = useState< - Array<{ id: number; content: string; created_at: string }> - >([]); - const [afterChunks, setAfterChunks] = useState< - Array<{ id: number; content: string; created_at: string }> - >([]); - const [loadingBefore, setLoadingBefore] = useState(false); - const [loadingAfter, setLoadingAfter] = useState(false); - - useEffect(() => { - setBeforeChunks([]); - setAfterChunks([]); - }, [chunkId, open]); - - const chunkStartIndex = - documentData && "chunk_start_index" in documentData ? (documentData.chunk_start_index ?? 0) : 0; - const initialChunks = documentData?.chunks ?? []; - const allChunks = [...beforeChunks, ...initialChunks, ...afterChunks]; - const absoluteStart = chunkStartIndex - beforeChunks.length; - const absoluteEnd = chunkStartIndex + initialChunks.length + afterChunks.length; - const canLoadBefore = absoluteStart > 0; - const canLoadAfter = absoluteEnd < totalChunks; - - const EXPAND_SIZE = 10; - - const loadBefore = useCallback(async () => { - if (!documentData || !("search_space_id" in documentData) || !canLoadBefore) return; - setLoadingBefore(true); - try { - const count = Math.min(EXPAND_SIZE, absoluteStart); - const result = await documentsApiService.getDocumentChunks({ - document_id: documentData.id, - page: 0, - page_size: count, - start_offset: absoluteStart - count, - }); - const existingIds = new Set(allChunks.map((c) => c.id)); - const newChunks = result.items - .filter((c) => !existingIds.has(c.id)) - .map((c) => ({ id: c.id, content: c.content, created_at: c.created_at })); - setBeforeChunks((prev) => [...newChunks, ...prev]); - } catch (err) { - console.error("Failed to load earlier chunks:", err); - } finally { - setLoadingBefore(false); - } - }, [documentData, absoluteStart, canLoadBefore, allChunks]); - - const loadAfter = useCallback(async () => { - if (!documentData || !("search_space_id" in documentData) || !canLoadAfter) return; - setLoadingAfter(true); - try { - const result = await documentsApiService.getDocumentChunks({ - document_id: documentData.id, - page: 0, - page_size: EXPAND_SIZE, - start_offset: absoluteEnd, - }); - const existingIds = new Set(allChunks.map((c) => c.id)); - const newChunks = result.items - .filter((c) => !existingIds.has(c.id)) - .map((c) => ({ id: c.id, content: c.content, created_at: c.created_at })); - setAfterChunks((prev) => [...prev, ...newChunks]); - } catch (err) { - console.error("Failed to load later chunks:", err); - } finally { - setLoadingAfter(false); - } - }, [documentData, absoluteEnd, canLoadAfter, allChunks]); - - const isDirectRenderSource = - sourceType === "TAVILY_API" || - sourceType === "LINKUP_API" || - sourceType === "SEARXNG_API" || - sourceType === "BAIDU_SEARCH_API"; - - const citedChunkIndex = allChunks.findIndex((chunk) => chunk.id === chunkId); - - // Simple scroll function that scrolls to a chunk by index - const scrollToChunkByIndex = useCallback( - (chunkIndex: number, smooth = true) => { - const scrollContainer = scrollAreaRef.current; - if (!scrollContainer) return; - - const viewport = scrollContainer.querySelector( - "[data-radix-scroll-area-viewport]" - ) as HTMLElement | null; - if (!viewport) return; - - const chunkElement = scrollContainer.querySelector( - `[data-chunk-index="${chunkIndex}"]` - ) as HTMLElement | null; - if (!chunkElement) return; - - // Get positions using getBoundingClientRect for accuracy - const viewportRect = viewport.getBoundingClientRect(); - const chunkRect = chunkElement.getBoundingClientRect(); - - // Calculate where to scroll to center the chunk - const currentScrollTop = viewport.scrollTop; - const chunkTopRelativeToViewport = chunkRect.top - viewportRect.top + currentScrollTop; - const scrollTarget = - chunkTopRelativeToViewport - viewportRect.height / 2 + chunkRect.height / 2; - - viewport.scrollTo({ - top: Math.max(0, scrollTarget), - behavior: smooth && !shouldReduceMotion ? "smooth" : "auto", - }); - - setActiveChunkIndex(chunkIndex); - }, - [shouldReduceMotion] - ); - - // Callback ref for the cited chunk - scrolls when the element mounts - const citedChunkRefCallback = useCallback( - (node: HTMLDivElement | null) => { - if (node && !hasScrolledRef.current && open) { - hasScrolledRef.current = true; // Mark immediately to prevent duplicate scrolls - - // Store the node reference for the delayed scroll - const scrollToCitedChunk = () => { - const scrollContainer = scrollAreaRef.current; - if (!scrollContainer || !node.isConnected) return false; - - const viewport = scrollContainer.querySelector( - "[data-radix-scroll-area-viewport]" - ) as HTMLElement | null; - if (!viewport) return false; - - // Get positions - const viewportRect = viewport.getBoundingClientRect(); - const chunkRect = node.getBoundingClientRect(); - - // Calculate scroll position to center the chunk - const currentScrollTop = viewport.scrollTop; - const chunkTopRelativeToViewport = chunkRect.top - viewportRect.top + currentScrollTop; - const scrollTarget = - chunkTopRelativeToViewport - viewportRect.height / 2 + chunkRect.height / 2; - - viewport.scrollTo({ - top: Math.max(0, scrollTarget), - behavior: "auto", // Instant scroll for initial positioning - }); - - return true; - }; - - // Scroll multiple times with delays to handle progressive content rendering - // Each subsequent scroll will correct for any layout shifts - const scrollAttempts = [50, 150, 300, 600, 1000]; - - scrollAttempts.forEach((delay) => { - scrollTimersRef.current.push( - setTimeout(() => { - scrollToCitedChunk(); - }, delay) - ); - }); - - // After final attempt, mark the cited chunk as active - scrollTimersRef.current.push( - setTimeout( - () => { - setActiveChunkIndex(citedChunkIndex); - }, - scrollAttempts[scrollAttempts.length - 1] + 50 - ) - ); - } - }, - [open, citedChunkIndex] - ); - - // Reset scroll state when panel closes - useEffect(() => { - if (!open) { - scrollTimersRef.current.forEach(clearTimeout); - scrollTimersRef.current = []; - hasScrolledRef.current = false; - setActiveChunkIndex(null); - } - return () => { - scrollTimersRef.current.forEach(clearTimeout); - scrollTimersRef.current = []; - }; - }, [open]); - - // Handle escape key - useEffect(() => { - const handleEscape = (e: KeyboardEvent) => { - if (e.key === "Escape" && open) { - onOpenChange(false); - } - }; - window.addEventListener("keydown", handleEscape); - return () => window.removeEventListener("keydown", handleEscape); - }, [open, onOpenChange]); - - // Prevent body scroll when open - useEffect(() => { - if (open) { - document.body.style.overflow = "hidden"; - } else { - document.body.style.overflow = ""; - } - return () => { - document.body.style.overflow = ""; - }; - }, [open]); - - const handleUrlClick = (e: React.MouseEvent, clickUrl: string) => { - e.preventDefault(); - e.stopPropagation(); - window.open(clickUrl, "_blank", "noopener,noreferrer"); - }; - - const scrollToChunk = useCallback( - (index: number) => { - scrollToChunkByIndex(index, true); - }, - [scrollToChunkByIndex] - ); - - const panelContent = ( - - {open && ( - <> - {/* Backdrop */} - onOpenChange(false)} - /> - - {/* Panel */} - - {/* Header */} - -
-

- {documentData?.title || title || "Source Document"} -

-

- {documentData && "document_type" in documentData - ? formatDocumentType(documentData.document_type) - : sourceType && formatDocumentType(sourceType)} - {totalChunks > 0 && ( - - • {totalChunks} chunk{totalChunks !== 1 ? "s" : ""} - {allChunks.length < totalChunks && ` (showing ${allChunks.length})`} - - )} -

-
-
- {url && ( - - )} - -
-
- - {/* Loading State */} - {!isDirectRenderSource && isDocumentByChunkFetching && ( -
- - -

- {t("loading_document")} -

-
-
- )} - - {/* Error State */} - {!isDirectRenderSource && documentByChunkFetchingError && ( -
- -
- -
-
-

Document unavailable

-

- {documentByChunkFetchingError.message || - "An unexpected error occurred. Please try again."} -

-
- -
-
- )} - - {/* Direct render for web search providers */} - {isDirectRenderSource && ( - -
- {url && ( - - )} - -

- - Source Information -

-
- {title || "Untitled"} -
-
- {description || "No content available"} -
-
-
-
- )} - - {/* API-fetched document content */} - {!isDirectRenderSource && documentData && ( -
- {/* Chunk Navigation Sidebar */} - {allChunks.length > 1 && ( - - -
- {allChunks.map((chunk, idx) => { - const absNum = absoluteStart + idx + 1; - const isCited = chunk.id === chunkId; - const isActive = activeChunkIndex === idx; - return ( - scrollToChunk(idx)} - initial={{ opacity: 0, scale: 0.8 }} - animate={{ opacity: 1, scale: 1 }} - transition={{ delay: Math.min(idx * 0.02, 0.2) }} - className={cn( - "relative w-11 h-9 mx-auto rounded-lg text-xs font-semibold transition-all duration-200 flex items-center justify-center", - isCited - ? "bg-primary text-primary-foreground shadow-md" - : isActive - ? "bg-muted text-foreground" - : "bg-muted/50 text-muted-foreground hover:bg-muted hover:text-foreground" - )} - title={isCited ? `Chunk ${absNum} (Cited)` : `Chunk ${absNum}`} - > - {absNum} - {isCited && ( - - - - )} - - ); - })} -
-
-
- )} - - {/* Main Content */} - -
- {/* Document Metadata */} - {"document_metadata" in documentData && - documentData.document_metadata && - Object.keys(documentData.document_metadata).length > 0 && ( - -

- - Document Information -

-
- {Object.entries(documentData.document_metadata).map(([key, value]) => ( -
-
- {key.replace(/_/g, " ")} -
-
{String(value)}
-
- ))} -
-
- )} - - {/* Chunks Header */} -
-

- - Chunks {absoluteStart + 1}–{absoluteEnd} of {totalChunks} -

- {citedChunkIndex !== -1 && ( - - )} -
- - {/* Load Earlier */} - {canLoadBefore && ( -
- -
- )} - - {/* Chunks */} -
- {allChunks.map((chunk, idx) => { - const isCited = chunk.id === chunkId; - const chunkNumber = absoluteStart + idx + 1; - return ( - 30} - /> - ); - })} -
- - {/* Load Later */} - {canLoadAfter && ( -
- -
- )} -
-
-
- )} -
- - )} -
- ); - - if (!mounted) return <>{children}; - - return ( - <> - {children} - {createPortal(panelContent, globalThis.document.body)} - - ); -} diff --git a/surfsense_web/components/report-panel/report-panel.tsx b/surfsense_web/components/report-panel/report-panel.tsx index ede63d902..621cf13ce 100644 --- a/surfsense_web/components/report-panel/report-panel.tsx +++ b/surfsense_web/components/report-panel/report-panel.tsx @@ -398,7 +398,8 @@ export function ReportPanelContent({ ); - const editingActions = showReportEditingTier && + const editingActions = + showReportEditingTier && !isReadOnly && (isEditing ? ( <> diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx index 988befdd0..a0b700c2d 100644 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ b/surfsense_web/components/settings/agent-model-manager.tsx @@ -1,15 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { - AlertCircle, - Dot, - FileText, - Info, - Pencil, - RefreshCw, - Trash2, -} from "lucide-react"; +import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; import { useMemo, useState } from "react"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; diff --git a/surfsense_web/components/settings/roles-manager.tsx b/surfsense_web/components/settings/roles-manager.tsx index e7dadc20f..335cfc8a9 100644 --- a/surfsense_web/components/settings/roles-manager.tsx +++ b/surfsense_web/components/settings/roles-manager.tsx @@ -5,10 +5,8 @@ import { useAtomValue } from "jotai"; import { Bot, ChevronRight, - ScanEye, - Pencil, - FileText, Earth, + FileText, Image, Logs, type LucideIcon, @@ -16,11 +14,13 @@ import { MessageSquare, Mic, MoreHorizontal, - Unplug, + Pencil, + ScanEye, Settings, Shield, SlidersHorizontal, Trash2, + Unplug, Users, Video, } from "lucide-react"; @@ -462,9 +462,19 @@ function RolesContent({ return (
+ {/* biome-ignore lint/a11y/useSemanticElements: row contains nested interactive elements (DropdownMenu); using a + +
+ )} + + + ); +} + +export const DoomLoopApprovalToolUI: ToolCallMessagePartComponent = ({ + toolName, + args, + result, +}) => { + const { dispatch } = useHitlDecision(); + + if (!result || !isInterruptResult(result)) return null; + + return ( + } + interruptData={result} + onDecision={(decision) => dispatch([decision])} + /> + ); +}; + +export function isDoomLoopInterrupt(result: unknown): boolean { + if (!isInterruptResult(result)) return false; + const ctx = (result.context ?? {}) as Record; + return ctx.permission === "doom_loop"; +} diff --git a/surfsense_web/lib/apis/agent-actions-api.service.ts b/surfsense_web/lib/apis/agent-actions-api.service.ts new file mode 100644 index 000000000..007bb131e --- /dev/null +++ b/surfsense_web/lib/apis/agent-actions-api.service.ts @@ -0,0 +1,64 @@ +import { z } from "zod"; +import { baseApiService } from "./base-api.service"; + +const AgentActionReadSchema = z.object({ + id: z.number(), + thread_id: z.number(), + user_id: z.string().nullable(), + search_space_id: z.number(), + tool_name: z.string(), + args: z.record(z.string(), z.unknown()).nullable(), + result_id: z.string().nullable(), + reversible: z.boolean(), + reverse_descriptor: z.record(z.string(), z.unknown()).nullable(), + error: z.record(z.string(), z.unknown()).nullable(), + reverse_of: z.number().nullable(), + reverted_by_action_id: z.number().nullable(), + is_revert_action: z.boolean(), + created_at: z.string(), +}); + +export type AgentAction = z.infer; + +const AgentActionListResponseSchema = z.object({ + items: z.array(AgentActionReadSchema), + total: z.number(), + page: z.number(), + page_size: z.number(), + has_more: z.boolean(), +}); + +export type AgentActionListResponse = z.infer; + +const RevertResponseSchema = z.object({ + status: z.literal("ok"), + message: z.string(), + new_action_id: z.number().nullable().optional(), +}); + +export type RevertResponse = z.infer; + +class AgentActionsApiService { + listForThread = async ( + threadId: number, + opts: { page?: number; pageSize?: number } = {} + ): Promise => { + const params = new URLSearchParams(); + params.set("page", String(opts.page ?? 0)); + params.set("page_size", String(opts.pageSize ?? 50)); + return baseApiService.get( + `/api/v1/threads/${threadId}/actions?${params.toString()}`, + AgentActionListResponseSchema + ); + }; + + revert = async (threadId: number, actionId: number): Promise => { + return baseApiService.post( + `/api/v1/threads/${threadId}/revert/${actionId}`, + RevertResponseSchema, + { body: {} } + ); + }; +} + +export const agentActionsApiService = new AgentActionsApiService(); diff --git a/surfsense_web/lib/apis/agent-flags-api.service.ts b/surfsense_web/lib/apis/agent-flags-api.service.ts new file mode 100644 index 000000000..87332ca9f --- /dev/null +++ b/surfsense_web/lib/apis/agent-flags-api.service.ts @@ -0,0 +1,40 @@ +import { z } from "zod"; +import { baseApiService } from "./base-api.service"; + +const AgentFeatureFlagsSchema = z.object({ + disable_new_agent_stack: z.boolean(), + + enable_context_editing: z.boolean(), + enable_compaction_v2: z.boolean(), + enable_retry_after: z.boolean(), + enable_model_fallback: z.boolean(), + enable_model_call_limit: z.boolean(), + enable_tool_call_limit: z.boolean(), + enable_tool_call_repair: z.boolean(), + enable_doom_loop: z.boolean(), + + enable_permission: z.boolean(), + enable_busy_mutex: z.boolean(), + enable_llm_tool_selector: z.boolean(), + + enable_skills: z.boolean(), + enable_specialized_subagents: z.boolean(), + enable_kb_planner_runnable: z.boolean(), + + enable_action_log: z.boolean(), + enable_revert_route: z.boolean(), + + enable_plugin_loader: z.boolean(), + + enable_otel: z.boolean(), +}); + +export type AgentFeatureFlags = z.infer; + +class AgentFlagsApiService { + get = async (): Promise => { + return baseApiService.get(`/api/v1/agent/flags`, AgentFeatureFlagsSchema); + }; +} + +export const agentFlagsApiService = new AgentFlagsApiService(); diff --git a/surfsense_web/lib/apis/agent-permissions-api.service.ts b/surfsense_web/lib/apis/agent-permissions-api.service.ts new file mode 100644 index 000000000..6927c55d0 --- /dev/null +++ b/surfsense_web/lib/apis/agent-permissions-api.service.ts @@ -0,0 +1,90 @@ +import { z } from "zod"; +import { ValidationError } from "@/lib/error"; +import { baseApiService } from "./base-api.service"; + +const ActionEnum = z.enum(["allow", "deny", "ask"]); +export type AgentPermissionAction = z.infer; + +const AgentPermissionRuleSchema = z.object({ + id: z.number(), + search_space_id: z.number(), + user_id: z.string().nullable(), + thread_id: z.number().nullable(), + permission: z.string(), + pattern: z.string(), + action: ActionEnum, + created_at: z.string(), +}); + +export type AgentPermissionRule = z.infer; + +const AgentPermissionRuleListSchema = z.array(AgentPermissionRuleSchema); + +const AgentPermissionRuleCreateSchema = z.object({ + permission: z + .string() + .min(1, "Permission is required") + .max(255) + .regex(/^[a-zA-Z0-9_:.\-*]+$/, "Use letters, digits, '.', '_', ':', '-', or '*' wildcards."), + pattern: z.string().min(1).max(255).default("*"), + action: ActionEnum, + user_id: z.string().nullable().optional(), + thread_id: z.number().nullable().optional(), +}); + +export type AgentPermissionRuleCreate = z.infer; + +const AgentPermissionRuleUpdateSchema = z.object({ + pattern: z.string().min(1).max(255).optional(), + action: ActionEnum.optional(), +}); + +export type AgentPermissionRuleUpdate = z.infer; + +class AgentPermissionsApiService { + list = async (searchSpaceId: number): Promise => { + return baseApiService.get( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules`, + AgentPermissionRuleListSchema + ); + }; + + create = async ( + searchSpaceId: number, + payload: AgentPermissionRuleCreate + ): Promise => { + const parsed = AgentPermissionRuleCreateSchema.safeParse(payload); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((i) => i.message).join(", ")); + } + return baseApiService.post( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules`, + AgentPermissionRuleSchema, + { body: parsed.data } + ); + }; + + update = async ( + searchSpaceId: number, + ruleId: number, + payload: AgentPermissionRuleUpdate + ): Promise => { + const parsed = AgentPermissionRuleUpdateSchema.safeParse(payload); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((i) => i.message).join(", ")); + } + return baseApiService.patch( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules/${ruleId}`, + AgentPermissionRuleSchema, + { body: parsed.data } + ); + }; + + remove = async (searchSpaceId: number, ruleId: number): Promise => { + await baseApiService.delete( + `/api/v1/searchspaces/${searchSpaceId}/agent/permissions/rules/${ruleId}` + ); + }; +} + +export const agentPermissionsApiService = new AgentPermissionsApiService();